ndarray/zip/mod.rs
1// Copyright 2017 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[macro_use]
10mod zipmacro;
11mod ndproducer;
12
13#[cfg(feature = "rayon")]
14use std::mem::MaybeUninit;
15
16use crate::imp_prelude::*;
17use crate::AssignElem;
18use crate::IntoDimension;
19use crate::Layout;
20use crate::partial::Partial;
21
22use crate::indexes::{indices, Indices};
23use crate::split_at::{SplitPreference, SplitAt};
24use crate::dimension;
25
26pub use self::ndproducer::{NdProducer, IntoNdProducer, Offset};
27
28/// Return if the expression is a break value.
29macro_rules! fold_while {
30 ($e:expr) => {
31 match $e {
32 FoldWhile::Continue(x) => x,
33 x => return x,
34 }
35 };
36}
37
38/// Broadcast an array so that it acts like a larger size and/or shape array.
39///
40/// See [broadcasting](ArrayBase#broadcasting) for more information.
41trait Broadcast<E>
42where
43 E: IntoDimension,
44{
45 type Output: NdProducer<Dim = E::Dim>;
46 /// Broadcast the array to the new dimensions `shape`.
47 ///
48 /// ***Panics*** if broadcasting isn’t possible.
49 fn broadcast_unwrap(self, shape: E) -> Self::Output;
50 private_decl! {}
51}
52
53/// Compute `Layout` hints for array shape dim, strides
54fn array_layout<D: Dimension>(dim: &D, strides: &D) -> Layout {
55 let n = dim.ndim();
56 if dimension::is_layout_c(dim, strides) {
57 // effectively one-dimensional => C and F layout compatible
58 if n <= 1 || dim.slice().iter().filter(|&&len| len > 1).count() <= 1 {
59 Layout::one_dimensional()
60 } else {
61 Layout::c()
62 }
63 } else if n > 1 && dimension::is_layout_f(dim, strides) {
64 Layout::f()
65 } else if n > 1 {
66 if dim[0] > 1 && strides[0] == 1 {
67 Layout::fpref()
68 } else if dim[n - 1] > 1 && strides[n - 1] == 1 {
69 Layout::cpref()
70 } else {
71 Layout::none()
72 }
73 } else {
74 Layout::none()
75 }
76}
77
78impl<S, D> ArrayBase<S, D>
79where
80 S: RawData,
81 D: Dimension,
82{
83 pub(crate) fn layout_impl(&self) -> Layout {
84 array_layout(&self.dim, &self.strides)
85 }
86}
87
88impl<'a, A, D, E> Broadcast<E> for ArrayView<'a, A, D>
89where
90 E: IntoDimension,
91 D: Dimension,
92{
93 type Output = ArrayView<'a, A, E::Dim>;
94 fn broadcast_unwrap(self, shape: E) -> Self::Output {
95 let res: ArrayView<'_, A, E::Dim> = (&self).broadcast_unwrap(shape.into_dimension());
96 unsafe { ArrayView::new(res.ptr, res.dim, res.strides) }
97 }
98 private_impl! {}
99}
100
101trait ZippableTuple: Sized {
102 type Item;
103 type Ptr: OffsetTuple<Args = Self::Stride> + Copy;
104 type Dim: Dimension;
105 type Stride: Copy;
106 fn as_ptr(&self) -> Self::Ptr;
107 unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item;
108 unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr;
109 fn stride_of(&self, index: usize) -> Self::Stride;
110 fn contiguous_stride(&self) -> Self::Stride;
111 fn split_at(self, axis: Axis, index: usize) -> (Self, Self);
112}
113
114/// Lock step function application across several arrays or other producers.
115///
116/// Zip allows matching several producers to each other elementwise and applying
117/// a function over all tuples of elements (one item from each input at
118/// a time).
119///
120/// In general, the zip uses a tuple of producers
121/// ([`NdProducer`] trait) that all have to be of the
122/// same shape. The NdProducer implementation defines what its item type is
123/// (for example if it's a shared reference, mutable reference or an array
124/// view etc).
125///
126/// If all the input arrays are of the same memory layout the zip performs much
127/// better and the compiler can usually vectorize the loop (if applicable).
128///
129/// The order elements are visited is not specified. The producers don’t have to
130/// have the same item type.
131///
132/// The `Zip` has two methods for function application: `for_each` and
133/// `fold_while`. The zip object can be split, which allows parallelization.
134/// A read-only zip object (no mutable producers) can be cloned.
135///
136/// See also the [`azip!()`] which offers a convenient shorthand
137/// to common ways to use `Zip`.
138///
139/// ```
140/// use ndarray::Zip;
141/// use ndarray::Array2;
142///
143/// type M = Array2<f64>;
144///
145/// // Create four 2d arrays of the same size
146/// let mut a = M::zeros((64, 32));
147/// let b = M::from_elem(a.dim(), 1.);
148/// let c = M::from_elem(a.dim(), 2.);
149/// let d = M::from_elem(a.dim(), 3.);
150///
151/// // Example 1: Perform an elementwise arithmetic operation across
152/// // the four arrays a, b, c, d.
153///
154/// Zip::from(&mut a)
155/// .and(&b)
156/// .and(&c)
157/// .and(&d)
158/// .for_each(|w, &x, &y, &z| {
159/// *w += x + y * z;
160/// });
161///
162/// // Example 2: Create a new array `totals` with one entry per row of `a`.
163/// // Use Zip to traverse the rows of `a` and assign to the corresponding
164/// // entry in `totals` with the sum across each row.
165/// // This is possible because the producer for `totals` and the row producer
166/// // for `a` have the same shape and dimensionality.
167/// // The rows producer yields one array view (`row`) per iteration.
168///
169/// use ndarray::{Array1, Axis};
170///
171/// let mut totals = Array1::zeros(a.nrows());
172///
173/// Zip::from(&mut totals)
174/// .and(a.rows())
175/// .for_each(|totals, row| *totals = row.sum());
176///
177/// // Check the result against the built in `.sum_axis()` along axis 1.
178/// assert_eq!(totals, a.sum_axis(Axis(1)));
179///
180///
181/// // Example 3: Recreate Example 2 using map_collect to make a new array
182///
183/// let totals2 = Zip::from(a.rows()).map_collect(|row| row.sum());
184///
185/// // Check the result against the previous example.
186/// assert_eq!(totals, totals2);
187/// ```
188#[derive(Debug, Clone)]
189#[must_use = "zipping producers is lazy and does nothing unless consumed"]
190pub struct Zip<Parts, D> {
191 parts: Parts,
192 dimension: D,
193 layout: Layout,
194 /// The sum of the layout tendencies of the parts;
195 /// positive for c- and negative for f-layout preference.
196 layout_tendency: i32,
197}
198
199
200impl<P, D> Zip<(P,), D>
201where
202 D: Dimension,
203 P: NdProducer<Dim = D>,
204{
205 /// Create a new `Zip` from the input array or other producer `p`.
206 ///
207 /// The Zip will take the exact dimension of `p` and all inputs
208 /// must have the same dimensions (or be broadcast to them).
209 pub fn from<IP>(p: IP) -> Self
210 where
211 IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>,
212 {
213 let array = p.into_producer();
214 let dim = array.raw_dim();
215 let layout = array.layout();
216 Zip {
217 dimension: dim,
218 layout,
219 parts: (array,),
220 layout_tendency: layout.tendency(),
221 }
222 }
223}
224impl<P, D> Zip<(Indices<D>, P), D>
225where
226 D: Dimension + Copy,
227 P: NdProducer<Dim = D>,
228{
229 /// Create a new `Zip` with an index producer and the producer `p`.
230 ///
231 /// The Zip will take the exact dimension of `p` and all inputs
232 /// must have the same dimensions (or be broadcast to them).
233 ///
234 /// *Note:* Indexed zip has overhead.
235 pub fn indexed<IP>(p: IP) -> Self
236 where
237 IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>,
238 {
239 let array = p.into_producer();
240 let dim = array.raw_dim();
241 Zip::from(indices(dim)).and(array)
242 }
243}
244
245#[inline]
246fn zip_dimension_check<D, P>(dimension: &D, part: &P)
247where
248 D: Dimension,
249 P: NdProducer<Dim = D>,
250{
251 ndassert!(
252 part.equal_dim(dimension),
253 "Zip: Producer dimension mismatch, expected: {:?}, got: {:?}",
254 dimension,
255 part.raw_dim()
256 );
257}
258
259
260impl<Parts, D> Zip<Parts, D>
261where
262 D: Dimension,
263{
264 /// Return a the number of element tuples in the Zip
265 pub fn size(&self) -> usize {
266 self.dimension.size()
267 }
268
269 /// Return the length of `axis`
270 ///
271 /// ***Panics*** if `axis` is out of bounds.
272 fn len_of(&self, axis: Axis) -> usize {
273 self.dimension[axis.index()]
274 }
275
276 fn prefer_f(&self) -> bool {
277 !self.layout.is(Layout::CORDER) &&
278 (self.layout.is(Layout::FORDER) || self.layout_tendency < 0)
279 }
280
281 /// Return an *approximation* to the max stride axis; if
282 /// component arrays disagree, there may be no choice better than the
283 /// others.
284 fn max_stride_axis(&self) -> Axis {
285 let i = if self.prefer_f() {
286 self
287 .dimension
288 .slice()
289 .iter()
290 .rposition(|&len| len > 1)
291 .unwrap_or(self.dimension.ndim() - 1)
292 } else {
293 /* corder or default */
294 self
295 .dimension
296 .slice()
297 .iter()
298 .position(|&len| len > 1)
299 .unwrap_or(0)
300 };
301 Axis(i)
302 }
303}
304
305impl<P, D> Zip<P, D>
306where
307 D: Dimension,
308{
309 fn for_each_core<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
310 where
311 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
312 P: ZippableTuple<Dim = D>,
313 {
314 if self.dimension.ndim() == 0 {
315 function(acc, unsafe { self.parts.as_ref(self.parts.as_ptr()) })
316 } else if self.layout.is(Layout::CORDER | Layout::FORDER) {
317 self.for_each_core_contiguous(acc, function)
318 } else {
319 self.for_each_core_strided(acc, function)
320 }
321 }
322
323 fn for_each_core_contiguous<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
324 where
325 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
326 P: ZippableTuple<Dim = D>,
327 {
328 debug_assert!(self.layout.is(Layout::CORDER | Layout::FORDER));
329 let size = self.dimension.size();
330 let ptrs = self.parts.as_ptr();
331 let inner_strides = self.parts.contiguous_stride();
332 unsafe {
333 self.inner(acc, ptrs, inner_strides, size, &mut function)
334 }
335 }
336
337 /// The innermost loop of the Zip for_each methods
338 ///
339 /// Run the fold while operation on a stretch of elements with constant strides
340 ///
341 /// `ptr`: base pointer for the first element in this stretch
342 /// `strides`: strides for the elements in this stretch
343 /// `len`: number of elements
344 /// `function`: closure
345 unsafe fn inner<F, Acc>(&self, mut acc: Acc, ptr: P::Ptr, strides: P::Stride,
346 len: usize, function: &mut F) -> FoldWhile<Acc>
347 where
348 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
349 P: ZippableTuple
350 {
351 let mut i = 0;
352 while i < len {
353 let p = ptr.stride_offset(strides, i);
354 acc = fold_while!(function(acc, self.parts.as_ref(p)));
355 i += 1;
356 }
357 FoldWhile::Continue(acc)
358 }
359
360
361 fn for_each_core_strided<F, Acc>(&mut self, acc: Acc, function: F) -> FoldWhile<Acc>
362 where
363 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
364 P: ZippableTuple<Dim = D>,
365 {
366 let n = self.dimension.ndim();
367 if n == 0 {
368 panic!("Unreachable: ndim == 0 is contiguous")
369 }
370 if n == 1 || self.layout_tendency >= 0 {
371 self.for_each_core_strided_c(acc, function)
372 } else {
373 self.for_each_core_strided_f(acc, function)
374 }
375 }
376
377 // Non-contiguous but preference for C - unroll over Axis(ndim - 1)
378 fn for_each_core_strided_c<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
379 where
380 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
381 P: ZippableTuple<Dim = D>,
382 {
383 let n = self.dimension.ndim();
384 let unroll_axis = n - 1;
385 let inner_len = self.dimension[unroll_axis];
386 self.dimension[unroll_axis] = 1;
387 let mut index_ = self.dimension.first_index();
388 let inner_strides = self.parts.stride_of(unroll_axis);
389 // Loop unrolled over closest axis
390 while let Some(index) = index_ {
391 unsafe {
392 let ptr = self.parts.uget_ptr(&index);
393 acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
394 }
395
396 index_ = self.dimension.next_for(index);
397 }
398 FoldWhile::Continue(acc)
399 }
400
401 // Non-contiguous but preference for F - unroll over Axis(0)
402 fn for_each_core_strided_f<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
403 where
404 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
405 P: ZippableTuple<Dim = D>,
406 {
407 let unroll_axis = 0;
408 let inner_len = self.dimension[unroll_axis];
409 self.dimension[unroll_axis] = 1;
410 let index_ = self.dimension.first_index();
411 let inner_strides = self.parts.stride_of(unroll_axis);
412 // Loop unrolled over closest axis
413 if let Some(mut index) = index_ {
414 loop {
415 unsafe {
416 let ptr = self.parts.uget_ptr(&index);
417 acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
418 }
419
420 if !self.dimension.next_for_f(&mut index) {
421 break;
422 }
423 }
424 }
425 FoldWhile::Continue(acc)
426 }
427
428 #[cfg(feature = "rayon")]
429 pub(crate) fn uninitialized_for_current_layout<T>(&self) -> Array<MaybeUninit<T>, D>
430 {
431 let is_f = self.prefer_f();
432 Array::uninit(self.dimension.clone().set_f(is_f))
433 }
434}
435
436impl<D, P1, P2> Zip<(P1, P2), D>
437where
438 D: Dimension,
439 P1: NdProducer<Dim=D>,
440 P1: NdProducer<Dim=D>,
441{
442 /// Debug assert traversal order is like c (including 1D case)
443 // Method placement: only used for binary Zip at the moment.
444 #[inline]
445 pub(crate) fn debug_assert_c_order(self) -> Self {
446 debug_assert!(self.layout.is(Layout::CORDER) || self.layout_tendency >= 0 ||
447 self.dimension.slice().iter().filter(|&&d| d > 1).count() <= 1,
448 "Assertion failed: traversal is not c-order or 1D for \
449 layout {:?}, tendency {}, dimension {:?}",
450 self.layout, self.layout_tendency, self.dimension);
451 self
452 }
453}
454
455
456/*
457trait Offset : Copy {
458 unsafe fn offset(self, off: isize) -> Self;
459 unsafe fn stride_offset(self, index: usize, stride: isize) -> Self {
460 self.offset(index as isize * stride)
461 }
462}
463
464impl<T> Offset for *mut T {
465 unsafe fn offset(self, off: isize) -> Self {
466 self.offset(off)
467 }
468}
469*/
470
471trait OffsetTuple {
472 type Args;
473 unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self;
474}
475
476impl<T> OffsetTuple for *mut T {
477 type Args = isize;
478 unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self {
479 self.offset(index as isize * stride)
480 }
481}
482
483macro_rules! offset_impl {
484 ($([$($param:ident)*][ $($q:ident)*],)+) => {
485 $(
486 #[allow(non_snake_case)]
487 impl<$($param: Offset),*> OffsetTuple for ($($param, )*) {
488 type Args = ($($param::Stride,)*);
489 unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self {
490 let ($($param, )*) = self;
491 let ($($q, )*) = stride;
492 ($(Offset::stride_offset($param, $q, index),)*)
493 }
494 }
495 )+
496 }
497}
498
499offset_impl! {
500 [A ][ a],
501 [A B][ a b],
502 [A B C][ a b c],
503 [A B C D][ a b c d],
504 [A B C D E][ a b c d e],
505 [A B C D E F][ a b c d e f],
506}
507
508macro_rules! zipt_impl {
509 ($([$($p:ident)*][ $($q:ident)*],)+) => {
510 $(
511 #[allow(non_snake_case)]
512 impl<Dim: Dimension, $($p: NdProducer<Dim=Dim>),*> ZippableTuple for ($($p, )*) {
513 type Item = ($($p::Item, )*);
514 type Ptr = ($($p::Ptr, )*);
515 type Dim = Dim;
516 type Stride = ($($p::Stride,)* );
517
518 fn stride_of(&self, index: usize) -> Self::Stride {
519 let ($(ref $p,)*) = *self;
520 ($($p.stride_of(Axis(index)), )*)
521 }
522
523 fn contiguous_stride(&self) -> Self::Stride {
524 let ($(ref $p,)*) = *self;
525 ($($p.contiguous_stride(), )*)
526 }
527
528 fn as_ptr(&self) -> Self::Ptr {
529 let ($(ref $p,)*) = *self;
530 ($($p.as_ptr(), )*)
531 }
532 unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
533 let ($(ref $q ,)*) = *self;
534 let ($($p,)*) = ptr;
535 ($($q.as_ref($p),)*)
536 }
537
538 unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
539 let ($(ref $p,)*) = *self;
540 ($($p.uget_ptr(i), )*)
541 }
542
543 fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
544 let ($($p,)*) = self;
545 let ($($p,)*) = (
546 $($p.split_at(axis, index), )*
547 );
548 (
549 ($($p.0,)*),
550 ($($p.1,)*)
551 )
552 }
553 }
554 )+
555 }
556}
557
558zipt_impl! {
559 [A ][ a],
560 [A B][ a b],
561 [A B C][ a b c],
562 [A B C D][ a b c d],
563 [A B C D E][ a b c d e],
564 [A B C D E F][ a b c d e f],
565}
566
567macro_rules! map_impl {
568 ($([$notlast:ident $($p:ident)*],)+) => {
569 $(
570 #[allow(non_snake_case)]
571 impl<D, $($p),*> Zip<($($p,)*), D>
572 where D: Dimension,
573 $($p: NdProducer<Dim=D> ,)*
574 {
575 /// Apply a function to all elements of the input arrays,
576 /// visiting elements in lock step.
577 pub fn for_each<F>(mut self, mut function: F)
578 where F: FnMut($($p::Item),*)
579 {
580 self.for_each_core((), move |(), args| {
581 let ($($p,)*) = args;
582 FoldWhile::Continue(function($($p),*))
583 });
584 }
585
586 /// Apply a function to all elements of the input arrays,
587 /// visiting elements in lock step.
588 #[deprecated(note="Renamed to .for_each()", since="0.15.0")]
589 pub fn apply<F>(self, function: F)
590 where F: FnMut($($p::Item),*)
591 {
592 self.for_each(function)
593 }
594
595 /// Apply a fold function to all elements of the input arrays,
596 /// visiting elements in lock step.
597 ///
598 /// # Example
599 ///
600 /// The expression `tr(AᵀB)` can be more efficiently computed as
601 /// the equivalent expression `∑ᵢⱼ(A∘B)ᵢⱼ` (i.e. the sum of the
602 /// elements of the entry-wise product). It would be possible to
603 /// evaluate this expression by first computing the entry-wise
604 /// product, `A∘B`, and then computing the elementwise sum of that
605 /// product, but it's possible to do this in a single loop (and
606 /// avoid an extra heap allocation if `A` and `B` can't be
607 /// consumed) by using `Zip`:
608 ///
609 /// ```
610 /// use ndarray::{array, Zip};
611 ///
612 /// let a = array![[1, 5], [3, 7]];
613 /// let b = array![[2, 4], [8, 6]];
614 ///
615 /// // Without using `Zip`. This involves two loops and an extra
616 /// // heap allocation for the result of `&a * &b`.
617 /// let sum_prod_nonzip = (&a * &b).sum();
618 /// // Using `Zip`. This is a single loop without any heap allocations.
619 /// let sum_prod_zip = Zip::from(&a).and(&b).fold(0, |acc, a, b| acc + a * b);
620 ///
621 /// assert_eq!(sum_prod_nonzip, sum_prod_zip);
622 /// ```
623 pub fn fold<F, Acc>(mut self, acc: Acc, mut function: F) -> Acc
624 where
625 F: FnMut(Acc, $($p::Item),*) -> Acc,
626 {
627 self.for_each_core(acc, move |acc, args| {
628 let ($($p,)*) = args;
629 FoldWhile::Continue(function(acc, $($p),*))
630 }).into_inner()
631 }
632
633 /// Apply a fold function to the input arrays while the return
634 /// value is `FoldWhile::Continue`, visiting elements in lock step.
635 ///
636 pub fn fold_while<F, Acc>(mut self, acc: Acc, mut function: F)
637 -> FoldWhile<Acc>
638 where F: FnMut(Acc, $($p::Item),*) -> FoldWhile<Acc>
639 {
640 self.for_each_core(acc, move |acc, args| {
641 let ($($p,)*) = args;
642 function(acc, $($p),*)
643 })
644 }
645
646 /// Tests if every element of the iterator matches a predicate.
647 ///
648 /// Returns `true` if `predicate` evaluates to `true` for all elements.
649 /// Returns `true` if the input arrays are empty.
650 ///
651 /// Example:
652 ///
653 /// ```
654 /// use ndarray::{array, Zip};
655 /// let a = array![1, 2, 3];
656 /// let b = array![1, 4, 9];
657 /// assert!(Zip::from(&a).and(&b).all(|&a, &b| a * a == b));
658 /// ```
659 pub fn all<F>(mut self, mut predicate: F) -> bool
660 where F: FnMut($($p::Item),*) -> bool
661 {
662 !self.for_each_core((), move |_, args| {
663 let ($($p,)*) = args;
664 if predicate($($p),*) {
665 FoldWhile::Continue(())
666 } else {
667 FoldWhile::Done(())
668 }
669 }).is_done()
670 }
671
672 expand_if!(@bool [$notlast]
673
674 /// Include the producer `p` in the Zip.
675 ///
676 /// ***Panics*** if `p`’s shape doesn’t match the Zip’s exactly.
677 pub fn and<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
678 where P: IntoNdProducer<Dim=D>,
679 {
680 let part = p.into_producer();
681 zip_dimension_check(&self.dimension, &part);
682 self.build_and(part)
683 }
684
685 /// Include the producer `p` in the Zip.
686 ///
687 /// ## Safety
688 ///
689 /// The caller must ensure that the producer's shape is equal to the Zip's shape.
690 /// Uses assertions when debug assertions are enabled.
691 #[allow(unused)]
692 pub(crate) unsafe fn and_unchecked<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
693 where P: IntoNdProducer<Dim=D>,
694 {
695 #[cfg(debug_assertions)]
696 {
697 self.and(p)
698 }
699 #[cfg(not(debug_assertions))]
700 {
701 self.build_and(p.into_producer())
702 }
703 }
704
705 /// Include the producer `p` in the Zip, broadcasting if needed.
706 ///
707 /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
708 ///
709 /// ***Panics*** if broadcasting isn’t possible.
710 pub fn and_broadcast<'a, P, D2, Elem>(self, p: P)
711 -> Zip<($($p,)* ArrayView<'a, Elem, D>, ), D>
712 where P: IntoNdProducer<Dim=D2, Output=ArrayView<'a, Elem, D2>, Item=&'a Elem>,
713 D2: Dimension,
714 {
715 let part = p.into_producer().broadcast_unwrap(self.dimension.clone());
716 self.build_and(part)
717 }
718
719 fn build_and<P>(self, part: P) -> Zip<($($p,)* P, ), D>
720 where P: NdProducer<Dim=D>,
721 {
722 let part_layout = part.layout();
723 let ($($p,)*) = self.parts;
724 Zip {
725 parts: ($($p,)* part, ),
726 layout: self.layout.intersect(part_layout),
727 dimension: self.dimension,
728 layout_tendency: self.layout_tendency + part_layout.tendency(),
729 }
730 }
731
732 /// Map and collect the results into a new array, which has the same size as the
733 /// inputs.
734 ///
735 /// If all inputs are c- or f-order respectively, that is preserved in the output.
736 pub fn map_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D> {
737 self.map_collect_owned(f)
738 }
739
740 pub(crate) fn map_collect_owned<S, R>(self, f: impl FnMut($($p::Item,)* ) -> R)
741 -> ArrayBase<S, D>
742 where S: DataOwned<Elem = R>
743 {
744 // safe because: all elements are written before the array is completed
745
746 let shape = self.dimension.clone().set_f(self.prefer_f());
747 let output = <ArrayBase<S, D>>::build_uninit(shape, |output| {
748 // Use partial to count the number of filled elements, and can drop the right
749 // number of elements on unwinding (if it happens during apply/collect).
750 unsafe {
751 let output_view = output.into_raw_view_mut().cast::<R>();
752 self.and(output_view)
753 .collect_with_partial(f)
754 .release_ownership();
755 }
756 });
757 unsafe {
758 output.assume_init()
759 }
760 }
761
762 /// Map and collect the results into a new array, which has the same size as the
763 /// inputs.
764 ///
765 /// If all inputs are c- or f-order respectively, that is preserved in the output.
766 #[deprecated(note="Renamed to .map_collect()", since="0.15.0")]
767 pub fn apply_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D> {
768 self.map_collect(f)
769 }
770
771 /// Map and assign the results into the producer `into`, which should have the same
772 /// size as the other inputs.
773 ///
774 /// The producer should have assignable items as dictated by the `AssignElem` trait,
775 /// for example `&mut R`.
776 pub fn map_assign_into<R, Q>(self, into: Q, mut f: impl FnMut($($p::Item,)* ) -> R)
777 where Q: IntoNdProducer<Dim=D>,
778 Q::Item: AssignElem<R>
779 {
780 self.and(into)
781 .for_each(move |$($p, )* output_| {
782 output_.assign_elem(f($($p ),*));
783 });
784 }
785
786 /// Map and assign the results into the producer `into`, which should have the same
787 /// size as the other inputs.
788 ///
789 /// The producer should have assignable items as dictated by the `AssignElem` trait,
790 /// for example `&mut R`.
791 #[deprecated(note="Renamed to .map_assign_into()", since="0.15.0")]
792 pub fn apply_assign_into<R, Q>(self, into: Q, f: impl FnMut($($p::Item,)* ) -> R)
793 where Q: IntoNdProducer<Dim=D>,
794 Q::Item: AssignElem<R>
795 {
796 self.map_assign_into(into, f)
797 }
798
799
800 );
801
802 /// Split the `Zip` evenly in two.
803 ///
804 /// It will be split in the way that best preserves element locality.
805 pub fn split(self) -> (Self, Self) {
806 debug_assert_ne!(self.size(), 0, "Attempt to split empty zip");
807 debug_assert_ne!(self.size(), 1, "Attempt to split zip with 1 elem");
808 SplitPreference::split(self)
809 }
810 }
811
812 expand_if!(@bool [$notlast]
813 // For collect; Last producer is a RawViewMut
814 #[allow(non_snake_case)]
815 impl<D, PLast, R, $($p),*> Zip<($($p,)* PLast), D>
816 where D: Dimension,
817 $($p: NdProducer<Dim=D> ,)*
818 PLast: NdProducer<Dim = D, Item = *mut R, Ptr = *mut R, Stride = isize>,
819 {
820 /// The inner workings of map_collect and par_map_collect
821 ///
822 /// Apply the function and collect the results into the output (last producer)
823 /// which should be a raw array view; a Partial that owns the written
824 /// elements is returned.
825 ///
826 /// Elements will be overwritten in place (in the sense of std::ptr::write).
827 ///
828 /// ## Safety
829 ///
830 /// The last producer is a RawArrayViewMut and must be safe to write into.
831 /// The producer must be c- or f-contig and have the same layout tendency
832 /// as the whole Zip.
833 ///
834 /// The returned Partial's proxy ownership of the elements must be handled,
835 /// before the array the raw view points to realizes its ownership.
836 pub(crate) unsafe fn collect_with_partial<F>(self, mut f: F) -> Partial<R>
837 where F: FnMut($($p::Item,)* ) -> R
838 {
839 // Get the last producer; and make a Partial that aliases its data pointer
840 let (.., ref output) = &self.parts;
841
842 // debug assert that the output is contiguous in the memory layout we need
843 if cfg!(debug_assertions) {
844 let out_layout = output.layout();
845 assert!(out_layout.is(Layout::CORDER | Layout::FORDER));
846 assert!(
847 (self.layout_tendency <= 0 && out_layout.tendency() <= 0) ||
848 (self.layout_tendency >= 0 && out_layout.tendency() >= 0),
849 "layout tendency violation for self layout {:?}, output layout {:?},\
850 output shape {:?}",
851 self.layout, out_layout, output.raw_dim());
852 }
853
854 let mut partial = Partial::new(output.as_ptr());
855
856 // Apply the mapping function on this zip
857 // if we panic with unwinding; Partial will drop the written elements.
858 let partial_len = &mut partial.len;
859 self.for_each(move |$($p,)* output_elem: *mut R| {
860 output_elem.write(f($($p),*));
861 if std::mem::needs_drop::<R>() {
862 *partial_len += 1;
863 }
864 });
865
866 partial
867 }
868 }
869 );
870
871 impl<D, $($p),*> SplitPreference for Zip<($($p,)*), D>
872 where D: Dimension,
873 $($p: NdProducer<Dim=D> ,)*
874 {
875 fn can_split(&self) -> bool { self.size() > 1 }
876
877 fn split_preference(&self) -> (Axis, usize) {
878 // Always split in a way that preserves layout (if any)
879 let axis = self.max_stride_axis();
880 let index = self.len_of(axis) / 2;
881 (axis, index)
882 }
883 }
884
885 impl<D, $($p),*> SplitAt for Zip<($($p,)*), D>
886 where D: Dimension,
887 $($p: NdProducer<Dim=D> ,)*
888 {
889 fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
890 let (p1, p2) = self.parts.split_at(axis, index);
891 let (d1, d2) = self.dimension.split_at(axis, index);
892 (Zip {
893 dimension: d1,
894 layout: self.layout,
895 parts: p1,
896 layout_tendency: self.layout_tendency,
897 },
898 Zip {
899 dimension: d2,
900 layout: self.layout,
901 parts: p2,
902 layout_tendency: self.layout_tendency,
903 })
904 }
905
906 }
907
908 )+
909 }
910}
911
912map_impl! {
913 [true P1],
914 [true P1 P2],
915 [true P1 P2 P3],
916 [true P1 P2 P3 P4],
917 [true P1 P2 P3 P4 P5],
918 [false P1 P2 P3 P4 P5 P6],
919}
920
921/// Value controlling the execution of `.fold_while` on `Zip`.
922#[derive(Debug, Copy, Clone)]
923pub enum FoldWhile<T> {
924 /// Continue folding with this value
925 Continue(T),
926 /// Fold is complete and will return this value
927 Done(T),
928}
929
930impl<T> FoldWhile<T> {
931 /// Return the inner value
932 pub fn into_inner(self) -> T {
933 match self {
934 FoldWhile::Continue(x) | FoldWhile::Done(x) => x,
935 }
936 }
937
938 /// Return true if it is `Done`, false if `Continue`
939 pub fn is_done(&self) -> bool {
940 match *self {
941 FoldWhile::Continue(_) => false,
942 FoldWhile::Done(_) => true,
943 }
944 }
945}