ndarray/
indexes.rs

1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8use super::Dimension;
9use crate::dimension::IntoDimension;
10use crate::zip::Offset;
11use crate::split_at::SplitAt;
12use crate::Axis;
13use crate::Layout;
14use crate::NdProducer;
15use crate::{ArrayBase, Data};
16
17/// An iterator over the indexes of an array shape.
18///
19/// Iterator element type is `D`.
20#[derive(Clone)]
21pub struct IndicesIter<D> {
22    dim: D,
23    index: Option<D>,
24}
25
26/// Create an iterable of the array shape `shape`.
27///
28/// *Note:* prefer higher order methods, arithmetic operations and
29/// non-indexed iteration before using indices.
30pub fn indices<E>(shape: E) -> Indices<E::Dim>
31where
32    E: IntoDimension,
33{
34    let dim = shape.into_dimension();
35    Indices {
36        start: E::Dim::zeros(dim.ndim()),
37        dim,
38    }
39}
40
41/// Return an iterable of the indices of the passed-in array.
42///
43/// *Note:* prefer higher order methods, arithmetic operations and
44/// non-indexed iteration before using indices.
45pub fn indices_of<S, D>(array: &ArrayBase<S, D>) -> Indices<D>
46where
47    S: Data,
48    D: Dimension,
49{
50    indices(array.dim())
51}
52
53impl<D> Iterator for IndicesIter<D>
54where
55    D: Dimension,
56{
57    type Item = D::Pattern;
58    #[inline]
59    fn next(&mut self) -> Option<Self::Item> {
60        let index = match self.index {
61            None => return None,
62            Some(ref ix) => ix.clone(),
63        };
64        self.index = self.dim.next_for(index.clone());
65        Some(index.into_pattern())
66    }
67
68    fn size_hint(&self) -> (usize, Option<usize>) {
69        let l = match self.index {
70            None => 0,
71            Some(ref ix) => {
72                let gone = self
73                    .dim
74                    .default_strides()
75                    .slice()
76                    .iter()
77                    .zip(ix.slice().iter())
78                    .fold(0, |s, (&a, &b)| s + a as usize * b as usize);
79                self.dim.size() - gone
80            }
81        };
82        (l, Some(l))
83    }
84
85    fn fold<B, F>(self, init: B, mut f: F) -> B
86    where
87        F: FnMut(B, D::Pattern) -> B,
88    {
89        let IndicesIter { mut index, dim } = self;
90        let ndim = dim.ndim();
91        if ndim == 0 {
92            return match index {
93                Some(ix) => f(init, ix.into_pattern()),
94                None => init,
95            };
96        }
97        let inner_axis = ndim - 1;
98        let inner_len = dim[inner_axis];
99        let mut acc = init;
100        while let Some(mut ix) = index {
101            // unroll innermost axis
102            for i in ix[inner_axis]..inner_len {
103                ix[inner_axis] = i;
104                acc = f(acc, ix.clone().into_pattern());
105            }
106            index = dim.next_for(ix);
107        }
108        acc
109    }
110}
111
112impl<D> ExactSizeIterator for IndicesIter<D> where D: Dimension {}
113
114impl<D> IntoIterator for Indices<D>
115where
116    D: Dimension,
117{
118    type Item = D::Pattern;
119    type IntoIter = IndicesIter<D>;
120    fn into_iter(self) -> Self::IntoIter {
121        let sz = self.dim.size();
122        let index = if sz != 0 { Some(self.start) } else { None };
123        IndicesIter {
124            index,
125            dim: self.dim,
126        }
127    }
128}
129
130/// Indices producer and iterable.
131///
132/// `Indices` is an `NdProducer` that produces the indices of an array shape.
133#[derive(Copy, Clone, Debug)]
134pub struct Indices<D>
135where
136    D: Dimension,
137{
138    start: D,
139    dim: D,
140}
141
142#[derive(Copy, Clone, Debug)]
143pub struct IndexPtr<D> {
144    index: D,
145}
146
147impl<D> Offset for IndexPtr<D>
148where
149    D: Dimension + Copy,
150{
151    // stride: The axis to increment
152    type Stride = usize;
153
154    unsafe fn stride_offset(mut self, stride: Self::Stride, index: usize) -> Self {
155        self.index[stride] += index;
156        self
157    }
158    private_impl! {}
159}
160
161// How the NdProducer for Indices works.
162//
163// NdProducer allows for raw pointers (Ptr), strides (Stride) and the produced
164// item (Item).
165//
166// Instead of Ptr, there is `IndexPtr<D>` which is an index value, like [0, 0, 0]
167// for the three dimensional case.
168//
169// The stride is simply which axis is currently being incremented. The stride for axis 1, is 1.
170//
171// .stride_offset(stride, index) simply computes the new index along that axis, for example:
172// [0, 0, 0].stride_offset(1, 10) => [0, 10, 0]  axis 1 is incremented by 10.
173//
174// .as_ref() converts the Ptr value to an Item. For example [0, 10, 0] => (0, 10, 0)
175impl<D: Dimension + Copy> NdProducer for Indices<D> {
176    type Item = D::Pattern;
177    type Dim = D;
178    type Ptr = IndexPtr<D>;
179    type Stride = usize;
180
181    private_impl! {}
182
183    fn raw_dim(&self) -> Self::Dim {
184        self.dim
185    }
186
187    fn equal_dim(&self, dim: &Self::Dim) -> bool {
188        self.dim.equal(dim)
189    }
190
191    fn as_ptr(&self) -> Self::Ptr {
192        IndexPtr { index: self.start }
193    }
194
195    fn layout(&self) -> Layout {
196        if self.dim.ndim() <= 1 {
197            Layout::one_dimensional()
198        } else {
199            Layout::none()
200        }
201    }
202
203    unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
204        ptr.index.into_pattern()
205    }
206
207    unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
208        let mut index = *i;
209        index += &self.start;
210        IndexPtr { index }
211    }
212
213    fn stride_of(&self, axis: Axis) -> Self::Stride {
214        axis.index()
215    }
216
217    #[inline(always)]
218    fn contiguous_stride(&self) -> Self::Stride {
219        0
220    }
221
222    fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
223        let start_a = self.start;
224        let mut start_b = start_a;
225        let (a, b) = self.dim.split_at(axis, index);
226        start_b[axis.index()] += index;
227        (
228            Indices {
229                start: start_a,
230                dim: a,
231            },
232            Indices {
233                start: start_b,
234                dim: b,
235            },
236        )
237    }
238}
239
240/// An iterator over the indexes of an array shape.
241///
242/// Iterator element type is `D`.
243#[derive(Clone)]
244pub struct IndicesIterF<D> {
245    dim: D,
246    index: D,
247    has_remaining: bool,
248}
249
250pub fn indices_iter_f<E>(shape: E) -> IndicesIterF<E::Dim>
251where
252    E: IntoDimension,
253{
254    let dim = shape.into_dimension();
255    let zero = E::Dim::zeros(dim.ndim());
256    IndicesIterF {
257        has_remaining: dim.size_checked() != Some(0),
258        index: zero,
259        dim,
260    }
261}
262
263impl<D> Iterator for IndicesIterF<D>
264where
265    D: Dimension,
266{
267    type Item = D::Pattern;
268    #[inline]
269    fn next(&mut self) -> Option<Self::Item> {
270        if !self.has_remaining {
271            None
272        } else {
273            let elt = self.index.clone().into_pattern();
274            self.has_remaining = self.dim.next_for_f(&mut self.index);
275            Some(elt)
276        }
277    }
278
279    fn size_hint(&self) -> (usize, Option<usize>) {
280        if !self.has_remaining {
281            return (0, Some(0));
282        }
283        let gone = self
284            .dim
285            .fortran_strides()
286            .slice()
287            .iter()
288            .zip(self.index.slice().iter())
289            .fold(0, |s, (&a, &b)| s + a as usize * b as usize);
290        let l = self.dim.size() - gone;
291        (l, Some(l))
292    }
293}
294
295impl<D> ExactSizeIterator for IndicesIterF<D> where D: Dimension {}
296
297#[cfg(test)]
298mod tests {
299    use super::indices;
300    use super::indices_iter_f;
301
302    #[test]
303    fn test_indices_iter_c_size_hint() {
304        let dim = (3, 4);
305        let mut it = indices(dim).into_iter();
306        let mut len = dim.0 * dim.1;
307        assert_eq!(it.len(), len);
308        while let Some(_) = it.next() {
309            len -= 1;
310            assert_eq!(it.len(), len);
311        }
312        assert_eq!(len, 0);
313    }
314
315    #[test]
316    fn test_indices_iter_c_fold() {
317        macro_rules! run_test {
318            ($dim:expr) => {
319                for num_consume in 0..3 {
320                    let mut it = indices($dim).into_iter();
321                    for _ in 0..num_consume {
322                        it.next();
323                    }
324                    let clone = it.clone();
325                    let len = it.len();
326                    let acc = clone.fold(0, |acc, ix| {
327                        assert_eq!(ix, it.next().unwrap());
328                        acc + 1
329                    });
330                    assert_eq!(acc, len);
331                    assert!(it.next().is_none());
332                }
333            };
334        }
335        run_test!(());
336        run_test!((2,));
337        run_test!((2, 3));
338        run_test!((2, 0, 3));
339        run_test!((2, 3, 4));
340        run_test!((2, 3, 4, 2));
341    }
342
343    #[test]
344    fn test_indices_iter_f_size_hint() {
345        let dim = (3, 4);
346        let mut it = indices_iter_f(dim);
347        let mut len = dim.0 * dim.1;
348        assert_eq!(it.len(), len);
349        while let Some(_) = it.next() {
350            len -= 1;
351            assert_eq!(it.len(), len);
352        }
353        assert_eq!(len, 0);
354    }
355}