ndarray/
stacking.rs

1// Copyright 2014-2020 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[cfg(not(feature = "std"))]
10use alloc::vec::Vec;
11
12use crate::dimension;
13use crate::error::{from_kind, ErrorKind, ShapeError};
14use crate::imp_prelude::*;
15
16/// Concatenate arrays along the given axis.
17///
18/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
19/// (may be made more flexible in the future).<br>
20/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
21/// if the result is larger than is possible to represent.
22///
23/// ```
24/// use ndarray::{arr2, Axis, concatenate};
25///
26/// let a = arr2(&[[2., 2.],
27///                [3., 3.]]);
28/// assert!(
29///     concatenate(Axis(0), &[a.view(), a.view()])
30///     == Ok(arr2(&[[2., 2.],
31///                  [3., 3.],
32///                  [2., 2.],
33///                  [3., 3.]]))
34/// );
35/// ```
36pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
37where
38    A: Clone,
39    D: RemoveAxis,
40{
41    if arrays.is_empty() {
42        return Err(from_kind(ErrorKind::Unsupported));
43    }
44    let mut res_dim = arrays[0].raw_dim();
45    if axis.index() >= res_dim.ndim() {
46        return Err(from_kind(ErrorKind::OutOfBounds));
47    }
48    let common_dim = res_dim.remove_axis(axis);
49    if arrays
50        .iter()
51        .any(|a| a.raw_dim().remove_axis(axis) != common_dim)
52    {
53        return Err(from_kind(ErrorKind::IncompatibleShape));
54    }
55
56    let stacked_dim = arrays.iter().fold(0, |acc, a| acc + a.len_of(axis));
57    res_dim.set_axis(axis, stacked_dim);
58    let new_len = dimension::size_of_shape_checked(&res_dim)?;
59
60    // start with empty array with precomputed capacity
61    // append's handling of empty arrays makes sure `axis` is ok for appending
62    res_dim.set_axis(axis, 0);
63    let mut res = unsafe {
64        // Safety: dimension is size 0 and vec is empty
65        Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len))
66    };
67
68    for array in arrays {
69        res.append(axis, array.clone())?;
70    }
71    debug_assert_eq!(res.len_of(axis), stacked_dim);
72    Ok(res)
73}
74
75/// Stack arrays along the new axis.
76///
77/// ***Errors*** if the arrays have mismatching shapes.
78/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
79/// if the result is larger than is possible to represent.
80///
81/// ```
82/// extern crate ndarray;
83///
84/// use ndarray::{arr2, arr3, stack, Axis};
85///
86/// # fn main() {
87///
88/// let a = arr2(&[[2., 2.],
89///                [3., 3.]]);
90/// assert!(
91///     stack(Axis(0), &[a.view(), a.view()])
92///     == Ok(arr3(&[[[2., 2.],
93///                   [3., 3.]],
94///                  [[2., 2.],
95///                   [3., 3.]]]))
96/// );
97/// # }
98/// ```
99pub fn stack<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D::Larger>, ShapeError>
100where
101    A: Clone,
102    D: Dimension,
103    D::Larger: RemoveAxis,
104{
105    if arrays.is_empty() {
106        return Err(from_kind(ErrorKind::Unsupported));
107    }
108    let common_dim = arrays[0].raw_dim();
109    // Avoid panic on `insert_axis` call, return an Err instead of it.
110    if axis.index() > common_dim.ndim() {
111        return Err(from_kind(ErrorKind::OutOfBounds));
112    }
113    let mut res_dim = common_dim.insert_axis(axis);
114
115    if arrays.iter().any(|a| a.raw_dim() != common_dim) {
116        return Err(from_kind(ErrorKind::IncompatibleShape));
117    }
118
119    res_dim.set_axis(axis, arrays.len());
120
121    let new_len = dimension::size_of_shape_checked(&res_dim)?;
122
123    // start with empty array with precomputed capacity
124    // append's handling of empty arrays makes sure `axis` is ok for appending
125    res_dim.set_axis(axis, 0);
126    let mut res = unsafe {
127        // Safety: dimension is size 0 and vec is empty
128        Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len))
129    };
130
131    for array in arrays {
132        res.append(axis, array.clone().insert_axis(axis))?;
133    }
134
135    debug_assert_eq!(res.len_of(axis), arrays.len());
136    Ok(res)
137}
138
139/// Stack arrays along the new axis.
140///
141/// Uses the [`stack()`] function, calling `ArrayView::from(&a)` on each
142/// argument `a`.
143///
144/// ***Panics*** if the `stack` function would return an error.
145///
146/// ```
147/// extern crate ndarray;
148///
149/// use ndarray::{arr2, arr3, stack, Axis};
150///
151/// # fn main() {
152///
153/// let a = arr2(&[[1., 2.],
154///                [3., 4.]]);
155/// assert_eq!(
156///     stack![Axis(0), a, a],
157///     arr3(&[[[1., 2.],
158///             [3., 4.]],
159///            [[1., 2.],
160///             [3., 4.]]]),
161/// );
162/// assert_eq!(
163///     stack![Axis(1), a, a,],
164///     arr3(&[[[1., 2.],
165///             [1., 2.]],
166///            [[3., 4.],
167///             [3., 4.]]]),
168/// );
169/// assert_eq!(
170///     stack![Axis(2), a, a],
171///     arr3(&[[[1., 1.],
172///             [2., 2.]],
173///            [[3., 3.],
174///             [4., 4.]]]),
175/// );
176/// # }
177/// ```
178#[macro_export]
179macro_rules! stack {
180    ($axis:expr, $( $array:expr ),+ ,) => {
181        $crate::stack!($axis, $($array),+)
182    };
183    ($axis:expr, $( $array:expr ),+ ) => {
184        $crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
185    };
186}
187
188/// Concatenate arrays along the given axis.
189///
190/// Uses the [`concatenate()`] function, calling `ArrayView::from(&a)` on each
191/// argument `a`.
192///
193/// ***Panics*** if the `concatenate` function would return an error.
194///
195/// ```
196/// extern crate ndarray;
197///
198/// use ndarray::{arr2, concatenate, Axis};
199///
200/// # fn main() {
201///
202/// let a = arr2(&[[1., 2.],
203///                [3., 4.]]);
204/// assert_eq!(
205///     concatenate![Axis(0), a, a],
206///     arr2(&[[1., 2.],
207///            [3., 4.],
208///            [1., 2.],
209///            [3., 4.]]),
210/// );
211/// assert_eq!(
212///     concatenate![Axis(1), a, a,],
213///     arr2(&[[1., 2., 1., 2.],
214///            [3., 4., 3., 4.]]),
215/// );
216/// # }
217/// ```
218#[macro_export]
219macro_rules! concatenate {
220    ($axis:expr, $( $array:expr ),+ ,) => {
221        $crate::concatenate!($axis, $($array),+)
222    };
223    ($axis:expr, $( $array:expr ),+ ) => {
224        $crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
225    };
226}