ndarray/zip/
ndproducer.rs

1
2use crate::imp_prelude::*;
3use crate::Layout;
4use crate::NdIndex;
5#[cfg(not(features = "std"))]
6use alloc::vec::Vec;
7
8/// Argument conversion into a producer.
9///
10/// Slices and vectors can be used (equivalent to 1-dimensional array views).
11///
12/// This trait is like `IntoIterator` for `NdProducers` instead of iterators.
13pub trait IntoNdProducer {
14    /// The element produced per iteration.
15    type Item;
16    /// Dimension type of the producer
17    type Dim: Dimension;
18    type Output: NdProducer<Dim = Self::Dim, Item = Self::Item>;
19    /// Convert the value into an `NdProducer`.
20    fn into_producer(self) -> Self::Output;
21}
22
23impl<P> IntoNdProducer for P
24where
25    P: NdProducer,
26{
27    type Item = P::Item;
28    type Dim = P::Dim;
29    type Output = Self;
30    fn into_producer(self) -> Self::Output {
31        self
32    }
33}
34
35/// A producer of an n-dimensional set of elements;
36/// for example an array view, mutable array view or an iterator
37/// that yields chunks.
38///
39/// Producers are used as a arguments to [`Zip`](crate::Zip) and
40/// [`azip!()`].
41///
42/// # Comparison to `IntoIterator`
43///
44/// Most `NdProducers` are *iterable* (implement `IntoIterator`) but not directly
45/// iterators. This separation is needed because the producer represents
46/// a multidimensional set of items, it can be split along a particular axis for
47/// parallelization, and it has no fixed correspondence to a sequence.
48///
49/// The natural exception is one dimensional producers, like `AxisIter`, which
50/// implement `Iterator` directly
51/// (`AxisIter` traverses a one dimensional sequence, along an axis, while
52/// *producing* multidimensional items).
53///
54/// See also [`IntoNdProducer`]
55pub trait NdProducer {
56    /// The element produced per iteration.
57    type Item;
58    // Internal use / Pointee type
59    /// Dimension type
60    type Dim: Dimension;
61
62    // The pointer Ptr is used by an array view to simply point to the
63    // current element. It doesn't have to be a pointer (see Indices).
64    // Its main function is that it can be incremented with a particular
65    // stride (= along a particular axis)
66    #[doc(hidden)]
67    /// Pointer or stand-in for pointer
68    type Ptr: Offset<Stride = Self::Stride>;
69    #[doc(hidden)]
70    /// Pointer stride
71    type Stride: Copy;
72
73    #[doc(hidden)]
74    fn layout(&self) -> Layout;
75    /// Return the shape of the producer.
76    fn raw_dim(&self) -> Self::Dim;
77    #[doc(hidden)]
78    fn equal_dim(&self, dim: &Self::Dim) -> bool {
79        self.raw_dim() == *dim
80    }
81    #[doc(hidden)]
82    fn as_ptr(&self) -> Self::Ptr;
83    #[doc(hidden)]
84    unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item;
85    #[doc(hidden)]
86    unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr;
87    #[doc(hidden)]
88    fn stride_of(&self, axis: Axis) -> <Self::Ptr as Offset>::Stride;
89    #[doc(hidden)]
90    fn contiguous_stride(&self) -> Self::Stride;
91    #[doc(hidden)]
92    fn split_at(self, axis: Axis, index: usize) -> (Self, Self)
93    where
94        Self: Sized;
95
96    private_decl! {}
97}
98
99pub trait Offset: Copy {
100    type Stride: Copy;
101    unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self;
102    private_decl! {}
103}
104
105impl<T> Offset for *const T {
106    type Stride = isize;
107    unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self {
108        self.offset(s * (index as isize))
109    }
110    private_impl! {}
111}
112
113impl<T> Offset for *mut T {
114    type Stride = isize;
115    unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self {
116        self.offset(s * (index as isize))
117    }
118    private_impl! {}
119}
120
121/// An array reference is an n-dimensional producer of element references
122/// (like ArrayView).
123impl<'a, A: 'a, S, D> IntoNdProducer for &'a ArrayBase<S, D>
124where
125    D: Dimension,
126    S: Data<Elem = A>,
127{
128    type Item = &'a A;
129    type Dim = D;
130    type Output = ArrayView<'a, A, D>;
131    fn into_producer(self) -> Self::Output {
132        self.view()
133    }
134}
135
136/// A mutable array reference is an n-dimensional producer of mutable element
137/// references (like ArrayViewMut).
138impl<'a, A: 'a, S, D> IntoNdProducer for &'a mut ArrayBase<S, D>
139where
140    D: Dimension,
141    S: DataMut<Elem = A>,
142{
143    type Item = &'a mut A;
144    type Dim = D;
145    type Output = ArrayViewMut<'a, A, D>;
146    fn into_producer(self) -> Self::Output {
147        self.view_mut()
148    }
149}
150
151/// A slice is a one-dimensional producer
152impl<'a, A: 'a> IntoNdProducer for &'a [A] {
153    type Item = <Self::Output as NdProducer>::Item;
154    type Dim = Ix1;
155    type Output = ArrayView1<'a, A>;
156    fn into_producer(self) -> Self::Output {
157        <_>::from(self)
158    }
159}
160
161/// A mutable slice is a mutable one-dimensional producer
162impl<'a, A: 'a> IntoNdProducer for &'a mut [A] {
163    type Item = <Self::Output as NdProducer>::Item;
164    type Dim = Ix1;
165    type Output = ArrayViewMut1<'a, A>;
166    fn into_producer(self) -> Self::Output {
167        <_>::from(self)
168    }
169}
170
171/// A Vec is a one-dimensional producer
172impl<'a, A: 'a> IntoNdProducer for &'a Vec<A> {
173    type Item = <Self::Output as NdProducer>::Item;
174    type Dim = Ix1;
175    type Output = ArrayView1<'a, A>;
176    fn into_producer(self) -> Self::Output {
177        <_>::from(self)
178    }
179}
180
181/// A mutable Vec is a mutable one-dimensional producer
182impl<'a, A: 'a> IntoNdProducer for &'a mut Vec<A> {
183    type Item = <Self::Output as NdProducer>::Item;
184    type Dim = Ix1;
185    type Output = ArrayViewMut1<'a, A>;
186    fn into_producer(self) -> Self::Output {
187        <_>::from(self)
188    }
189}
190
191impl<'a, A, D: Dimension> NdProducer for ArrayView<'a, A, D> {
192    type Item = &'a A;
193    type Dim = D;
194    type Ptr = *mut A;
195    type Stride = isize;
196
197    private_impl! {}
198
199    fn raw_dim(&self) -> Self::Dim {
200        self.raw_dim()
201    }
202
203    fn equal_dim(&self, dim: &Self::Dim) -> bool {
204        self.dim.equal(dim)
205    }
206
207    fn as_ptr(&self) -> *mut A {
208        self.as_ptr() as _
209    }
210
211    fn layout(&self) -> Layout {
212        self.layout_impl()
213    }
214
215    unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item {
216        &*ptr
217    }
218
219    unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A {
220        self.ptr.as_ptr().offset(i.index_unchecked(&self.strides))
221    }
222
223    fn stride_of(&self, axis: Axis) -> isize {
224        self.stride_of(axis)
225    }
226
227    #[inline(always)]
228    fn contiguous_stride(&self) -> Self::Stride {
229        1
230    }
231
232    fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
233        self.split_at(axis, index)
234    }
235}
236
237impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> {
238    type Item = &'a mut A;
239    type Dim = D;
240    type Ptr = *mut A;
241    type Stride = isize;
242
243    private_impl! {}
244
245    fn raw_dim(&self) -> Self::Dim {
246        self.raw_dim()
247    }
248
249    fn equal_dim(&self, dim: &Self::Dim) -> bool {
250        self.dim.equal(dim)
251    }
252
253    fn as_ptr(&self) -> *mut A {
254        self.as_ptr() as _
255    }
256
257    fn layout(&self) -> Layout {
258        self.layout_impl()
259    }
260
261    unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item {
262        &mut *ptr
263    }
264
265    unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A {
266        self.ptr.as_ptr().offset(i.index_unchecked(&self.strides))
267    }
268
269    fn stride_of(&self, axis: Axis) -> isize {
270        self.stride_of(axis)
271    }
272
273    #[inline(always)]
274    fn contiguous_stride(&self) -> Self::Stride {
275        1
276    }
277
278    fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
279        self.split_at(axis, index)
280    }
281}
282
283impl<A, D: Dimension> NdProducer for RawArrayView<A, D> {
284    type Item = *const A;
285    type Dim = D;
286    type Ptr = *const A;
287    type Stride = isize;
288
289    private_impl! {}
290
291    fn raw_dim(&self) -> Self::Dim {
292        self.raw_dim()
293    }
294
295    fn equal_dim(&self, dim: &Self::Dim) -> bool {
296        self.dim.equal(dim)
297    }
298
299    fn as_ptr(&self) -> *const A {
300        self.as_ptr()
301    }
302
303    fn layout(&self) -> Layout {
304        self.layout_impl()
305    }
306
307    unsafe fn as_ref(&self, ptr: *const A) -> *const A {
308        ptr
309    }
310
311    unsafe fn uget_ptr(&self, i: &Self::Dim) -> *const A {
312        self.ptr.as_ptr().offset(i.index_unchecked(&self.strides))
313    }
314
315    fn stride_of(&self, axis: Axis) -> isize {
316        self.stride_of(axis)
317    }
318
319    #[inline(always)]
320    fn contiguous_stride(&self) -> Self::Stride {
321        1
322    }
323
324    fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
325        self.split_at(axis, index)
326    }
327}
328
329impl<A, D: Dimension> NdProducer for RawArrayViewMut<A, D> {
330    type Item = *mut A;
331    type Dim = D;
332    type Ptr = *mut A;
333    type Stride = isize;
334
335    private_impl! {}
336
337    fn raw_dim(&self) -> Self::Dim {
338        self.raw_dim()
339    }
340
341    fn equal_dim(&self, dim: &Self::Dim) -> bool {
342        self.dim.equal(dim)
343    }
344
345    fn as_ptr(&self) -> *mut A {
346        self.as_ptr() as _
347    }
348
349    fn layout(&self) -> Layout {
350        self.layout_impl()
351    }
352
353    unsafe fn as_ref(&self, ptr: *mut A) -> *mut A {
354        ptr
355    }
356
357    unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A {
358        self.ptr.as_ptr().offset(i.index_unchecked(&self.strides))
359    }
360
361    fn stride_of(&self, axis: Axis) -> isize {
362        self.stride_of(axis)
363    }
364
365    #[inline(always)]
366    fn contiguous_stride(&self) -> Self::Stride {
367        1
368    }
369
370    fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
371        self.split_at(axis, index)
372    }
373}
374