ndarray/iterators/
mod.rs

1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[macro_use]
10mod macros;
11mod chunks;
12mod into_iter;
13pub mod iter;
14mod lanes;
15mod windows;
16
17use std::iter::FromIterator;
18use std::marker::PhantomData;
19use std::ptr;
20use alloc::vec::Vec;
21
22use crate::Ix1;
23
24use super::{ArrayBase, ArrayView, ArrayViewMut, Axis, Data, NdProducer, RemoveAxis};
25use super::{Dimension, Ix, Ixs};
26
27pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut};
28pub use self::lanes::{Lanes, LanesMut};
29pub use self::windows::Windows;
30pub use self::into_iter::IntoIter;
31
32use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
33
34/// Base for iterators over all axes.
35///
36/// Iterator element type is `*mut A`.
37pub struct Baseiter<A, D> {
38    ptr: *mut A,
39    dim: D,
40    strides: D,
41    index: Option<D>,
42}
43
44impl<A, D: Dimension> Baseiter<A, D> {
45    /// Creating a Baseiter is unsafe because shape and stride parameters need
46    /// to be correct to avoid performing an unsafe pointer offset while
47    /// iterating.
48    #[inline]
49    pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter<A, D> {
50        Baseiter {
51            ptr,
52            index: len.first_index(),
53            dim: len,
54            strides: stride,
55        }
56    }
57}
58
59impl<A, D: Dimension> Iterator for Baseiter<A, D> {
60    type Item = *mut A;
61
62    #[inline]
63    fn next(&mut self) -> Option<*mut A> {
64        let index = match self.index {
65            None => return None,
66            Some(ref ix) => ix.clone(),
67        };
68        let offset = D::stride_offset(&index, &self.strides);
69        self.index = self.dim.next_for(index);
70        unsafe { Some(self.ptr.offset(offset)) }
71    }
72
73    fn size_hint(&self) -> (usize, Option<usize>) {
74        let len = self.len();
75        (len, Some(len))
76    }
77
78    fn fold<Acc, G>(mut self, init: Acc, mut g: G) -> Acc
79    where
80        G: FnMut(Acc, *mut A) -> Acc,
81    {
82        let ndim = self.dim.ndim();
83        debug_assert_ne!(ndim, 0);
84        let mut accum = init;
85        while let Some(mut index) = self.index {
86            let stride = self.strides.last_elem() as isize;
87            let elem_index = index.last_elem();
88            let len = self.dim.last_elem();
89            let offset = D::stride_offset(&index, &self.strides);
90            unsafe {
91                let row_ptr = self.ptr.offset(offset);
92                let mut i = 0;
93                let i_end = len - elem_index;
94                while i < i_end {
95                    accum = g(accum, row_ptr.offset(i as isize * stride));
96                    i += 1;
97                }
98            }
99            index.set_last_elem(len - 1);
100            self.index = self.dim.next_for(index);
101        }
102        accum
103    }
104}
105
106impl<A, D: Dimension> ExactSizeIterator for Baseiter<A, D> {
107    fn len(&self) -> usize {
108        match self.index {
109            None => 0,
110            Some(ref ix) => {
111                let gone = self
112                    .dim
113                    .default_strides()
114                    .slice()
115                    .iter()
116                    .zip(ix.slice().iter())
117                    .fold(0, |s, (&a, &b)| s + a as usize * b as usize);
118                self.dim.size() - gone
119            }
120        }
121    }
122}
123
124impl<A> DoubleEndedIterator for Baseiter<A, Ix1> {
125    #[inline]
126    fn next_back(&mut self) -> Option<*mut A> {
127        let index = match self.index {
128            None => return None,
129            Some(ix) => ix,
130        };
131        self.dim[0] -= 1;
132        let offset = <_>::stride_offset(&self.dim, &self.strides);
133        if index == self.dim {
134            self.index = None;
135        }
136
137        unsafe { Some(self.ptr.offset(offset)) }
138    }
139
140    fn nth_back(&mut self, n: usize) -> Option<*mut A> {
141        let index = self.index?;
142        let len = self.dim[0] - index[0];
143        if n < len {
144            self.dim[0] -= n + 1;
145            let offset = <_>::stride_offset(&self.dim, &self.strides);
146            if index == self.dim {
147                self.index = None;
148            }
149            unsafe { Some(self.ptr.offset(offset)) }
150        } else {
151            self.index = None;
152            None
153        }
154    }
155
156    fn rfold<Acc, G>(mut self, init: Acc, mut g: G) -> Acc
157    where
158        G: FnMut(Acc, *mut A) -> Acc,
159    {
160        let mut accum = init;
161        if let Some(index) = self.index {
162            let elem_index = index[0];
163            unsafe {
164                // self.dim[0] is the current length
165                while self.dim[0] > elem_index {
166                    self.dim[0] -= 1;
167                    accum = g(
168                        accum,
169                        self.ptr
170                            .offset(Ix1::stride_offset(&self.dim, &self.strides)),
171                    );
172                }
173            }
174        }
175        accum
176    }
177}
178
179clone_bounds!(
180    [A, D: Clone]
181    Baseiter[A, D] {
182        @copy {
183            ptr,
184        }
185        dim,
186        strides,
187        index,
188    }
189);
190
191clone_bounds!(
192    ['a, A, D: Clone]
193    ElementsBase['a, A, D] {
194        @copy {
195            life,
196        }
197        inner,
198    }
199);
200
201impl<'a, A, D: Dimension> ElementsBase<'a, A, D> {
202    pub fn new(v: ArrayView<'a, A, D>) -> Self {
203        ElementsBase {
204            inner: v.into_base_iter(),
205            life: PhantomData,
206        }
207    }
208}
209
210impl<'a, A, D: Dimension> Iterator for ElementsBase<'a, A, D> {
211    type Item = &'a A;
212    #[inline]
213    fn next(&mut self) -> Option<&'a A> {
214        self.inner.next().map(|p| unsafe { &*p })
215    }
216
217    fn size_hint(&self) -> (usize, Option<usize>) {
218        self.inner.size_hint()
219    }
220
221    fn fold<Acc, G>(self, init: Acc, mut g: G) -> Acc
222    where
223        G: FnMut(Acc, Self::Item) -> Acc,
224    {
225        unsafe { self.inner.fold(init, move |acc, ptr| g(acc, &*ptr)) }
226    }
227}
228
229impl<'a, A> DoubleEndedIterator for ElementsBase<'a, A, Ix1> {
230    #[inline]
231    fn next_back(&mut self) -> Option<&'a A> {
232        self.inner.next_back().map(|p| unsafe { &*p })
233    }
234
235    fn rfold<Acc, G>(self, init: Acc, mut g: G) -> Acc
236    where
237        G: FnMut(Acc, Self::Item) -> Acc,
238    {
239        unsafe { self.inner.rfold(init, move |acc, ptr| g(acc, &*ptr)) }
240    }
241}
242
243impl<'a, A, D> ExactSizeIterator for ElementsBase<'a, A, D>
244where
245    D: Dimension,
246{
247    fn len(&self) -> usize {
248        self.inner.len()
249    }
250}
251
252macro_rules! either {
253    ($value:expr, $inner:pat => $result:expr) => {
254        match $value {
255            ElementsRepr::Slice($inner) => $result,
256            ElementsRepr::Counted($inner) => $result,
257        }
258    };
259}
260
261macro_rules! either_mut {
262    ($value:expr, $inner:ident => $result:expr) => {
263        match $value {
264            ElementsRepr::Slice(ref mut $inner) => $result,
265            ElementsRepr::Counted(ref mut $inner) => $result,
266        }
267    };
268}
269
270clone_bounds!(
271    ['a, A, D: Clone]
272    Iter['a, A, D] {
273        @copy {
274        }
275        inner,
276    }
277);
278
279impl<'a, A, D> Iter<'a, A, D>
280where
281    D: Dimension,
282{
283    pub(crate) fn new(self_: ArrayView<'a, A, D>) -> Self {
284        Iter {
285            inner: if let Some(slc) = self_.to_slice() {
286                ElementsRepr::Slice(slc.iter())
287            } else {
288                ElementsRepr::Counted(self_.into_elements_base())
289            },
290        }
291    }
292}
293
294impl<'a, A, D> IterMut<'a, A, D>
295where
296    D: Dimension,
297{
298    pub(crate) fn new(self_: ArrayViewMut<'a, A, D>) -> Self {
299        IterMut {
300            inner: match self_.try_into_slice() {
301                Ok(x) => ElementsRepr::Slice(x.iter_mut()),
302                Err(self_) => ElementsRepr::Counted(self_.into_elements_base()),
303            },
304        }
305    }
306}
307
308#[derive(Clone)]
309pub enum ElementsRepr<S, C> {
310    Slice(S),
311    Counted(C),
312}
313
314/// An iterator over the elements of an array.
315///
316/// Iterator element type is `&'a A`.
317///
318/// See [`.iter()`](ArrayBase::iter) for more information.
319pub struct Iter<'a, A, D> {
320    inner: ElementsRepr<SliceIter<'a, A>, ElementsBase<'a, A, D>>,
321}
322
323/// Counted read only iterator
324pub struct ElementsBase<'a, A, D> {
325    inner: Baseiter<A, D>,
326    life: PhantomData<&'a A>,
327}
328
329/// An iterator over the elements of an array (mutable).
330///
331/// Iterator element type is `&'a mut A`.
332///
333/// See [`.iter_mut()`](ArrayBase::iter_mut) for more information.
334pub struct IterMut<'a, A, D> {
335    inner: ElementsRepr<SliceIterMut<'a, A>, ElementsBaseMut<'a, A, D>>,
336}
337
338/// An iterator over the elements of an array.
339///
340/// Iterator element type is `&'a mut A`.
341pub struct ElementsBaseMut<'a, A, D> {
342    inner: Baseiter<A, D>,
343    life: PhantomData<&'a mut A>,
344}
345
346impl<'a, A, D: Dimension> ElementsBaseMut<'a, A, D> {
347    pub fn new(v: ArrayViewMut<'a, A, D>) -> Self {
348        ElementsBaseMut {
349            inner: v.into_base_iter(),
350            life: PhantomData,
351        }
352    }
353}
354
355/// An iterator over the indexes and elements of an array.
356///
357/// See [`.indexed_iter()`](ArrayBase::indexed_iter) for more information.
358#[derive(Clone)]
359pub struct IndexedIter<'a, A, D>(ElementsBase<'a, A, D>);
360/// An iterator over the indexes and elements of an array (mutable).
361///
362/// See [`.indexed_iter_mut()`](ArrayBase::indexed_iter_mut) for more information.
363pub struct IndexedIterMut<'a, A, D>(ElementsBaseMut<'a, A, D>);
364
365impl<'a, A, D> IndexedIter<'a, A, D>
366where
367    D: Dimension,
368{
369    pub(crate) fn new(x: ElementsBase<'a, A, D>) -> Self {
370        IndexedIter(x)
371    }
372}
373
374impl<'a, A, D> IndexedIterMut<'a, A, D>
375where
376    D: Dimension,
377{
378    pub(crate) fn new(x: ElementsBaseMut<'a, A, D>) -> Self {
379        IndexedIterMut(x)
380    }
381}
382
383impl<'a, A, D: Dimension> Iterator for Iter<'a, A, D> {
384    type Item = &'a A;
385    #[inline]
386    fn next(&mut self) -> Option<&'a A> {
387        either_mut!(self.inner, iter => iter.next())
388    }
389
390    fn size_hint(&self) -> (usize, Option<usize>) {
391        either!(self.inner, ref iter => iter.size_hint())
392    }
393
394    fn fold<Acc, G>(self, init: Acc, g: G) -> Acc
395    where
396        G: FnMut(Acc, Self::Item) -> Acc,
397    {
398        either!(self.inner, iter => iter.fold(init, g))
399    }
400
401    fn nth(&mut self, n: usize) -> Option<Self::Item> {
402        either_mut!(self.inner, iter => iter.nth(n))
403    }
404
405    fn collect<B>(self) -> B
406    where
407        B: FromIterator<Self::Item>,
408    {
409        either!(self.inner, iter => iter.collect())
410    }
411
412    fn all<F>(&mut self, f: F) -> bool
413    where
414        F: FnMut(Self::Item) -> bool,
415    {
416        either_mut!(self.inner, iter => iter.all(f))
417    }
418
419    fn any<F>(&mut self, f: F) -> bool
420    where
421        F: FnMut(Self::Item) -> bool,
422    {
423        either_mut!(self.inner, iter => iter.any(f))
424    }
425
426    fn find<P>(&mut self, predicate: P) -> Option<Self::Item>
427    where
428        P: FnMut(&Self::Item) -> bool,
429    {
430        either_mut!(self.inner, iter => iter.find(predicate))
431    }
432
433    fn find_map<B, F>(&mut self, f: F) -> Option<B>
434    where
435        F: FnMut(Self::Item) -> Option<B>,
436    {
437        either_mut!(self.inner, iter => iter.find_map(f))
438    }
439
440    fn count(self) -> usize {
441        either!(self.inner, iter => iter.count())
442    }
443
444    fn last(self) -> Option<Self::Item> {
445        either!(self.inner, iter => iter.last())
446    }
447
448    fn position<P>(&mut self, predicate: P) -> Option<usize>
449    where
450        P: FnMut(Self::Item) -> bool,
451    {
452        either_mut!(self.inner, iter => iter.position(predicate))
453    }
454}
455
456impl<'a, A> DoubleEndedIterator for Iter<'a, A, Ix1> {
457    #[inline]
458    fn next_back(&mut self) -> Option<&'a A> {
459        either_mut!(self.inner, iter => iter.next_back())
460    }
461
462    fn nth_back(&mut self, n: usize) -> Option<&'a A> {
463        either_mut!(self.inner, iter => iter.nth_back(n))
464    }
465
466    fn rfold<Acc, G>(self, init: Acc, g: G) -> Acc
467    where
468        G: FnMut(Acc, Self::Item) -> Acc,
469    {
470        either!(self.inner, iter => iter.rfold(init, g))
471    }
472}
473
474impl<'a, A, D> ExactSizeIterator for Iter<'a, A, D>
475where
476    D: Dimension,
477{
478    fn len(&self) -> usize {
479        either!(self.inner, ref iter => iter.len())
480    }
481}
482
483impl<'a, A, D: Dimension> Iterator for IndexedIter<'a, A, D> {
484    type Item = (D::Pattern, &'a A);
485    #[inline]
486    fn next(&mut self) -> Option<Self::Item> {
487        let index = match self.0.inner.index {
488            None => return None,
489            Some(ref ix) => ix.clone(),
490        };
491        match self.0.next() {
492            None => None,
493            Some(elem) => Some((index.into_pattern(), elem)),
494        }
495    }
496
497    fn size_hint(&self) -> (usize, Option<usize>) {
498        self.0.size_hint()
499    }
500}
501
502impl<'a, A, D> ExactSizeIterator for IndexedIter<'a, A, D>
503where
504    D: Dimension,
505{
506    fn len(&self) -> usize {
507        self.0.inner.len()
508    }
509}
510
511impl<'a, A, D: Dimension> Iterator for IterMut<'a, A, D> {
512    type Item = &'a mut A;
513    #[inline]
514    fn next(&mut self) -> Option<&'a mut A> {
515        either_mut!(self.inner, iter => iter.next())
516    }
517
518    fn size_hint(&self) -> (usize, Option<usize>) {
519        either!(self.inner, ref iter => iter.size_hint())
520    }
521
522    fn fold<Acc, G>(self, init: Acc, g: G) -> Acc
523    where
524        G: FnMut(Acc, Self::Item) -> Acc,
525    {
526        either!(self.inner, iter => iter.fold(init, g))
527    }
528
529    fn nth(&mut self, n: usize) -> Option<Self::Item> {
530        either_mut!(self.inner, iter => iter.nth(n))
531    }
532
533    fn collect<B>(self) -> B
534    where
535        B: FromIterator<Self::Item>,
536    {
537        either!(self.inner, iter => iter.collect())
538    }
539
540    fn all<F>(&mut self, f: F) -> bool
541    where
542        F: FnMut(Self::Item) -> bool,
543    {
544        either_mut!(self.inner, iter => iter.all(f))
545    }
546
547    fn any<F>(&mut self, f: F) -> bool
548    where
549        F: FnMut(Self::Item) -> bool,
550    {
551        either_mut!(self.inner, iter => iter.any(f))
552    }
553
554    fn find<P>(&mut self, predicate: P) -> Option<Self::Item>
555    where
556        P: FnMut(&Self::Item) -> bool,
557    {
558        either_mut!(self.inner, iter => iter.find(predicate))
559    }
560
561    fn find_map<B, F>(&mut self, f: F) -> Option<B>
562    where
563        F: FnMut(Self::Item) -> Option<B>,
564    {
565        either_mut!(self.inner, iter => iter.find_map(f))
566    }
567
568    fn count(self) -> usize {
569        either!(self.inner, iter => iter.count())
570    }
571
572    fn last(self) -> Option<Self::Item> {
573        either!(self.inner, iter => iter.last())
574    }
575
576    fn position<P>(&mut self, predicate: P) -> Option<usize>
577    where
578        P: FnMut(Self::Item) -> bool,
579    {
580        either_mut!(self.inner, iter => iter.position(predicate))
581    }
582}
583
584impl<'a, A> DoubleEndedIterator for IterMut<'a, A, Ix1> {
585    #[inline]
586    fn next_back(&mut self) -> Option<&'a mut A> {
587        either_mut!(self.inner, iter => iter.next_back())
588    }
589
590    fn nth_back(&mut self, n: usize) -> Option<&'a mut A> {
591        either_mut!(self.inner, iter => iter.nth_back(n))
592    }
593
594    fn rfold<Acc, G>(self, init: Acc, g: G) -> Acc
595    where
596        G: FnMut(Acc, Self::Item) -> Acc,
597    {
598        either!(self.inner, iter => iter.rfold(init, g))
599    }
600}
601
602impl<'a, A, D> ExactSizeIterator for IterMut<'a, A, D>
603where
604    D: Dimension,
605{
606    fn len(&self) -> usize {
607        either!(self.inner, ref iter => iter.len())
608    }
609}
610
611impl<'a, A, D: Dimension> Iterator for ElementsBaseMut<'a, A, D> {
612    type Item = &'a mut A;
613    #[inline]
614    fn next(&mut self) -> Option<&'a mut A> {
615        self.inner.next().map(|p| unsafe { &mut *p })
616    }
617
618    fn size_hint(&self) -> (usize, Option<usize>) {
619        self.inner.size_hint()
620    }
621
622    fn fold<Acc, G>(self, init: Acc, mut g: G) -> Acc
623    where
624        G: FnMut(Acc, Self::Item) -> Acc,
625    {
626        unsafe { self.inner.fold(init, move |acc, ptr| g(acc, &mut *ptr)) }
627    }
628}
629
630impl<'a, A> DoubleEndedIterator for ElementsBaseMut<'a, A, Ix1> {
631    #[inline]
632    fn next_back(&mut self) -> Option<&'a mut A> {
633        self.inner.next_back().map(|p| unsafe { &mut *p })
634    }
635
636    fn rfold<Acc, G>(self, init: Acc, mut g: G) -> Acc
637    where
638        G: FnMut(Acc, Self::Item) -> Acc,
639    {
640        unsafe { self.inner.rfold(init, move |acc, ptr| g(acc, &mut *ptr)) }
641    }
642}
643
644impl<'a, A, D> ExactSizeIterator for ElementsBaseMut<'a, A, D>
645where
646    D: Dimension,
647{
648    fn len(&self) -> usize {
649        self.inner.len()
650    }
651}
652
653impl<'a, A, D: Dimension> Iterator for IndexedIterMut<'a, A, D> {
654    type Item = (D::Pattern, &'a mut A);
655    #[inline]
656    fn next(&mut self) -> Option<Self::Item> {
657        let index = match self.0.inner.index {
658            None => return None,
659            Some(ref ix) => ix.clone(),
660        };
661        match self.0.next() {
662            None => None,
663            Some(elem) => Some((index.into_pattern(), elem)),
664        }
665    }
666
667    fn size_hint(&self) -> (usize, Option<usize>) {
668        self.0.size_hint()
669    }
670}
671
672impl<'a, A, D> ExactSizeIterator for IndexedIterMut<'a, A, D>
673where
674    D: Dimension,
675{
676    fn len(&self) -> usize {
677        self.0.inner.len()
678    }
679}
680
681/// An iterator that traverses over all axes but one, and yields a view for
682/// each lane along that axis.
683///
684/// See [`.lanes()`](ArrayBase::lanes) for more information.
685pub struct LanesIter<'a, A, D> {
686    inner_len: Ix,
687    inner_stride: Ixs,
688    iter: Baseiter<A, D>,
689    life: PhantomData<&'a A>,
690}
691
692clone_bounds!(
693    ['a, A, D: Clone]
694    LanesIter['a, A, D] {
695        @copy {
696            inner_len,
697            inner_stride,
698            life,
699        }
700        iter,
701    }
702);
703
704impl<'a, A, D> Iterator for LanesIter<'a, A, D>
705where
706    D: Dimension,
707{
708    type Item = ArrayView<'a, A, Ix1>;
709    fn next(&mut self) -> Option<Self::Item> {
710        self.iter.next().map(|ptr| unsafe {
711            ArrayView::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix))
712        })
713    }
714
715    fn size_hint(&self) -> (usize, Option<usize>) {
716        self.iter.size_hint()
717    }
718}
719
720impl<'a, A, D> ExactSizeIterator for LanesIter<'a, A, D>
721where
722    D: Dimension,
723{
724    fn len(&self) -> usize {
725        self.iter.len()
726    }
727}
728
729// NOTE: LanesIterMut is a mutable iterator and must not expose aliasing
730// pointers. Due to this we use an empty slice for the raw data (it's unused
731// anyway).
732/// An iterator that traverses over all dimensions but the innermost,
733/// and yields each inner row (mutable).
734///
735/// See [`.lanes_mut()`](ArrayBase::lanes_mut)
736/// for more information.
737pub struct LanesIterMut<'a, A, D> {
738    inner_len: Ix,
739    inner_stride: Ixs,
740    iter: Baseiter<A, D>,
741    life: PhantomData<&'a mut A>,
742}
743
744impl<'a, A, D> Iterator for LanesIterMut<'a, A, D>
745where
746    D: Dimension,
747{
748    type Item = ArrayViewMut<'a, A, Ix1>;
749    fn next(&mut self) -> Option<Self::Item> {
750        self.iter.next().map(|ptr| unsafe {
751            ArrayViewMut::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix))
752        })
753    }
754
755    fn size_hint(&self) -> (usize, Option<usize>) {
756        self.iter.size_hint()
757    }
758}
759
760impl<'a, A, D> ExactSizeIterator for LanesIterMut<'a, A, D>
761where
762    D: Dimension,
763{
764    fn len(&self) -> usize {
765        self.iter.len()
766    }
767}
768
769#[derive(Debug)]
770pub struct AxisIterCore<A, D> {
771    /// Index along the axis of the value of `.next()`, relative to the start
772    /// of the axis.
773    index: Ix,
774    /// (Exclusive) upper bound on `index`. Initially, this is equal to the
775    /// length of the axis.
776    end: Ix,
777    /// Stride along the axis (offset between consecutive pointers).
778    stride: Ixs,
779    /// Shape of the iterator's items.
780    inner_dim: D,
781    /// Strides of the iterator's items.
782    inner_strides: D,
783    /// Pointer corresponding to `index == 0`.
784    ptr: *mut A,
785}
786
787clone_bounds!(
788    [A, D: Clone]
789    AxisIterCore[A, D] {
790        @copy {
791            index,
792            end,
793            stride,
794            ptr,
795        }
796        inner_dim,
797        inner_strides,
798    }
799);
800
801impl<A, D: Dimension> AxisIterCore<A, D> {
802    /// Constructs a new iterator over the specified axis.
803    fn new<S, Di>(v: ArrayBase<S, Di>, axis: Axis) -> Self
804    where
805        Di: RemoveAxis<Smaller = D>,
806        S: Data<Elem = A>,
807    {
808        AxisIterCore {
809            index: 0,
810            end: v.len_of(axis),
811            stride: v.stride_of(axis),
812            inner_dim: v.dim.remove_axis(axis),
813            inner_strides: v.strides.remove_axis(axis),
814            ptr: v.ptr.as_ptr(),
815        }
816    }
817
818    #[inline]
819    unsafe fn offset(&self, index: usize) -> *mut A {
820        debug_assert!(
821            index < self.end,
822            "index={}, end={}, stride={}",
823            index,
824            self.end,
825            self.stride
826        );
827        self.ptr.offset(index as isize * self.stride)
828    }
829
830    /// Splits the iterator at `index`, yielding two disjoint iterators.
831    ///
832    /// `index` is relative to the current state of the iterator (which is not
833    /// necessarily the start of the axis).
834    ///
835    /// **Panics** if `index` is strictly greater than the iterator's remaining
836    /// length.
837    fn split_at(self, index: usize) -> (Self, Self) {
838        assert!(index <= self.len());
839        let mid = self.index + index;
840        let left = AxisIterCore {
841            index: self.index,
842            end: mid,
843            stride: self.stride,
844            inner_dim: self.inner_dim.clone(),
845            inner_strides: self.inner_strides.clone(),
846            ptr: self.ptr,
847        };
848        let right = AxisIterCore {
849            index: mid,
850            end: self.end,
851            stride: self.stride,
852            inner_dim: self.inner_dim,
853            inner_strides: self.inner_strides,
854            ptr: self.ptr,
855        };
856        (left, right)
857    }
858
859    /// Does the same thing as `.next()` but also returns the index of the item
860    /// relative to the start of the axis.
861    fn next_with_index(&mut self) -> Option<(usize, *mut A)> {
862        let index = self.index;
863        self.next().map(|ptr| (index, ptr))
864    }
865
866    /// Does the same thing as `.next_back()` but also returns the index of the
867    /// item relative to the start of the axis.
868    fn next_back_with_index(&mut self) -> Option<(usize, *mut A)> {
869        self.next_back().map(|ptr| (self.end, ptr))
870    }
871}
872
873impl<A, D> Iterator for AxisIterCore<A, D>
874where
875    D: Dimension,
876{
877    type Item = *mut A;
878
879    fn next(&mut self) -> Option<Self::Item> {
880        if self.index >= self.end {
881            None
882        } else {
883            let ptr = unsafe { self.offset(self.index) };
884            self.index += 1;
885            Some(ptr)
886        }
887    }
888
889    fn size_hint(&self) -> (usize, Option<usize>) {
890        let len = self.len();
891        (len, Some(len))
892    }
893}
894
895impl<A, D> DoubleEndedIterator for AxisIterCore<A, D>
896where
897    D: Dimension,
898{
899    fn next_back(&mut self) -> Option<Self::Item> {
900        if self.index >= self.end {
901            None
902        } else {
903            let ptr = unsafe { self.offset(self.end - 1) };
904            self.end -= 1;
905            Some(ptr)
906        }
907    }
908}
909
910impl<A, D> ExactSizeIterator for AxisIterCore<A, D>
911where
912    D: Dimension,
913{
914    fn len(&self) -> usize {
915        self.end - self.index
916    }
917}
918
919/// An iterator that traverses over an axis and
920/// and yields each subview.
921///
922/// The outermost dimension is `Axis(0)`, created with `.outer_iter()`,
923/// but you can traverse arbitrary dimension with `.axis_iter()`.
924///
925/// For example, in a 3 × 5 × 5 array, with `axis` equal to `Axis(2)`,
926/// the iterator element is a 3 × 5 subview (and there are 5 in total).
927///
928/// Iterator element type is `ArrayView<'a, A, D>`.
929///
930/// See [`.outer_iter()`](ArrayBase::outer_iter)
931/// or [`.axis_iter()`](ArrayBase::axis_iter)
932/// for more information.
933#[derive(Debug)]
934pub struct AxisIter<'a, A, D> {
935    iter: AxisIterCore<A, D>,
936    life: PhantomData<&'a A>,
937}
938
939clone_bounds!(
940    ['a, A, D: Clone]
941    AxisIter['a, A, D] {
942        @copy {
943            life,
944        }
945        iter,
946    }
947);
948
949impl<'a, A, D: Dimension> AxisIter<'a, A, D> {
950    /// Creates a new iterator over the specified axis.
951    pub(crate) fn new<Di>(v: ArrayView<'a, A, Di>, axis: Axis) -> Self
952    where
953        Di: RemoveAxis<Smaller = D>,
954    {
955        AxisIter {
956            iter: AxisIterCore::new(v, axis),
957            life: PhantomData,
958        }
959    }
960
961    /// Splits the iterator at `index`, yielding two disjoint iterators.
962    ///
963    /// `index` is relative to the current state of the iterator (which is not
964    /// necessarily the start of the axis).
965    ///
966    /// **Panics** if `index` is strictly greater than the iterator's remaining
967    /// length.
968    pub fn split_at(self, index: usize) -> (Self, Self) {
969        let (left, right) = self.iter.split_at(index);
970        (
971            AxisIter {
972                iter: left,
973                life: self.life,
974            },
975            AxisIter {
976                iter: right,
977                life: self.life,
978            },
979        )
980    }
981}
982
983impl<'a, A, D> Iterator for AxisIter<'a, A, D>
984where
985    D: Dimension,
986{
987    type Item = ArrayView<'a, A, D>;
988
989    fn next(&mut self) -> Option<Self::Item> {
990        self.iter.next().map(|ptr| unsafe { self.as_ref(ptr) })
991    }
992
993    fn size_hint(&self) -> (usize, Option<usize>) {
994        self.iter.size_hint()
995    }
996}
997
998impl<'a, A, D> DoubleEndedIterator for AxisIter<'a, A, D>
999where
1000    D: Dimension,
1001{
1002    fn next_back(&mut self) -> Option<Self::Item> {
1003        self.iter.next_back().map(|ptr| unsafe { self.as_ref(ptr) })
1004    }
1005}
1006
1007impl<'a, A, D> ExactSizeIterator for AxisIter<'a, A, D>
1008where
1009    D: Dimension,
1010{
1011    fn len(&self) -> usize {
1012        self.iter.len()
1013    }
1014}
1015
1016/// An iterator that traverses over an axis and
1017/// and yields each subview (mutable)
1018///
1019/// The outermost dimension is `Axis(0)`, created with `.outer_iter()`,
1020/// but you can traverse arbitrary dimension with `.axis_iter()`.
1021///
1022/// For example, in a 3 × 5 × 5 array, with `axis` equal to `Axis(2)`,
1023/// the iterator element is a 3 × 5 subview (and there are 5 in total).
1024///
1025/// Iterator element type is `ArrayViewMut<'a, A, D>`.
1026///
1027/// See [`.outer_iter_mut()`](ArrayBase::outer_iter_mut)
1028/// or [`.axis_iter_mut()`](ArrayBase::axis_iter_mut)
1029/// for more information.
1030pub struct AxisIterMut<'a, A, D> {
1031    iter: AxisIterCore<A, D>,
1032    life: PhantomData<&'a mut A>,
1033}
1034
1035impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> {
1036    /// Creates a new iterator over the specified axis.
1037    pub(crate) fn new<Di>(v: ArrayViewMut<'a, A, Di>, axis: Axis) -> Self
1038    where
1039        Di: RemoveAxis<Smaller = D>,
1040    {
1041        AxisIterMut {
1042            iter: AxisIterCore::new(v, axis),
1043            life: PhantomData,
1044        }
1045    }
1046
1047    /// Splits the iterator at `index`, yielding two disjoint iterators.
1048    ///
1049    /// `index` is relative to the current state of the iterator (which is not
1050    /// necessarily the start of the axis).
1051    ///
1052    /// **Panics** if `index` is strictly greater than the iterator's remaining
1053    /// length.
1054    pub fn split_at(self, index: usize) -> (Self, Self) {
1055        let (left, right) = self.iter.split_at(index);
1056        (
1057            AxisIterMut {
1058                iter: left,
1059                life: self.life,
1060            },
1061            AxisIterMut {
1062                iter: right,
1063                life: self.life,
1064            },
1065        )
1066    }
1067}
1068
1069impl<'a, A, D> Iterator for AxisIterMut<'a, A, D>
1070where
1071    D: Dimension,
1072{
1073    type Item = ArrayViewMut<'a, A, D>;
1074
1075    fn next(&mut self) -> Option<Self::Item> {
1076        self.iter.next().map(|ptr| unsafe { self.as_ref(ptr) })
1077    }
1078
1079    fn size_hint(&self) -> (usize, Option<usize>) {
1080        self.iter.size_hint()
1081    }
1082}
1083
1084impl<'a, A, D> DoubleEndedIterator for AxisIterMut<'a, A, D>
1085where
1086    D: Dimension,
1087{
1088    fn next_back(&mut self) -> Option<Self::Item> {
1089        self.iter.next_back().map(|ptr| unsafe { self.as_ref(ptr) })
1090    }
1091}
1092
1093impl<'a, A, D> ExactSizeIterator for AxisIterMut<'a, A, D>
1094where
1095    D: Dimension,
1096{
1097    fn len(&self) -> usize {
1098        self.iter.len()
1099    }
1100}
1101
1102impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> {
1103    type Item = <Self as Iterator>::Item;
1104    type Dim = Ix1;
1105    type Ptr = *mut A;
1106    type Stride = isize;
1107
1108    fn layout(&self) -> crate::Layout {
1109        crate::Layout::one_dimensional()
1110    }
1111
1112    fn raw_dim(&self) -> Self::Dim {
1113        Ix1(self.len())
1114    }
1115
1116    fn as_ptr(&self) -> Self::Ptr {
1117        if self.len() > 0 {
1118            // `self.iter.index` is guaranteed to be in-bounds if any of the
1119            // iterator remains (i.e. if `self.len() > 0`).
1120            unsafe { self.iter.offset(self.iter.index) }
1121        } else {
1122            // In this case, `self.iter.index` may be past the end, so we must
1123            // not call `.offset()`. It's okay to return a dangling pointer
1124            // because it will never be used in the length 0 case.
1125            std::ptr::NonNull::dangling().as_ptr()
1126        }
1127    }
1128
1129    fn contiguous_stride(&self) -> isize {
1130        self.iter.stride
1131    }
1132
1133    unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
1134        ArrayView::new_(
1135            ptr,
1136            self.iter.inner_dim.clone(),
1137            self.iter.inner_strides.clone(),
1138        )
1139    }
1140
1141    unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
1142        self.iter.offset(self.iter.index + i[0])
1143    }
1144
1145    fn stride_of(&self, _axis: Axis) -> isize {
1146        self.contiguous_stride()
1147    }
1148
1149    fn split_at(self, _axis: Axis, index: usize) -> (Self, Self) {
1150        self.split_at(index)
1151    }
1152
1153    private_impl! {}
1154}
1155
1156impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
1157    type Item = <Self as Iterator>::Item;
1158    type Dim = Ix1;
1159    type Ptr = *mut A;
1160    type Stride = isize;
1161
1162    fn layout(&self) -> crate::Layout {
1163        crate::Layout::one_dimensional()
1164    }
1165
1166    fn raw_dim(&self) -> Self::Dim {
1167        Ix1(self.len())
1168    }
1169
1170    fn as_ptr(&self) -> Self::Ptr {
1171        if self.len() > 0 {
1172            // `self.iter.index` is guaranteed to be in-bounds if any of the
1173            // iterator remains (i.e. if `self.len() > 0`).
1174            unsafe { self.iter.offset(self.iter.index) }
1175        } else {
1176            // In this case, `self.iter.index` may be past the end, so we must
1177            // not call `.offset()`. It's okay to return a dangling pointer
1178            // because it will never be used in the length 0 case.
1179            std::ptr::NonNull::dangling().as_ptr()
1180        }
1181    }
1182
1183    fn contiguous_stride(&self) -> isize {
1184        self.iter.stride
1185    }
1186
1187    unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
1188        ArrayViewMut::new_(
1189            ptr,
1190            self.iter.inner_dim.clone(),
1191            self.iter.inner_strides.clone(),
1192        )
1193    }
1194
1195    unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
1196        self.iter.offset(self.iter.index + i[0])
1197    }
1198
1199    fn stride_of(&self, _axis: Axis) -> isize {
1200        self.contiguous_stride()
1201    }
1202
1203    fn split_at(self, _axis: Axis, index: usize) -> (Self, Self) {
1204        self.split_at(index)
1205    }
1206
1207    private_impl! {}
1208}
1209
1210/// An iterator that traverses over the specified axis
1211/// and yields views of the specified size on this axis.
1212///
1213/// For example, in a 2 × 8 × 3 array, if the axis of iteration
1214/// is 1 and the chunk size is 2, the yielded elements
1215/// are 2 × 2 × 3 views (and there are 4 in total).
1216///
1217/// Iterator element type is `ArrayView<'a, A, D>`.
1218///
1219/// See [`.axis_chunks_iter()`](ArrayBase::axis_chunks_iter) for more information.
1220pub struct AxisChunksIter<'a, A, D> {
1221    iter: AxisIterCore<A, D>,
1222    /// Index of the partial chunk (the chunk smaller than the specified chunk
1223    /// size due to the axis length not being evenly divisible). If the axis
1224    /// length is evenly divisible by the chunk size, this index is larger than
1225    /// the maximum valid index.
1226    partial_chunk_index: usize,
1227    /// Dimension of the partial chunk.
1228    partial_chunk_dim: D,
1229    life: PhantomData<&'a A>,
1230}
1231
1232clone_bounds!(
1233    ['a, A, D: Clone]
1234    AxisChunksIter['a, A, D] {
1235        @copy {
1236            life,
1237            partial_chunk_index,
1238        }
1239        iter,
1240        partial_chunk_dim,
1241    }
1242);
1243
1244/// Computes the information necessary to construct an iterator over chunks
1245/// along an axis, given a `view` of the array, the `axis` to iterate over, and
1246/// the chunk `size`.
1247///
1248/// Returns an axis iterator with the correct stride to move between chunks,
1249/// the number of chunks, and the shape of the last chunk.
1250///
1251/// **Panics** if `size == 0`.
1252fn chunk_iter_parts<A, D: Dimension>(
1253    v: ArrayView<'_, A, D>,
1254    axis: Axis,
1255    size: usize,
1256) -> (AxisIterCore<A, D>, usize, D) {
1257    assert_ne!(size, 0, "Chunk size must be nonzero.");
1258    let axis_len = v.len_of(axis);
1259    let n_whole_chunks = axis_len / size;
1260    let chunk_remainder = axis_len % size;
1261    let iter_len = if chunk_remainder == 0 {
1262        n_whole_chunks
1263    } else {
1264        n_whole_chunks + 1
1265    };
1266    let stride = if n_whole_chunks == 0 {
1267        // This case avoids potential overflow when `size > axis_len`.
1268        0
1269    } else {
1270        v.stride_of(axis) * size as isize
1271    };
1272
1273    let axis = axis.index();
1274    let mut inner_dim = v.dim.clone();
1275    inner_dim[axis] = size;
1276
1277    let mut partial_chunk_dim = v.dim;
1278    partial_chunk_dim[axis] = chunk_remainder;
1279    let partial_chunk_index = n_whole_chunks;
1280
1281    let iter = AxisIterCore {
1282        index: 0,
1283        end: iter_len,
1284        stride,
1285        inner_dim,
1286        inner_strides: v.strides,
1287        ptr: v.ptr.as_ptr(),
1288    };
1289
1290    (iter, partial_chunk_index, partial_chunk_dim)
1291}
1292
1293impl<'a, A, D: Dimension> AxisChunksIter<'a, A, D> {
1294    pub(crate) fn new(v: ArrayView<'a, A, D>, axis: Axis, size: usize) -> Self {
1295        let (iter, partial_chunk_index, partial_chunk_dim) = chunk_iter_parts(v, axis, size);
1296        AxisChunksIter {
1297            iter,
1298            partial_chunk_index,
1299            partial_chunk_dim,
1300            life: PhantomData,
1301        }
1302    }
1303}
1304
1305macro_rules! chunk_iter_impl {
1306    ($iter:ident, $array:ident) => {
1307        impl<'a, A, D> $iter<'a, A, D>
1308        where
1309            D: Dimension,
1310        {
1311            fn get_subview(&self, index: usize, ptr: *mut A) -> $array<'a, A, D> {
1312                if index != self.partial_chunk_index {
1313                    unsafe {
1314                        $array::new_(
1315                            ptr,
1316                            self.iter.inner_dim.clone(),
1317                            self.iter.inner_strides.clone(),
1318                        )
1319                    }
1320                } else {
1321                    unsafe {
1322                        $array::new_(
1323                            ptr,
1324                            self.partial_chunk_dim.clone(),
1325                            self.iter.inner_strides.clone(),
1326                        )
1327                    }
1328                }
1329            }
1330
1331            /// Splits the iterator at index, yielding two disjoint iterators.
1332            ///
1333            /// `index` is relative to the current state of the iterator (which is not
1334            /// necessarily the start of the axis).
1335            ///
1336            /// **Panics** if `index` is strictly greater than the iterator's remaining
1337            /// length.
1338            pub fn split_at(self, index: usize) -> (Self, Self) {
1339                let (left, right) = self.iter.split_at(index);
1340                (
1341                    Self {
1342                        iter: left,
1343                        partial_chunk_index: self.partial_chunk_index,
1344                        partial_chunk_dim: self.partial_chunk_dim.clone(),
1345                        life: self.life,
1346                    },
1347                    Self {
1348                        iter: right,
1349                        partial_chunk_index: self.partial_chunk_index,
1350                        partial_chunk_dim: self.partial_chunk_dim,
1351                        life: self.life,
1352                    },
1353                )
1354            }
1355        }
1356
1357        impl<'a, A, D> Iterator for $iter<'a, A, D>
1358        where
1359            D: Dimension,
1360        {
1361            type Item = $array<'a, A, D>;
1362
1363            fn next(&mut self) -> Option<Self::Item> {
1364                self.iter
1365                    .next_with_index()
1366                    .map(|(index, ptr)| self.get_subview(index, ptr))
1367            }
1368
1369            fn size_hint(&self) -> (usize, Option<usize>) {
1370                self.iter.size_hint()
1371            }
1372        }
1373
1374        impl<'a, A, D> DoubleEndedIterator for $iter<'a, A, D>
1375        where
1376            D: Dimension,
1377        {
1378            fn next_back(&mut self) -> Option<Self::Item> {
1379                self.iter
1380                    .next_back_with_index()
1381                    .map(|(index, ptr)| self.get_subview(index, ptr))
1382            }
1383        }
1384
1385        impl<'a, A, D> ExactSizeIterator for $iter<'a, A, D> where D: Dimension {}
1386    };
1387}
1388
1389/// An iterator that traverses over the specified axis
1390/// and yields mutable views of the specified size on this axis.
1391///
1392/// For example, in a 2 × 8 × 3 array, if the axis of iteration
1393/// is 1 and the chunk size is 2, the yielded elements
1394/// are 2 × 2 × 3 views (and there are 4 in total).
1395///
1396/// Iterator element type is `ArrayViewMut<'a, A, D>`.
1397///
1398/// See [`.axis_chunks_iter_mut()`](ArrayBase::axis_chunks_iter_mut)
1399/// for more information.
1400pub struct AxisChunksIterMut<'a, A, D> {
1401    iter: AxisIterCore<A, D>,
1402    partial_chunk_index: usize,
1403    partial_chunk_dim: D,
1404    life: PhantomData<&'a mut A>,
1405}
1406
1407impl<'a, A, D: Dimension> AxisChunksIterMut<'a, A, D> {
1408    pub(crate) fn new(v: ArrayViewMut<'a, A, D>, axis: Axis, size: usize) -> Self {
1409        let (iter, partial_chunk_index, partial_chunk_dim) =
1410            chunk_iter_parts(v.into_view(), axis, size);
1411        AxisChunksIterMut {
1412            iter,
1413            partial_chunk_index,
1414            partial_chunk_dim,
1415            life: PhantomData,
1416        }
1417    }
1418}
1419
1420chunk_iter_impl!(AxisChunksIter, ArrayView);
1421chunk_iter_impl!(AxisChunksIterMut, ArrayViewMut);
1422
1423send_sync_read_only!(Iter);
1424send_sync_read_only!(IndexedIter);
1425send_sync_read_only!(LanesIter);
1426send_sync_read_only!(AxisIter);
1427send_sync_read_only!(AxisChunksIter);
1428send_sync_read_only!(ElementsBase);
1429
1430send_sync_read_write!(IterMut);
1431send_sync_read_write!(IndexedIterMut);
1432send_sync_read_write!(LanesIterMut);
1433send_sync_read_write!(AxisIterMut);
1434send_sync_read_write!(AxisChunksIterMut);
1435send_sync_read_write!(ElementsBaseMut);
1436
1437/// (Trait used internally) An iterator that we trust
1438/// to deliver exactly as many items as it said it would.
1439///
1440/// The iterator must produce exactly the number of elements it reported or
1441/// diverge before reaching the end.
1442#[allow(clippy::missing_safety_doc)] // not nameable downstream
1443pub unsafe trait TrustedIterator {}
1444
1445use crate::indexes::IndicesIterF;
1446use crate::iter::IndicesIter;
1447#[cfg(feature = "std")]
1448use crate::{geomspace::Geomspace, linspace::Linspace, logspace::Logspace};
1449#[cfg(feature = "std")]
1450unsafe impl<F> TrustedIterator for Linspace<F> {}
1451#[cfg(feature = "std")]
1452unsafe impl<F> TrustedIterator for Geomspace<F> {}
1453#[cfg(feature = "std")]
1454unsafe impl<F> TrustedIterator for Logspace<F> {}
1455unsafe impl<'a, A, D> TrustedIterator for Iter<'a, A, D> {}
1456unsafe impl<'a, A, D> TrustedIterator for IterMut<'a, A, D> {}
1457unsafe impl<I> TrustedIterator for std::iter::Cloned<I> where I: TrustedIterator {}
1458unsafe impl<I, F> TrustedIterator for std::iter::Map<I, F> where I: TrustedIterator {}
1459unsafe impl<'a, A> TrustedIterator for slice::Iter<'a, A> {}
1460unsafe impl<'a, A> TrustedIterator for slice::IterMut<'a, A> {}
1461unsafe impl TrustedIterator for ::std::ops::Range<usize> {}
1462// FIXME: These indices iter are dubious -- size needs to be checked up front.
1463unsafe impl<D> TrustedIterator for IndicesIter<D> where D: Dimension {}
1464unsafe impl<D> TrustedIterator for IndicesIterF<D> where D: Dimension {}
1465unsafe impl<A, D> TrustedIterator for IntoIter<A, D> where D: Dimension {}
1466
1467/// Like Iterator::collect, but only for trusted length iterators
1468pub fn to_vec<I>(iter: I) -> Vec<I::Item>
1469where
1470    I: TrustedIterator + ExactSizeIterator,
1471{
1472    to_vec_mapped(iter, |x| x)
1473}
1474
1475/// Like Iterator::collect, but only for trusted length iterators
1476pub fn to_vec_mapped<I, F, B>(iter: I, mut f: F) -> Vec<B>
1477where
1478    I: TrustedIterator + ExactSizeIterator,
1479    F: FnMut(I::Item) -> B,
1480{
1481    // Use an `unsafe` block to do this efficiently.
1482    // We know that iter will produce exactly .size() elements,
1483    // and the loop can vectorize if it's clean (without branch to grow the vector).
1484    let (size, _) = iter.size_hint();
1485    let mut result = Vec::with_capacity(size);
1486    let mut out_ptr = result.as_mut_ptr();
1487    let mut len = 0;
1488    iter.fold((), |(), elt| unsafe {
1489        ptr::write(out_ptr, f(elt));
1490        len += 1;
1491        result.set_len(len);
1492        out_ptr = out_ptr.offset(1);
1493    });
1494    debug_assert_eq!(size, result.len());
1495    result
1496}