ndarray/dimension/
axes.rs

1use crate::{Axis, Dimension, Ix, Ixs};
2
3/// Create a new Axes iterator
4pub(crate) fn axes_of<'a, D>(d: &'a D, strides: &'a D) -> Axes<'a, D>
5where
6    D: Dimension,
7{
8    Axes {
9        dim: d,
10        strides,
11        start: 0,
12        end: d.ndim(),
13    }
14}
15
16/// An iterator over the length and stride of each axis of an array.
17///
18/// This iterator is created from the array method
19/// [`.axes()`](crate::ArrayBase::axes).
20///
21/// Iterator element type is [`AxisDescription`].
22///
23/// # Examples
24///
25/// ```
26/// use ndarray::Array3;
27/// use ndarray::Axis;
28///
29/// let a = Array3::<f32>::zeros((3, 5, 4));
30///
31/// // find the largest axis in the array
32/// // check the axis index and its length
33///
34/// let largest_axis = a.axes()
35///                     .max_by_key(|ax| ax.len)
36///                     .unwrap();
37/// assert_eq!(largest_axis.axis, Axis(1));
38/// assert_eq!(largest_axis.len, 5);
39/// ```
40#[derive(Debug)]
41pub struct Axes<'a, D> {
42    dim: &'a D,
43    strides: &'a D,
44    start: usize,
45    end: usize,
46}
47
48/// Description of the axis, its length and its stride.
49#[derive(Debug)]
50pub struct AxisDescription {
51    /// Axis identifier (index)
52    pub axis: Axis,
53    /// Length in count of elements of the current axis
54    pub len: usize,
55    /// Stride in count of elements of the current axis
56    pub stride: isize,
57}
58
59copy_and_clone!(AxisDescription);
60
61// AxisDescription can't really be empty
62// https://github.com/rust-ndarray/ndarray/pull/642#discussion_r296051702
63#[allow(clippy::len_without_is_empty)]
64impl AxisDescription {
65    /// Return axis
66    #[deprecated(note = "Use .axis field instead", since = "0.15.0")]
67    #[inline(always)]
68    pub fn axis(self) -> Axis {
69        self.axis
70    }
71    /// Return length
72    #[deprecated(note = "Use .len field instead", since = "0.15.0")]
73    #[inline(always)]
74    pub fn len(self) -> Ix {
75        self.len
76    }
77    /// Return stride
78    #[deprecated(note = "Use .stride field instead", since = "0.15.0")]
79    #[inline(always)]
80    pub fn stride(self) -> Ixs {
81        self.stride
82    }
83}
84
85copy_and_clone!(['a, D] Axes<'a, D>);
86
87impl<'a, D> Iterator for Axes<'a, D>
88where
89    D: Dimension,
90{
91    /// Description of the axis, its length and its stride.
92    type Item = AxisDescription;
93
94    fn next(&mut self) -> Option<Self::Item> {
95        if self.start < self.end {
96            let i = self.start.post_inc();
97            Some(AxisDescription {
98                axis: Axis(i),
99                len: self.dim[i],
100                stride: self.strides[i] as Ixs,
101            })
102        } else {
103            None
104        }
105    }
106
107    fn fold<B, F>(self, init: B, f: F) -> B
108    where
109        F: FnMut(B, AxisDescription) -> B,
110    {
111        (self.start..self.end)
112            .map(move |i| AxisDescription {
113                axis: Axis(i),
114                len: self.dim[i],
115                stride: self.strides[i] as isize,
116            })
117            .fold(init, f)
118    }
119
120    fn size_hint(&self) -> (usize, Option<usize>) {
121        let len = self.end - self.start;
122        (len, Some(len))
123    }
124}
125
126impl<'a, D> DoubleEndedIterator for Axes<'a, D>
127where
128    D: Dimension,
129{
130    fn next_back(&mut self) -> Option<Self::Item> {
131        if self.start < self.end {
132            let i = self.end.pre_dec();
133            Some(AxisDescription {
134                axis: Axis(i),
135                len: self.dim[i],
136                stride: self.strides[i] as Ixs,
137            })
138        } else {
139            None
140        }
141    }
142}
143
144trait IncOps: Copy {
145    fn post_inc(&mut self) -> Self;
146    fn post_dec(&mut self) -> Self;
147    fn pre_dec(&mut self) -> Self;
148}
149
150impl IncOps for usize {
151    #[inline(always)]
152    fn post_inc(&mut self) -> Self {
153        let x = *self;
154        *self += 1;
155        x
156    }
157    #[inline(always)]
158    fn post_dec(&mut self) -> Self {
159        let x = *self;
160        *self -= 1;
161        x
162    }
163    #[inline(always)]
164    fn pre_dec(&mut self) -> Self {
165        *self -= 1;
166        *self
167    }
168}