ndarray/stacking.rs
1// Copyright 2014-2020 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[cfg(not(feature = "std"))]
10use alloc::vec::Vec;
11
12use crate::dimension;
13use crate::error::{from_kind, ErrorKind, ShapeError};
14use crate::imp_prelude::*;
15
16/// Concatenate arrays along the given axis.
17///
18/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
19/// (may be made more flexible in the future).<br>
20/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
21/// if the result is larger than is possible to represent.
22///
23/// ```
24/// use ndarray::{arr2, Axis, concatenate};
25///
26/// let a = arr2(&[[2., 2.],
27/// [3., 3.]]);
28/// assert!(
29/// concatenate(Axis(0), &[a.view(), a.view()])
30/// == Ok(arr2(&[[2., 2.],
31/// [3., 3.],
32/// [2., 2.],
33/// [3., 3.]]))
34/// );
35/// ```
36pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
37where
38 A: Clone,
39 D: RemoveAxis,
40{
41 if arrays.is_empty() {
42 return Err(from_kind(ErrorKind::Unsupported));
43 }
44 let mut res_dim = arrays[0].raw_dim();
45 if axis.index() >= res_dim.ndim() {
46 return Err(from_kind(ErrorKind::OutOfBounds));
47 }
48 let common_dim = res_dim.remove_axis(axis);
49 if arrays
50 .iter()
51 .any(|a| a.raw_dim().remove_axis(axis) != common_dim)
52 {
53 return Err(from_kind(ErrorKind::IncompatibleShape));
54 }
55
56 let stacked_dim = arrays.iter().fold(0, |acc, a| acc + a.len_of(axis));
57 res_dim.set_axis(axis, stacked_dim);
58 let new_len = dimension::size_of_shape_checked(&res_dim)?;
59
60 // start with empty array with precomputed capacity
61 // append's handling of empty arrays makes sure `axis` is ok for appending
62 res_dim.set_axis(axis, 0);
63 let mut res = unsafe {
64 // Safety: dimension is size 0 and vec is empty
65 Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len))
66 };
67
68 for array in arrays {
69 res.append(axis, array.clone())?;
70 }
71 debug_assert_eq!(res.len_of(axis), stacked_dim);
72 Ok(res)
73}
74
75/// Stack arrays along the new axis.
76///
77/// ***Errors*** if the arrays have mismatching shapes.
78/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
79/// if the result is larger than is possible to represent.
80///
81/// ```
82/// extern crate ndarray;
83///
84/// use ndarray::{arr2, arr3, stack, Axis};
85///
86/// # fn main() {
87///
88/// let a = arr2(&[[2., 2.],
89/// [3., 3.]]);
90/// assert!(
91/// stack(Axis(0), &[a.view(), a.view()])
92/// == Ok(arr3(&[[[2., 2.],
93/// [3., 3.]],
94/// [[2., 2.],
95/// [3., 3.]]]))
96/// );
97/// # }
98/// ```
99pub fn stack<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D::Larger>, ShapeError>
100where
101 A: Clone,
102 D: Dimension,
103 D::Larger: RemoveAxis,
104{
105 if arrays.is_empty() {
106 return Err(from_kind(ErrorKind::Unsupported));
107 }
108 let common_dim = arrays[0].raw_dim();
109 // Avoid panic on `insert_axis` call, return an Err instead of it.
110 if axis.index() > common_dim.ndim() {
111 return Err(from_kind(ErrorKind::OutOfBounds));
112 }
113 let mut res_dim = common_dim.insert_axis(axis);
114
115 if arrays.iter().any(|a| a.raw_dim() != common_dim) {
116 return Err(from_kind(ErrorKind::IncompatibleShape));
117 }
118
119 res_dim.set_axis(axis, arrays.len());
120
121 let new_len = dimension::size_of_shape_checked(&res_dim)?;
122
123 // start with empty array with precomputed capacity
124 // append's handling of empty arrays makes sure `axis` is ok for appending
125 res_dim.set_axis(axis, 0);
126 let mut res = unsafe {
127 // Safety: dimension is size 0 and vec is empty
128 Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len))
129 };
130
131 for array in arrays {
132 res.append(axis, array.clone().insert_axis(axis))?;
133 }
134
135 debug_assert_eq!(res.len_of(axis), arrays.len());
136 Ok(res)
137}
138
139/// Stack arrays along the new axis.
140///
141/// Uses the [`stack()`] function, calling `ArrayView::from(&a)` on each
142/// argument `a`.
143///
144/// ***Panics*** if the `stack` function would return an error.
145///
146/// ```
147/// extern crate ndarray;
148///
149/// use ndarray::{arr2, arr3, stack, Axis};
150///
151/// # fn main() {
152///
153/// let a = arr2(&[[1., 2.],
154/// [3., 4.]]);
155/// assert_eq!(
156/// stack![Axis(0), a, a],
157/// arr3(&[[[1., 2.],
158/// [3., 4.]],
159/// [[1., 2.],
160/// [3., 4.]]]),
161/// );
162/// assert_eq!(
163/// stack![Axis(1), a, a,],
164/// arr3(&[[[1., 2.],
165/// [1., 2.]],
166/// [[3., 4.],
167/// [3., 4.]]]),
168/// );
169/// assert_eq!(
170/// stack![Axis(2), a, a],
171/// arr3(&[[[1., 1.],
172/// [2., 2.]],
173/// [[3., 3.],
174/// [4., 4.]]]),
175/// );
176/// # }
177/// ```
178#[macro_export]
179macro_rules! stack {
180 ($axis:expr, $( $array:expr ),+ ,) => {
181 $crate::stack!($axis, $($array),+)
182 };
183 ($axis:expr, $( $array:expr ),+ ) => {
184 $crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
185 };
186}
187
188/// Concatenate arrays along the given axis.
189///
190/// Uses the [`concatenate()`] function, calling `ArrayView::from(&a)` on each
191/// argument `a`.
192///
193/// ***Panics*** if the `concatenate` function would return an error.
194///
195/// ```
196/// extern crate ndarray;
197///
198/// use ndarray::{arr2, concatenate, Axis};
199///
200/// # fn main() {
201///
202/// let a = arr2(&[[1., 2.],
203/// [3., 4.]]);
204/// assert_eq!(
205/// concatenate![Axis(0), a, a],
206/// arr2(&[[1., 2.],
207/// [3., 4.],
208/// [1., 2.],
209/// [3., 4.]]),
210/// );
211/// assert_eq!(
212/// concatenate![Axis(1), a, a,],
213/// arr2(&[[1., 2., 1., 2.],
214/// [3., 4., 3., 4.]]),
215/// );
216/// # }
217/// ```
218#[macro_export]
219macro_rules! concatenate {
220 ($axis:expr, $( $array:expr ),+ ,) => {
221 $crate::concatenate!($axis, $($array),+)
222 };
223 ($axis:expr, $( $array:expr ),+ ) => {
224 $crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
225 };
226}