ndarray/zip/
mod.rs

1// Copyright 2017 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[macro_use]
10mod zipmacro;
11mod ndproducer;
12
13#[cfg(feature = "rayon")]
14use std::mem::MaybeUninit;
15
16use crate::imp_prelude::*;
17use crate::AssignElem;
18use crate::IntoDimension;
19use crate::Layout;
20use crate::partial::Partial;
21
22use crate::indexes::{indices, Indices};
23use crate::split_at::{SplitPreference, SplitAt};
24use crate::dimension;
25
26pub use self::ndproducer::{NdProducer, IntoNdProducer, Offset};
27
28/// Return if the expression is a break value.
29macro_rules! fold_while {
30    ($e:expr) => {
31        match $e {
32            FoldWhile::Continue(x) => x,
33            x => return x,
34        }
35    };
36}
37
38/// Broadcast an array so that it acts like a larger size and/or shape array.
39///
40/// See [broadcasting](ArrayBase#broadcasting) for more information.
41trait Broadcast<E>
42where
43    E: IntoDimension,
44{
45    type Output: NdProducer<Dim = E::Dim>;
46    /// Broadcast the array to the new dimensions `shape`.
47    ///
48    /// ***Panics*** if broadcasting isn’t possible.
49    fn broadcast_unwrap(self, shape: E) -> Self::Output;
50    private_decl! {}
51}
52
53/// Compute `Layout` hints for array shape dim, strides
54fn array_layout<D: Dimension>(dim: &D, strides: &D) -> Layout {
55    let n = dim.ndim();
56    if dimension::is_layout_c(dim, strides) {
57        // effectively one-dimensional => C and F layout compatible
58        if n <= 1 || dim.slice().iter().filter(|&&len| len > 1).count() <= 1 {
59            Layout::one_dimensional()
60        } else {
61            Layout::c()
62        }
63    } else if n > 1 && dimension::is_layout_f(dim, strides) {
64        Layout::f()
65    } else if n > 1 {
66        if dim[0] > 1 && strides[0] == 1 {
67            Layout::fpref()
68        } else if dim[n - 1] > 1 && strides[n - 1] == 1 {
69            Layout::cpref()
70        } else {
71            Layout::none()
72        }
73    } else {
74        Layout::none()
75    }
76}
77
78impl<S, D> ArrayBase<S, D>
79where
80    S: RawData,
81    D: Dimension,
82{
83    pub(crate) fn layout_impl(&self) -> Layout {
84        array_layout(&self.dim, &self.strides)
85    }
86}
87
88impl<'a, A, D, E> Broadcast<E> for ArrayView<'a, A, D>
89where
90    E: IntoDimension,
91    D: Dimension,
92{
93    type Output = ArrayView<'a, A, E::Dim>;
94    fn broadcast_unwrap(self, shape: E) -> Self::Output {
95        let res: ArrayView<'_, A, E::Dim> = (&self).broadcast_unwrap(shape.into_dimension());
96        unsafe { ArrayView::new(res.ptr, res.dim, res.strides) }
97    }
98    private_impl! {}
99}
100
101trait ZippableTuple: Sized {
102    type Item;
103    type Ptr: OffsetTuple<Args = Self::Stride> + Copy;
104    type Dim: Dimension;
105    type Stride: Copy;
106    fn as_ptr(&self) -> Self::Ptr;
107    unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item;
108    unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr;
109    fn stride_of(&self, index: usize) -> Self::Stride;
110    fn contiguous_stride(&self) -> Self::Stride;
111    fn split_at(self, axis: Axis, index: usize) -> (Self, Self);
112}
113
114/// Lock step function application across several arrays or other producers.
115///
116/// Zip allows matching several producers to each other elementwise and applying
117/// a function over all tuples of elements (one item from each input at
118/// a time).
119///
120/// In general, the zip uses a tuple of producers
121/// ([`NdProducer`] trait) that all have to be of the
122/// same shape. The NdProducer implementation defines what its item type is
123/// (for example if it's a shared reference, mutable reference or an array
124/// view etc).
125///
126/// If all the input arrays are of the same memory layout the zip performs much
127/// better and the compiler can usually vectorize the loop (if applicable).
128///
129/// The order elements are visited is not specified. The producers don’t have to
130/// have the same item type.
131///
132/// The `Zip` has two methods for function application: `for_each` and
133/// `fold_while`. The zip object can be split, which allows parallelization.
134/// A read-only zip object (no mutable producers) can be cloned.
135///
136/// See also the [`azip!()`] which offers a convenient shorthand
137/// to common ways to use `Zip`.
138///
139/// ```
140/// use ndarray::Zip;
141/// use ndarray::Array2;
142///
143/// type M = Array2<f64>;
144///
145/// // Create four 2d arrays of the same size
146/// let mut a = M::zeros((64, 32));
147/// let b = M::from_elem(a.dim(), 1.);
148/// let c = M::from_elem(a.dim(), 2.);
149/// let d = M::from_elem(a.dim(), 3.);
150///
151/// // Example 1: Perform an elementwise arithmetic operation across
152/// // the four arrays a, b, c, d.
153///
154/// Zip::from(&mut a)
155///     .and(&b)
156///     .and(&c)
157///     .and(&d)
158///     .for_each(|w, &x, &y, &z| {
159///         *w += x + y * z;
160///     });
161///
162/// // Example 2: Create a new array `totals` with one entry per row of `a`.
163/// //  Use Zip to traverse the rows of `a` and assign to the corresponding
164/// //  entry in `totals` with the sum across each row.
165/// //  This is possible because the producer for `totals` and the row producer
166/// //  for `a` have the same shape and dimensionality.
167/// //  The rows producer yields one array view (`row`) per iteration.
168///
169/// use ndarray::{Array1, Axis};
170///
171/// let mut totals = Array1::zeros(a.nrows());
172///
173/// Zip::from(&mut totals)
174///     .and(a.rows())
175///     .for_each(|totals, row| *totals = row.sum());
176///
177/// // Check the result against the built in `.sum_axis()` along axis 1.
178/// assert_eq!(totals, a.sum_axis(Axis(1)));
179///
180///
181/// // Example 3: Recreate Example 2 using map_collect to make a new array
182///
183/// let totals2 = Zip::from(a.rows()).map_collect(|row| row.sum());
184///
185/// // Check the result against the previous example.
186/// assert_eq!(totals, totals2);
187/// ```
188#[derive(Debug, Clone)]
189#[must_use = "zipping producers is lazy and does nothing unless consumed"]
190pub struct Zip<Parts, D> {
191    parts: Parts,
192    dimension: D,
193    layout: Layout,
194    /// The sum of the layout tendencies of the parts;
195    /// positive for c- and negative for f-layout preference.
196    layout_tendency: i32,
197}
198
199
200impl<P, D> Zip<(P,), D>
201where
202    D: Dimension,
203    P: NdProducer<Dim = D>,
204{
205    /// Create a new `Zip` from the input array or other producer `p`.
206    ///
207    /// The Zip will take the exact dimension of `p` and all inputs
208    /// must have the same dimensions (or be broadcast to them).
209    pub fn from<IP>(p: IP) -> Self
210    where
211        IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>,
212    {
213        let array = p.into_producer();
214        let dim = array.raw_dim();
215        let layout = array.layout();
216        Zip {
217            dimension: dim,
218            layout,
219            parts: (array,),
220            layout_tendency: layout.tendency(),
221        }
222    }
223}
224impl<P, D> Zip<(Indices<D>, P), D>
225where
226    D: Dimension + Copy,
227    P: NdProducer<Dim = D>,
228{
229    /// Create a new `Zip` with an index producer and the producer `p`.
230    ///
231    /// The Zip will take the exact dimension of `p` and all inputs
232    /// must have the same dimensions (or be broadcast to them).
233    ///
234    /// *Note:* Indexed zip has overhead.
235    pub fn indexed<IP>(p: IP) -> Self
236    where
237        IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>,
238    {
239        let array = p.into_producer();
240        let dim = array.raw_dim();
241        Zip::from(indices(dim)).and(array)
242    }
243}
244
245#[inline]
246fn zip_dimension_check<D, P>(dimension: &D, part: &P)
247where
248    D: Dimension,
249    P: NdProducer<Dim = D>,
250{
251    ndassert!(
252        part.equal_dim(dimension),
253        "Zip: Producer dimension mismatch, expected: {:?}, got: {:?}",
254        dimension,
255        part.raw_dim()
256    );
257}
258
259
260impl<Parts, D> Zip<Parts, D>
261where
262    D: Dimension,
263{
264    /// Return a the number of element tuples in the Zip
265    pub fn size(&self) -> usize {
266        self.dimension.size()
267    }
268
269    /// Return the length of `axis`
270    ///
271    /// ***Panics*** if `axis` is out of bounds.
272    fn len_of(&self, axis: Axis) -> usize {
273        self.dimension[axis.index()]
274    }
275
276    fn prefer_f(&self) -> bool {
277        !self.layout.is(Layout::CORDER) &&
278            (self.layout.is(Layout::FORDER) || self.layout_tendency < 0)
279    }
280
281    /// Return an *approximation* to the max stride axis; if
282    /// component arrays disagree, there may be no choice better than the
283    /// others.
284    fn max_stride_axis(&self) -> Axis {
285        let i = if self.prefer_f() {
286            self
287                .dimension
288                .slice()
289                .iter()
290                .rposition(|&len| len > 1)
291                .unwrap_or(self.dimension.ndim() - 1)
292        } else {
293            /* corder or default */
294            self
295                .dimension
296                .slice()
297                .iter()
298                .position(|&len| len > 1)
299                .unwrap_or(0)
300        };
301        Axis(i)
302    }
303}
304
305impl<P, D> Zip<P, D>
306where
307    D: Dimension,
308{
309    fn for_each_core<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
310    where
311        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
312        P: ZippableTuple<Dim = D>,
313    {
314        if self.dimension.ndim() == 0 {
315            function(acc, unsafe { self.parts.as_ref(self.parts.as_ptr()) })
316        } else if self.layout.is(Layout::CORDER | Layout::FORDER) {
317            self.for_each_core_contiguous(acc, function)
318        } else {
319            self.for_each_core_strided(acc, function)
320        }
321    }
322
323    fn for_each_core_contiguous<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
324    where
325        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
326        P: ZippableTuple<Dim = D>,
327    {
328        debug_assert!(self.layout.is(Layout::CORDER | Layout::FORDER));
329        let size = self.dimension.size();
330        let ptrs = self.parts.as_ptr();
331        let inner_strides = self.parts.contiguous_stride();
332        unsafe {
333            self.inner(acc, ptrs, inner_strides, size, &mut function)
334        }
335    }
336
337    /// The innermost loop of the Zip for_each methods
338    ///
339    /// Run the fold while operation on a stretch of elements with constant strides
340    ///
341    /// `ptr`: base pointer for the first element in this stretch
342    /// `strides`: strides for the elements in this stretch
343    /// `len`: number of elements
344    /// `function`: closure
345    unsafe fn inner<F, Acc>(&self, mut acc: Acc, ptr: P::Ptr, strides: P::Stride,
346                            len: usize, function: &mut F) -> FoldWhile<Acc>
347    where
348        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
349        P: ZippableTuple
350    {
351        let mut i = 0;
352        while i < len {
353            let p = ptr.stride_offset(strides, i);
354            acc = fold_while!(function(acc, self.parts.as_ref(p)));
355            i += 1;
356        }
357        FoldWhile::Continue(acc)
358    }
359
360
361    fn for_each_core_strided<F, Acc>(&mut self, acc: Acc, function: F) -> FoldWhile<Acc>
362    where
363        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
364        P: ZippableTuple<Dim = D>,
365    {
366        let n = self.dimension.ndim();
367        if n == 0 {
368            panic!("Unreachable: ndim == 0 is contiguous")
369        }
370        if n == 1 || self.layout_tendency >= 0 {
371            self.for_each_core_strided_c(acc, function)
372        } else {
373            self.for_each_core_strided_f(acc, function)
374        }
375    }
376
377    // Non-contiguous but preference for C - unroll over Axis(ndim - 1)
378    fn for_each_core_strided_c<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
379    where
380        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
381        P: ZippableTuple<Dim = D>,
382    {
383        let n = self.dimension.ndim();
384        let unroll_axis = n - 1;
385        let inner_len = self.dimension[unroll_axis];
386        self.dimension[unroll_axis] = 1;
387        let mut index_ = self.dimension.first_index();
388        let inner_strides = self.parts.stride_of(unroll_axis);
389        // Loop unrolled over closest axis
390        while let Some(index) = index_ {
391            unsafe {
392                let ptr = self.parts.uget_ptr(&index);
393                acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
394            }
395
396            index_ = self.dimension.next_for(index);
397        }
398        FoldWhile::Continue(acc)
399    }
400
401    // Non-contiguous but preference for F - unroll over Axis(0)
402    fn for_each_core_strided_f<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
403    where
404        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
405        P: ZippableTuple<Dim = D>,
406    {
407        let unroll_axis = 0;
408        let inner_len = self.dimension[unroll_axis];
409        self.dimension[unroll_axis] = 1;
410        let index_ = self.dimension.first_index();
411        let inner_strides = self.parts.stride_of(unroll_axis);
412        // Loop unrolled over closest axis
413        if let Some(mut index) = index_ {
414            loop {
415                unsafe {
416                    let ptr = self.parts.uget_ptr(&index);
417                    acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
418                }
419
420                if !self.dimension.next_for_f(&mut index) {
421                    break;
422                }
423            }
424        }
425        FoldWhile::Continue(acc)
426    }
427
428    #[cfg(feature = "rayon")]
429    pub(crate) fn uninitialized_for_current_layout<T>(&self) -> Array<MaybeUninit<T>, D>
430    {
431        let is_f = self.prefer_f();
432        Array::uninit(self.dimension.clone().set_f(is_f))
433    }
434}
435
436impl<D, P1, P2> Zip<(P1, P2), D>
437where
438    D: Dimension,
439    P1: NdProducer<Dim=D>,
440    P1: NdProducer<Dim=D>,
441{
442    /// Debug assert traversal order is like c (including 1D case)
443    // Method placement: only used for binary Zip at the moment.
444    #[inline]
445    pub(crate) fn debug_assert_c_order(self) -> Self {
446        debug_assert!(self.layout.is(Layout::CORDER) || self.layout_tendency >= 0 ||
447                      self.dimension.slice().iter().filter(|&&d| d > 1).count() <= 1,
448                      "Assertion failed: traversal is not c-order or 1D for \
449                      layout {:?}, tendency {}, dimension {:?}",
450                      self.layout, self.layout_tendency, self.dimension);
451        self
452    }
453}
454
455
456/*
457trait Offset : Copy {
458    unsafe fn offset(self, off: isize) -> Self;
459    unsafe fn stride_offset(self, index: usize, stride: isize) -> Self {
460        self.offset(index as isize * stride)
461    }
462}
463
464impl<T> Offset for *mut T {
465    unsafe fn offset(self, off: isize) -> Self {
466        self.offset(off)
467    }
468}
469*/
470
471trait OffsetTuple {
472    type Args;
473    unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self;
474}
475
476impl<T> OffsetTuple for *mut T {
477    type Args = isize;
478    unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self {
479        self.offset(index as isize * stride)
480    }
481}
482
483macro_rules! offset_impl {
484    ($([$($param:ident)*][ $($q:ident)*],)+) => {
485        $(
486        #[allow(non_snake_case)]
487        impl<$($param: Offset),*> OffsetTuple for ($($param, )*) {
488            type Args = ($($param::Stride,)*);
489            unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self {
490                let ($($param, )*) = self;
491                let ($($q, )*) = stride;
492                ($(Offset::stride_offset($param, $q, index),)*)
493            }
494        }
495        )+
496    }
497}
498
499offset_impl! {
500    [A ][ a],
501    [A B][ a b],
502    [A B C][ a b c],
503    [A B C D][ a b c d],
504    [A B C D E][ a b c d e],
505    [A B C D E F][ a b c d e f],
506}
507
508macro_rules! zipt_impl {
509    ($([$($p:ident)*][ $($q:ident)*],)+) => {
510        $(
511        #[allow(non_snake_case)]
512        impl<Dim: Dimension, $($p: NdProducer<Dim=Dim>),*> ZippableTuple for ($($p, )*) {
513            type Item = ($($p::Item, )*);
514            type Ptr = ($($p::Ptr, )*);
515            type Dim = Dim;
516            type Stride = ($($p::Stride,)* );
517
518            fn stride_of(&self, index: usize) -> Self::Stride {
519                let ($(ref $p,)*) = *self;
520                ($($p.stride_of(Axis(index)), )*)
521            }
522
523            fn contiguous_stride(&self) -> Self::Stride {
524                let ($(ref $p,)*) = *self;
525                ($($p.contiguous_stride(), )*)
526            }
527
528            fn as_ptr(&self) -> Self::Ptr {
529                let ($(ref $p,)*) = *self;
530                ($($p.as_ptr(), )*)
531            }
532            unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
533                let ($(ref $q ,)*) = *self;
534                let ($($p,)*) = ptr;
535                ($($q.as_ref($p),)*)
536            }
537
538            unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
539                let ($(ref $p,)*) = *self;
540                ($($p.uget_ptr(i), )*)
541            }
542
543            fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
544                let ($($p,)*) = self;
545                let ($($p,)*) = (
546                    $($p.split_at(axis, index), )*
547                );
548                (
549                    ($($p.0,)*),
550                    ($($p.1,)*)
551                )
552            }
553        }
554        )+
555    }
556}
557
558zipt_impl! {
559    [A ][ a],
560    [A B][ a b],
561    [A B C][ a b c],
562    [A B C D][ a b c d],
563    [A B C D E][ a b c d e],
564    [A B C D E F][ a b c d e f],
565}
566
567macro_rules! map_impl {
568    ($([$notlast:ident $($p:ident)*],)+) => {
569        $(
570        #[allow(non_snake_case)]
571        impl<D, $($p),*> Zip<($($p,)*), D>
572            where D: Dimension,
573                  $($p: NdProducer<Dim=D> ,)*
574        {
575            /// Apply a function to all elements of the input arrays,
576            /// visiting elements in lock step.
577            pub fn for_each<F>(mut self, mut function: F)
578                where F: FnMut($($p::Item),*)
579            {
580                self.for_each_core((), move |(), args| {
581                    let ($($p,)*) = args;
582                    FoldWhile::Continue(function($($p),*))
583                });
584            }
585
586            /// Apply a function to all elements of the input arrays,
587            /// visiting elements in lock step.
588            #[deprecated(note="Renamed to .for_each()", since="0.15.0")]
589            pub fn apply<F>(self, function: F)
590                where F: FnMut($($p::Item),*)
591            {
592                self.for_each(function)
593            }
594
595            /// Apply a fold function to all elements of the input arrays,
596            /// visiting elements in lock step.
597            ///
598            /// # Example
599            ///
600            /// The expression `tr(AᵀB)` can be more efficiently computed as
601            /// the equivalent expression `∑ᵢⱼ(A∘B)ᵢⱼ` (i.e. the sum of the
602            /// elements of the entry-wise product). It would be possible to
603            /// evaluate this expression by first computing the entry-wise
604            /// product, `A∘B`, and then computing the elementwise sum of that
605            /// product, but it's possible to do this in a single loop (and
606            /// avoid an extra heap allocation if `A` and `B` can't be
607            /// consumed) by using `Zip`:
608            ///
609            /// ```
610            /// use ndarray::{array, Zip};
611            ///
612            /// let a = array![[1, 5], [3, 7]];
613            /// let b = array![[2, 4], [8, 6]];
614            ///
615            /// // Without using `Zip`. This involves two loops and an extra
616            /// // heap allocation for the result of `&a * &b`.
617            /// let sum_prod_nonzip = (&a * &b).sum();
618            /// // Using `Zip`. This is a single loop without any heap allocations.
619            /// let sum_prod_zip = Zip::from(&a).and(&b).fold(0, |acc, a, b| acc + a * b);
620            ///
621            /// assert_eq!(sum_prod_nonzip, sum_prod_zip);
622            /// ```
623            pub fn fold<F, Acc>(mut self, acc: Acc, mut function: F) -> Acc
624            where
625                F: FnMut(Acc, $($p::Item),*) -> Acc,
626            {
627                self.for_each_core(acc, move |acc, args| {
628                    let ($($p,)*) = args;
629                    FoldWhile::Continue(function(acc, $($p),*))
630                }).into_inner()
631            }
632
633            /// Apply a fold function to the input arrays while the return
634            /// value is `FoldWhile::Continue`, visiting elements in lock step.
635            ///
636            pub fn fold_while<F, Acc>(mut self, acc: Acc, mut function: F)
637                -> FoldWhile<Acc>
638                where F: FnMut(Acc, $($p::Item),*) -> FoldWhile<Acc>
639            {
640                self.for_each_core(acc, move |acc, args| {
641                    let ($($p,)*) = args;
642                    function(acc, $($p),*)
643                })
644            }
645
646            /// Tests if every element of the iterator matches a predicate.
647            ///
648            /// Returns `true` if `predicate` evaluates to `true` for all elements.
649            /// Returns `true` if the input arrays are empty.
650            ///
651            /// Example:
652            ///
653            /// ```
654            /// use ndarray::{array, Zip};
655            /// let a = array![1, 2, 3];
656            /// let b = array![1, 4, 9];
657            /// assert!(Zip::from(&a).and(&b).all(|&a, &b| a * a == b));
658            /// ```
659            pub fn all<F>(mut self, mut predicate: F) -> bool
660                where F: FnMut($($p::Item),*) -> bool
661            {
662                !self.for_each_core((), move |_, args| {
663                    let ($($p,)*) = args;
664                    if predicate($($p),*) {
665                        FoldWhile::Continue(())
666                    } else {
667                        FoldWhile::Done(())
668                    }
669                }).is_done()
670            }
671
672            expand_if!(@bool [$notlast]
673
674            /// Include the producer `p` in the Zip.
675            ///
676            /// ***Panics*** if `p`’s shape doesn’t match the Zip’s exactly.
677            pub fn and<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
678                where P: IntoNdProducer<Dim=D>,
679            {
680                let part = p.into_producer();
681                zip_dimension_check(&self.dimension, &part);
682                self.build_and(part)
683            }
684
685            /// Include the producer `p` in the Zip.
686            ///
687            /// ## Safety
688            ///
689            /// The caller must ensure that the producer's shape is equal to the Zip's shape.
690            /// Uses assertions when debug assertions are enabled.
691            #[allow(unused)]
692            pub(crate) unsafe fn and_unchecked<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
693                where P: IntoNdProducer<Dim=D>,
694            {
695                #[cfg(debug_assertions)]
696                {
697                    self.and(p)
698                }
699                #[cfg(not(debug_assertions))]
700                {
701                    self.build_and(p.into_producer())
702                }
703            }
704
705            /// Include the producer `p` in the Zip, broadcasting if needed.
706            ///
707            /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
708            ///
709            /// ***Panics*** if broadcasting isn’t possible.
710            pub fn and_broadcast<'a, P, D2, Elem>(self, p: P)
711                -> Zip<($($p,)* ArrayView<'a, Elem, D>, ), D>
712                where P: IntoNdProducer<Dim=D2, Output=ArrayView<'a, Elem, D2>, Item=&'a Elem>,
713                      D2: Dimension,
714            {
715                let part = p.into_producer().broadcast_unwrap(self.dimension.clone());
716                self.build_and(part)
717            }
718
719            fn build_and<P>(self, part: P) -> Zip<($($p,)* P, ), D>
720                where P: NdProducer<Dim=D>,
721            {
722                let part_layout = part.layout();
723                let ($($p,)*) = self.parts;
724                Zip {
725                    parts: ($($p,)* part, ),
726                    layout: self.layout.intersect(part_layout),
727                    dimension: self.dimension,
728                    layout_tendency: self.layout_tendency + part_layout.tendency(),
729                }
730            }
731
732            /// Map and collect the results into a new array, which has the same size as the
733            /// inputs.
734            ///
735            /// If all inputs are c- or f-order respectively, that is preserved in the output.
736            pub fn map_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D> {
737                self.map_collect_owned(f)
738            }
739
740            pub(crate) fn map_collect_owned<S, R>(self, f: impl FnMut($($p::Item,)* ) -> R)
741                -> ArrayBase<S, D>
742                where S: DataOwned<Elem = R>
743            {
744                // safe because: all elements are written before the array is completed
745
746                let shape = self.dimension.clone().set_f(self.prefer_f());
747                let output = <ArrayBase<S, D>>::build_uninit(shape, |output| {
748                    // Use partial to count the number of filled elements, and can drop the right
749                    // number of elements on unwinding (if it happens during apply/collect).
750                    unsafe {
751                        let output_view = output.into_raw_view_mut().cast::<R>();
752                        self.and(output_view)
753                            .collect_with_partial(f)
754                            .release_ownership();
755                    }
756                });
757                unsafe {
758                    output.assume_init()
759                }
760            }
761
762            /// Map and collect the results into a new array, which has the same size as the
763            /// inputs.
764            ///
765            /// If all inputs are c- or f-order respectively, that is preserved in the output.
766            #[deprecated(note="Renamed to .map_collect()", since="0.15.0")]
767            pub fn apply_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D> {
768                self.map_collect(f)
769            }
770
771            /// Map and assign the results into the producer `into`, which should have the same
772            /// size as the other inputs.
773            ///
774            /// The producer should have assignable items as dictated by the `AssignElem` trait,
775            /// for example `&mut R`.
776            pub fn map_assign_into<R, Q>(self, into: Q, mut f: impl FnMut($($p::Item,)* ) -> R)
777                where Q: IntoNdProducer<Dim=D>,
778                      Q::Item: AssignElem<R>
779            {
780                self.and(into)
781                    .for_each(move |$($p, )* output_| {
782                        output_.assign_elem(f($($p ),*));
783                    });
784            }
785
786            /// Map and assign the results into the producer `into`, which should have the same
787            /// size as the other inputs.
788            ///
789            /// The producer should have assignable items as dictated by the `AssignElem` trait,
790            /// for example `&mut R`.
791            #[deprecated(note="Renamed to .map_assign_into()", since="0.15.0")]
792            pub fn apply_assign_into<R, Q>(self, into: Q, f: impl FnMut($($p::Item,)* ) -> R)
793                where Q: IntoNdProducer<Dim=D>,
794                      Q::Item: AssignElem<R>
795            {
796                self.map_assign_into(into, f)
797            }
798
799
800            );
801
802            /// Split the `Zip` evenly in two.
803            ///
804            /// It will be split in the way that best preserves element locality.
805            pub fn split(self) -> (Self, Self) {
806                debug_assert_ne!(self.size(), 0, "Attempt to split empty zip");
807                debug_assert_ne!(self.size(), 1, "Attempt to split zip with 1 elem");
808                SplitPreference::split(self)
809            }
810        }
811
812        expand_if!(@bool [$notlast]
813            // For collect; Last producer is a RawViewMut
814            #[allow(non_snake_case)]
815            impl<D, PLast, R, $($p),*> Zip<($($p,)* PLast), D>
816                where D: Dimension,
817                      $($p: NdProducer<Dim=D> ,)*
818                      PLast: NdProducer<Dim = D, Item = *mut R, Ptr = *mut R, Stride = isize>,
819            {
820                /// The inner workings of map_collect and par_map_collect
821                ///
822                /// Apply the function and collect the results into the output (last producer)
823                /// which should be a raw array view; a Partial that owns the written
824                /// elements is returned.
825                ///
826                /// Elements will be overwritten in place (in the sense of std::ptr::write).
827                ///
828                /// ## Safety
829                ///
830                /// The last producer is a RawArrayViewMut and must be safe to write into.
831                /// The producer must be c- or f-contig and have the same layout tendency
832                /// as the whole Zip.
833                ///
834                /// The returned Partial's proxy ownership of the elements must be handled,
835                /// before the array the raw view points to realizes its ownership.
836                pub(crate) unsafe fn collect_with_partial<F>(self, mut f: F) -> Partial<R>
837                    where F: FnMut($($p::Item,)* ) -> R
838                {
839                    // Get the last producer; and make a Partial that aliases its data pointer
840                    let (.., ref output) = &self.parts;
841
842                    // debug assert that the output is contiguous in the memory layout we need
843                    if cfg!(debug_assertions) {
844                        let out_layout = output.layout();
845                        assert!(out_layout.is(Layout::CORDER | Layout::FORDER));
846                        assert!(
847                            (self.layout_tendency <= 0 && out_layout.tendency() <= 0) ||
848                            (self.layout_tendency >= 0 && out_layout.tendency() >= 0),
849                            "layout tendency violation for self layout {:?}, output layout {:?},\
850                            output shape {:?}",
851                            self.layout, out_layout, output.raw_dim());
852                    }
853
854                    let mut partial = Partial::new(output.as_ptr());
855
856                    // Apply the mapping function on this zip
857                    // if we panic with unwinding; Partial will drop the written elements.
858                    let partial_len = &mut partial.len;
859                    self.for_each(move |$($p,)* output_elem: *mut R| {
860                        output_elem.write(f($($p),*));
861                        if std::mem::needs_drop::<R>() {
862                            *partial_len += 1;
863                        }
864                    });
865
866                    partial
867                }
868            }
869        );
870
871        impl<D, $($p),*> SplitPreference for Zip<($($p,)*), D>
872            where D: Dimension,
873                  $($p: NdProducer<Dim=D> ,)*
874        {
875            fn can_split(&self) -> bool { self.size() > 1 }
876
877            fn split_preference(&self) -> (Axis, usize) {
878                // Always split in a way that preserves layout (if any)
879                let axis = self.max_stride_axis();
880                let index = self.len_of(axis) / 2;
881                (axis, index)
882            }
883        }
884
885        impl<D, $($p),*> SplitAt for Zip<($($p,)*), D>
886            where D: Dimension,
887                  $($p: NdProducer<Dim=D> ,)*
888        {
889            fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
890                let (p1, p2) = self.parts.split_at(axis, index);
891                let (d1, d2) = self.dimension.split_at(axis, index);
892                (Zip {
893                    dimension: d1,
894                    layout: self.layout,
895                    parts: p1,
896                    layout_tendency: self.layout_tendency,
897                },
898                Zip {
899                    dimension: d2,
900                    layout: self.layout,
901                    parts: p2,
902                    layout_tendency: self.layout_tendency,
903                })
904            }
905
906        }
907
908        )+
909    }
910}
911
912map_impl! {
913    [true P1],
914    [true P1 P2],
915    [true P1 P2 P3],
916    [true P1 P2 P3 P4],
917    [true P1 P2 P3 P4 P5],
918    [false P1 P2 P3 P4 P5 P6],
919}
920
921/// Value controlling the execution of `.fold_while` on `Zip`.
922#[derive(Debug, Copy, Clone)]
923pub enum FoldWhile<T> {
924    /// Continue folding with this value
925    Continue(T),
926    /// Fold is complete and will return this value
927    Done(T),
928}
929
930impl<T> FoldWhile<T> {
931    /// Return the inner value
932    pub fn into_inner(self) -> T {
933        match self {
934            FoldWhile::Continue(x) | FoldWhile::Done(x) => x,
935        }
936    }
937
938    /// Return true if it is `Done`, false if `Continue`
939    pub fn is_done(&self) -> bool {
940        match *self {
941            FoldWhile::Continue(_) => false,
942            FoldWhile::Done(_) => true,
943        }
944    }
945}