ndarray/
stacking.rs

1// Copyright 2014-2020 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use alloc::vec::Vec;
10
11use crate::dimension;
12use crate::error::{from_kind, ErrorKind, ShapeError};
13use crate::imp_prelude::*;
14
15/// Stack arrays along the new axis.
16///
17/// ***Errors*** if the arrays have mismatching shapes.
18/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
19/// if the result is larger than is possible to represent.
20///
21/// ```
22/// extern crate ndarray;
23///
24/// use ndarray::{arr2, arr3, stack, Axis};
25///
26/// # fn main() {
27///
28/// let a = arr2(&[[2., 2.],
29///                [3., 3.]]);
30/// assert!(
31///     stack(Axis(0), &[a.view(), a.view()])
32///     == Ok(arr3(&[[[2., 2.],
33///                   [3., 3.]],
34///                  [[2., 2.],
35///                   [3., 3.]]]))
36/// );
37/// # }
38/// ```
39pub fn stack<A, D>(
40    axis: Axis,
41    arrays: &[ArrayView<A, D>],
42) -> Result<Array<A, D::Larger>, ShapeError>
43where
44    A: Clone,
45    D: Dimension,
46    D::Larger: RemoveAxis,
47{
48    #[allow(deprecated)]
49    stack_new_axis(axis, arrays)
50}
51
52/// Concatenate arrays along the given axis.
53///
54/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
55/// (may be made more flexible in the future).<br>
56/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
57/// if the result is larger than is possible to represent.
58///
59/// ```
60/// use ndarray::{arr2, Axis, concatenate};
61///
62/// let a = arr2(&[[2., 2.],
63///                [3., 3.]]);
64/// assert!(
65///     concatenate(Axis(0), &[a.view(), a.view()])
66///     == Ok(arr2(&[[2., 2.],
67///                  [3., 3.],
68///                  [2., 2.],
69///                  [3., 3.]]))
70/// );
71/// ```
72pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
73where
74    A: Clone,
75    D: RemoveAxis,
76{
77    if arrays.is_empty() {
78        return Err(from_kind(ErrorKind::Unsupported));
79    }
80    let mut res_dim = arrays[0].raw_dim();
81    if axis.index() >= res_dim.ndim() {
82        return Err(from_kind(ErrorKind::OutOfBounds));
83    }
84    let common_dim = res_dim.remove_axis(axis);
85    if arrays
86        .iter()
87        .any(|a| a.raw_dim().remove_axis(axis) != common_dim)
88    {
89        return Err(from_kind(ErrorKind::IncompatibleShape));
90    }
91
92    let stacked_dim = arrays.iter().fold(0, |acc, a| acc + a.len_of(axis));
93    res_dim.set_axis(axis, stacked_dim);
94    let new_len = dimension::size_of_shape_checked(&res_dim)?;
95
96    // start with empty array with precomputed capacity
97    // append's handling of empty arrays makes sure `axis` is ok for appending
98    res_dim.set_axis(axis, 0);
99    let mut res = unsafe {
100        // Safety: dimension is size 0 and vec is empty
101        Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len))
102    };
103
104    for array in arrays {
105        res.append(axis, array.clone())?;
106    }
107    debug_assert_eq!(res.len_of(axis), stacked_dim);
108    Ok(res)
109}
110
111#[deprecated(note="Use under the name stack instead.", since="0.15.0")]
112/// Stack arrays along the new axis.
113///
114/// ***Errors*** if the arrays have mismatching shapes.
115/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
116/// if the result is larger than is possible to represent.
117///
118/// ```
119/// extern crate ndarray;
120///
121/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
122///
123/// # fn main() {
124///
125/// let a = arr2(&[[2., 2.],
126///                [3., 3.]]);
127/// assert!(
128///     stack_new_axis(Axis(0), &[a.view(), a.view()])
129///     == Ok(arr3(&[[[2., 2.],
130///                   [3., 3.]],
131///                  [[2., 2.],
132///                   [3., 3.]]]))
133/// );
134/// # }
135/// ```
136pub fn stack_new_axis<A, D>(
137    axis: Axis,
138    arrays: &[ArrayView<A, D>],
139) -> Result<Array<A, D::Larger>, ShapeError>
140where
141    A: Clone,
142    D: Dimension,
143    D::Larger: RemoveAxis,
144{
145    if arrays.is_empty() {
146        return Err(from_kind(ErrorKind::Unsupported));
147    }
148    let common_dim = arrays[0].raw_dim();
149    // Avoid panic on `insert_axis` call, return an Err instead of it.
150    if axis.index() > common_dim.ndim() {
151        return Err(from_kind(ErrorKind::OutOfBounds));
152    }
153    let mut res_dim = common_dim.insert_axis(axis);
154
155    if arrays.iter().any(|a| a.raw_dim() != common_dim) {
156        return Err(from_kind(ErrorKind::IncompatibleShape));
157    }
158
159    res_dim.set_axis(axis, arrays.len());
160
161    let new_len = dimension::size_of_shape_checked(&res_dim)?;
162
163    // start with empty array with precomputed capacity
164    // append's handling of empty arrays makes sure `axis` is ok for appending
165    res_dim.set_axis(axis, 0);
166    let mut res = unsafe {
167        // Safety: dimension is size 0 and vec is empty
168        Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len))
169    };
170
171    for array in arrays {
172        res.append(axis, array.clone().insert_axis(axis))?;
173    }
174
175    debug_assert_eq!(res.len_of(axis), arrays.len());
176    Ok(res)
177}
178
179/// Stack arrays along the new axis.
180///
181/// Uses the [`stack()`] function, calling `ArrayView::from(&a)` on each
182/// argument `a`.
183///
184/// ***Panics*** if the `stack` function would return an error.
185///
186/// ```
187/// extern crate ndarray;
188///
189/// use ndarray::{arr2, arr3, stack, Axis};
190///
191/// # fn main() {
192///
193/// let a = arr2(&[[1., 2.],
194///                [3., 4.]]);
195/// assert_eq!(
196///     stack![Axis(0), a, a],
197///     arr3(&[[[1., 2.],
198///             [3., 4.]],
199///            [[1., 2.],
200///             [3., 4.]]]),
201/// );
202/// assert_eq!(
203///     stack![Axis(1), a, a,],
204///     arr3(&[[[1., 2.],
205///             [1., 2.]],
206///            [[3., 4.],
207///             [3., 4.]]]),
208/// );
209/// assert_eq!(
210///     stack![Axis(2), a, a],
211///     arr3(&[[[1., 1.],
212///             [2., 2.]],
213///            [[3., 3.],
214///             [4., 4.]]]),
215/// );
216/// # }
217/// ```
218#[macro_export]
219macro_rules! stack {
220    ($axis:expr, $( $array:expr ),+ ,) => {
221        $crate::stack!($axis, $($array),+)
222    };
223    ($axis:expr, $( $array:expr ),+ ) => {
224        $crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
225    };
226}
227
228/// Concatenate arrays along the given axis.
229///
230/// Uses the [`concatenate()`] function, calling `ArrayView::from(&a)` on each
231/// argument `a`.
232///
233/// ***Panics*** if the `concatenate` function would return an error.
234///
235/// ```
236/// extern crate ndarray;
237///
238/// use ndarray::{arr2, concatenate, Axis};
239///
240/// # fn main() {
241///
242/// let a = arr2(&[[1., 2.],
243///                [3., 4.]]);
244/// assert_eq!(
245///     concatenate![Axis(0), a, a],
246///     arr2(&[[1., 2.],
247///            [3., 4.],
248///            [1., 2.],
249///            [3., 4.]]),
250/// );
251/// assert_eq!(
252///     concatenate![Axis(1), a, a,],
253///     arr2(&[[1., 2., 1., 2.],
254///            [3., 4., 3., 4.]]),
255/// );
256/// # }
257/// ```
258#[macro_export]
259macro_rules! concatenate {
260    ($axis:expr, $( $array:expr ),+ ,) => {
261        $crate::concatenate!($axis, $($array),+)
262    };
263    ($axis:expr, $( $array:expr ),+ ) => {
264        $crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
265    };
266}
267
268/// Stack arrays along the new axis.
269///
270/// Uses the [`stack_new_axis()`] function, calling `ArrayView::from(&a)` on each
271/// argument `a`.
272///
273/// ***Panics*** if the `stack` function would return an error.
274///
275/// ```
276/// extern crate ndarray;
277///
278/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
279///
280/// # fn main() {
281///
282/// let a = arr2(&[[2., 2.],
283///                [3., 3.]]);
284/// assert!(
285///     stack_new_axis![Axis(0), a, a]
286///     == arr3(&[[[2., 2.],
287///                [3., 3.]],
288///               [[2., 2.],
289///                [3., 3.]]])
290/// );
291/// # }
292/// ```
293#[macro_export]
294#[deprecated(note="Use under the name stack instead.", since="0.15.0")]
295macro_rules! stack_new_axis {
296    ($axis:expr, $( $array:expr ),+ ) => {
297        $crate::stack_new_axis($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
298    }
299}