ndarray/stacking.rs
1// Copyright 2014-2020 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use alloc::vec::Vec;
10
11use crate::dimension;
12use crate::error::{from_kind, ErrorKind, ShapeError};
13use crate::imp_prelude::*;
14
15/// Stack arrays along the new axis.
16///
17/// ***Errors*** if the arrays have mismatching shapes.
18/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
19/// if the result is larger than is possible to represent.
20///
21/// ```
22/// extern crate ndarray;
23///
24/// use ndarray::{arr2, arr3, stack, Axis};
25///
26/// # fn main() {
27///
28/// let a = arr2(&[[2., 2.],
29/// [3., 3.]]);
30/// assert!(
31/// stack(Axis(0), &[a.view(), a.view()])
32/// == Ok(arr3(&[[[2., 2.],
33/// [3., 3.]],
34/// [[2., 2.],
35/// [3., 3.]]]))
36/// );
37/// # }
38/// ```
39pub fn stack<A, D>(
40 axis: Axis,
41 arrays: &[ArrayView<A, D>],
42) -> Result<Array<A, D::Larger>, ShapeError>
43where
44 A: Clone,
45 D: Dimension,
46 D::Larger: RemoveAxis,
47{
48 #[allow(deprecated)]
49 stack_new_axis(axis, arrays)
50}
51
52/// Concatenate arrays along the given axis.
53///
54/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
55/// (may be made more flexible in the future).<br>
56/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
57/// if the result is larger than is possible to represent.
58///
59/// ```
60/// use ndarray::{arr2, Axis, concatenate};
61///
62/// let a = arr2(&[[2., 2.],
63/// [3., 3.]]);
64/// assert!(
65/// concatenate(Axis(0), &[a.view(), a.view()])
66/// == Ok(arr2(&[[2., 2.],
67/// [3., 3.],
68/// [2., 2.],
69/// [3., 3.]]))
70/// );
71/// ```
72pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
73where
74 A: Clone,
75 D: RemoveAxis,
76{
77 if arrays.is_empty() {
78 return Err(from_kind(ErrorKind::Unsupported));
79 }
80 let mut res_dim = arrays[0].raw_dim();
81 if axis.index() >= res_dim.ndim() {
82 return Err(from_kind(ErrorKind::OutOfBounds));
83 }
84 let common_dim = res_dim.remove_axis(axis);
85 if arrays
86 .iter()
87 .any(|a| a.raw_dim().remove_axis(axis) != common_dim)
88 {
89 return Err(from_kind(ErrorKind::IncompatibleShape));
90 }
91
92 let stacked_dim = arrays.iter().fold(0, |acc, a| acc + a.len_of(axis));
93 res_dim.set_axis(axis, stacked_dim);
94 let new_len = dimension::size_of_shape_checked(&res_dim)?;
95
96 // start with empty array with precomputed capacity
97 // append's handling of empty arrays makes sure `axis` is ok for appending
98 res_dim.set_axis(axis, 0);
99 let mut res = unsafe {
100 // Safety: dimension is size 0 and vec is empty
101 Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len))
102 };
103
104 for array in arrays {
105 res.append(axis, array.clone())?;
106 }
107 debug_assert_eq!(res.len_of(axis), stacked_dim);
108 Ok(res)
109}
110
111#[deprecated(note="Use under the name stack instead.", since="0.15.0")]
112/// Stack arrays along the new axis.
113///
114/// ***Errors*** if the arrays have mismatching shapes.
115/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
116/// if the result is larger than is possible to represent.
117///
118/// ```
119/// extern crate ndarray;
120///
121/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
122///
123/// # fn main() {
124///
125/// let a = arr2(&[[2., 2.],
126/// [3., 3.]]);
127/// assert!(
128/// stack_new_axis(Axis(0), &[a.view(), a.view()])
129/// == Ok(arr3(&[[[2., 2.],
130/// [3., 3.]],
131/// [[2., 2.],
132/// [3., 3.]]]))
133/// );
134/// # }
135/// ```
136pub fn stack_new_axis<A, D>(
137 axis: Axis,
138 arrays: &[ArrayView<A, D>],
139) -> Result<Array<A, D::Larger>, ShapeError>
140where
141 A: Clone,
142 D: Dimension,
143 D::Larger: RemoveAxis,
144{
145 if arrays.is_empty() {
146 return Err(from_kind(ErrorKind::Unsupported));
147 }
148 let common_dim = arrays[0].raw_dim();
149 // Avoid panic on `insert_axis` call, return an Err instead of it.
150 if axis.index() > common_dim.ndim() {
151 return Err(from_kind(ErrorKind::OutOfBounds));
152 }
153 let mut res_dim = common_dim.insert_axis(axis);
154
155 if arrays.iter().any(|a| a.raw_dim() != common_dim) {
156 return Err(from_kind(ErrorKind::IncompatibleShape));
157 }
158
159 res_dim.set_axis(axis, arrays.len());
160
161 let new_len = dimension::size_of_shape_checked(&res_dim)?;
162
163 // start with empty array with precomputed capacity
164 // append's handling of empty arrays makes sure `axis` is ok for appending
165 res_dim.set_axis(axis, 0);
166 let mut res = unsafe {
167 // Safety: dimension is size 0 and vec is empty
168 Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len))
169 };
170
171 for array in arrays {
172 res.append(axis, array.clone().insert_axis(axis))?;
173 }
174
175 debug_assert_eq!(res.len_of(axis), arrays.len());
176 Ok(res)
177}
178
179/// Stack arrays along the new axis.
180///
181/// Uses the [`stack()`] function, calling `ArrayView::from(&a)` on each
182/// argument `a`.
183///
184/// ***Panics*** if the `stack` function would return an error.
185///
186/// ```
187/// extern crate ndarray;
188///
189/// use ndarray::{arr2, arr3, stack, Axis};
190///
191/// # fn main() {
192///
193/// let a = arr2(&[[1., 2.],
194/// [3., 4.]]);
195/// assert_eq!(
196/// stack![Axis(0), a, a],
197/// arr3(&[[[1., 2.],
198/// [3., 4.]],
199/// [[1., 2.],
200/// [3., 4.]]]),
201/// );
202/// assert_eq!(
203/// stack![Axis(1), a, a,],
204/// arr3(&[[[1., 2.],
205/// [1., 2.]],
206/// [[3., 4.],
207/// [3., 4.]]]),
208/// );
209/// assert_eq!(
210/// stack![Axis(2), a, a],
211/// arr3(&[[[1., 1.],
212/// [2., 2.]],
213/// [[3., 3.],
214/// [4., 4.]]]),
215/// );
216/// # }
217/// ```
218#[macro_export]
219macro_rules! stack {
220 ($axis:expr, $( $array:expr ),+ ,) => {
221 $crate::stack!($axis, $($array),+)
222 };
223 ($axis:expr, $( $array:expr ),+ ) => {
224 $crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
225 };
226}
227
228/// Concatenate arrays along the given axis.
229///
230/// Uses the [`concatenate()`] function, calling `ArrayView::from(&a)` on each
231/// argument `a`.
232///
233/// ***Panics*** if the `concatenate` function would return an error.
234///
235/// ```
236/// extern crate ndarray;
237///
238/// use ndarray::{arr2, concatenate, Axis};
239///
240/// # fn main() {
241///
242/// let a = arr2(&[[1., 2.],
243/// [3., 4.]]);
244/// assert_eq!(
245/// concatenate![Axis(0), a, a],
246/// arr2(&[[1., 2.],
247/// [3., 4.],
248/// [1., 2.],
249/// [3., 4.]]),
250/// );
251/// assert_eq!(
252/// concatenate![Axis(1), a, a,],
253/// arr2(&[[1., 2., 1., 2.],
254/// [3., 4., 3., 4.]]),
255/// );
256/// # }
257/// ```
258#[macro_export]
259macro_rules! concatenate {
260 ($axis:expr, $( $array:expr ),+ ,) => {
261 $crate::concatenate!($axis, $($array),+)
262 };
263 ($axis:expr, $( $array:expr ),+ ) => {
264 $crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
265 };
266}
267
268/// Stack arrays along the new axis.
269///
270/// Uses the [`stack_new_axis()`] function, calling `ArrayView::from(&a)` on each
271/// argument `a`.
272///
273/// ***Panics*** if the `stack` function would return an error.
274///
275/// ```
276/// extern crate ndarray;
277///
278/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
279///
280/// # fn main() {
281///
282/// let a = arr2(&[[2., 2.],
283/// [3., 3.]]);
284/// assert!(
285/// stack_new_axis![Axis(0), a, a]
286/// == arr3(&[[[2., 2.],
287/// [3., 3.]],
288/// [[2., 2.],
289/// [3., 3.]]])
290/// );
291/// # }
292/// ```
293#[macro_export]
294#[deprecated(note="Use under the name stack instead.", since="0.15.0")]
295macro_rules! stack_new_axis {
296 ($axis:expr, $( $array:expr ),+ ) => {
297 $crate::stack_new_axis($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
298 }
299}