1use super::Dimension;
9use crate::dimension::IntoDimension;
10use crate::zip::Offset;
11use crate::split_at::SplitAt;
12use crate::Axis;
13use crate::Layout;
14use crate::NdProducer;
15use crate::{ArrayBase, Data};
16
17#[derive(Clone)]
21pub struct IndicesIter<D> {
22 dim: D,
23 index: Option<D>,
24}
25
26pub fn indices<E>(shape: E) -> Indices<E::Dim>
31where
32 E: IntoDimension,
33{
34 let dim = shape.into_dimension();
35 Indices {
36 start: E::Dim::zeros(dim.ndim()),
37 dim,
38 }
39}
40
41pub fn indices_of<S, D>(array: &ArrayBase<S, D>) -> Indices<D>
46where
47 S: Data,
48 D: Dimension,
49{
50 indices(array.dim())
51}
52
53impl<D> Iterator for IndicesIter<D>
54where
55 D: Dimension,
56{
57 type Item = D::Pattern;
58 #[inline]
59 fn next(&mut self) -> Option<Self::Item> {
60 let index = match self.index {
61 None => return None,
62 Some(ref ix) => ix.clone(),
63 };
64 self.index = self.dim.next_for(index.clone());
65 Some(index.into_pattern())
66 }
67
68 fn size_hint(&self) -> (usize, Option<usize>) {
69 let l = match self.index {
70 None => 0,
71 Some(ref ix) => {
72 let gone = self
73 .dim
74 .default_strides()
75 .slice()
76 .iter()
77 .zip(ix.slice().iter())
78 .fold(0, |s, (&a, &b)| s + a as usize * b as usize);
79 self.dim.size() - gone
80 }
81 };
82 (l, Some(l))
83 }
84
85 fn fold<B, F>(self, init: B, mut f: F) -> B
86 where
87 F: FnMut(B, D::Pattern) -> B,
88 {
89 let IndicesIter { mut index, dim } = self;
90 let ndim = dim.ndim();
91 if ndim == 0 {
92 return match index {
93 Some(ix) => f(init, ix.into_pattern()),
94 None => init,
95 };
96 }
97 let inner_axis = ndim - 1;
98 let inner_len = dim[inner_axis];
99 let mut acc = init;
100 while let Some(mut ix) = index {
101 for i in ix[inner_axis]..inner_len {
103 ix[inner_axis] = i;
104 acc = f(acc, ix.clone().into_pattern());
105 }
106 index = dim.next_for(ix);
107 }
108 acc
109 }
110}
111
112impl<D> ExactSizeIterator for IndicesIter<D> where D: Dimension {}
113
114impl<D> IntoIterator for Indices<D>
115where
116 D: Dimension,
117{
118 type Item = D::Pattern;
119 type IntoIter = IndicesIter<D>;
120 fn into_iter(self) -> Self::IntoIter {
121 let sz = self.dim.size();
122 let index = if sz != 0 { Some(self.start) } else { None };
123 IndicesIter {
124 index,
125 dim: self.dim,
126 }
127 }
128}
129
130#[derive(Copy, Clone, Debug)]
134pub struct Indices<D>
135where
136 D: Dimension,
137{
138 start: D,
139 dim: D,
140}
141
142#[derive(Copy, Clone, Debug)]
143pub struct IndexPtr<D> {
144 index: D,
145}
146
147impl<D> Offset for IndexPtr<D>
148where
149 D: Dimension + Copy,
150{
151 type Stride = usize;
153
154 unsafe fn stride_offset(mut self, stride: Self::Stride, index: usize) -> Self {
155 self.index[stride] += index;
156 self
157 }
158 private_impl! {}
159}
160
161impl<D: Dimension + Copy> NdProducer for Indices<D> {
176 type Item = D::Pattern;
177 type Dim = D;
178 type Ptr = IndexPtr<D>;
179 type Stride = usize;
180
181 private_impl! {}
182
183 fn raw_dim(&self) -> Self::Dim {
184 self.dim
185 }
186
187 fn equal_dim(&self, dim: &Self::Dim) -> bool {
188 self.dim.equal(dim)
189 }
190
191 fn as_ptr(&self) -> Self::Ptr {
192 IndexPtr { index: self.start }
193 }
194
195 fn layout(&self) -> Layout {
196 if self.dim.ndim() <= 1 {
197 Layout::one_dimensional()
198 } else {
199 Layout::none()
200 }
201 }
202
203 unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
204 ptr.index.into_pattern()
205 }
206
207 unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
208 let mut index = *i;
209 index += &self.start;
210 IndexPtr { index }
211 }
212
213 fn stride_of(&self, axis: Axis) -> Self::Stride {
214 axis.index()
215 }
216
217 #[inline(always)]
218 fn contiguous_stride(&self) -> Self::Stride {
219 0
220 }
221
222 fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
223 let start_a = self.start;
224 let mut start_b = start_a;
225 let (a, b) = self.dim.split_at(axis, index);
226 start_b[axis.index()] += index;
227 (
228 Indices {
229 start: start_a,
230 dim: a,
231 },
232 Indices {
233 start: start_b,
234 dim: b,
235 },
236 )
237 }
238}
239
240#[derive(Clone)]
244pub struct IndicesIterF<D> {
245 dim: D,
246 index: D,
247 has_remaining: bool,
248}
249
250pub fn indices_iter_f<E>(shape: E) -> IndicesIterF<E::Dim>
251where
252 E: IntoDimension,
253{
254 let dim = shape.into_dimension();
255 let zero = E::Dim::zeros(dim.ndim());
256 IndicesIterF {
257 has_remaining: dim.size_checked() != Some(0),
258 index: zero,
259 dim,
260 }
261}
262
263impl<D> Iterator for IndicesIterF<D>
264where
265 D: Dimension,
266{
267 type Item = D::Pattern;
268 #[inline]
269 fn next(&mut self) -> Option<Self::Item> {
270 if !self.has_remaining {
271 None
272 } else {
273 let elt = self.index.clone().into_pattern();
274 self.has_remaining = self.dim.next_for_f(&mut self.index);
275 Some(elt)
276 }
277 }
278
279 fn size_hint(&self) -> (usize, Option<usize>) {
280 if !self.has_remaining {
281 return (0, Some(0));
282 }
283 let gone = self
284 .dim
285 .fortran_strides()
286 .slice()
287 .iter()
288 .zip(self.index.slice().iter())
289 .fold(0, |s, (&a, &b)| s + a as usize * b as usize);
290 let l = self.dim.size() - gone;
291 (l, Some(l))
292 }
293}
294
295impl<D> ExactSizeIterator for IndicesIterF<D> where D: Dimension {}
296
297#[cfg(test)]
298mod tests {
299 use super::indices;
300 use super::indices_iter_f;
301
302 #[test]
303 fn test_indices_iter_c_size_hint() {
304 let dim = (3, 4);
305 let mut it = indices(dim).into_iter();
306 let mut len = dim.0 * dim.1;
307 assert_eq!(it.len(), len);
308 while let Some(_) = it.next() {
309 len -= 1;
310 assert_eq!(it.len(), len);
311 }
312 assert_eq!(len, 0);
313 }
314
315 #[test]
316 fn test_indices_iter_c_fold() {
317 macro_rules! run_test {
318 ($dim:expr) => {
319 for num_consume in 0..3 {
320 let mut it = indices($dim).into_iter();
321 for _ in 0..num_consume {
322 it.next();
323 }
324 let clone = it.clone();
325 let len = it.len();
326 let acc = clone.fold(0, |acc, ix| {
327 assert_eq!(ix, it.next().unwrap());
328 acc + 1
329 });
330 assert_eq!(acc, len);
331 assert!(it.next().is_none());
332 }
333 };
334 }
335 run_test!(());
336 run_test!((2,));
337 run_test!((2, 3));
338 run_test!((2, 0, 3));
339 run_test!((2, 3, 4));
340 run_test!((2, 3, 4, 2));
341 }
342
343 #[test]
344 fn test_indices_iter_f_size_hint() {
345 let dim = (3, 4);
346 let mut it = indices_iter_f(dim);
347 let mut len = dim.0 * dim.1;
348 assert_eq!(it.len(), len);
349 while let Some(_) = it.next() {
350 len -= 1;
351 assert_eq!(it.len(), len);
352 }
353 assert_eq!(len, 0);
354 }
355}