ndarray/dimension/
mod.rs

1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::error::{from_kind, ErrorKind, ShapeError};
10use crate::slice::SliceArg;
11use crate::{Ix, Ixs, Slice, SliceInfoElem};
12use crate::shape_builder::Strides;
13use num_integer::div_floor;
14
15pub use self::axes::{Axes, AxisDescription};
16pub use self::axis::Axis;
17pub use self::broadcast::DimMax;
18pub use self::conversion::IntoDimension;
19pub use self::dim::*;
20pub use self::dimension_trait::Dimension;
21pub use self::dynindeximpl::IxDynImpl;
22pub use self::ndindex::NdIndex;
23pub use self::ops::DimAdd;
24pub use self::remove_axis::RemoveAxis;
25
26pub(crate) use self::axes::axes_of;
27pub(crate) use self::reshape::reshape_dim;
28
29use std::isize;
30use std::mem;
31
32#[macro_use]
33mod macros;
34mod axes;
35mod axis;
36pub(crate) mod broadcast;
37mod conversion;
38pub mod dim;
39mod dimension_trait;
40mod dynindeximpl;
41mod ndindex;
42mod ops;
43mod remove_axis;
44pub(crate) mod reshape;
45mod sequence;
46
47/// Calculate offset from `Ix` stride converting sign properly
48#[inline(always)]
49pub fn stride_offset(n: Ix, stride: Ix) -> isize {
50    (n as isize) * ((stride as Ixs) as isize)
51}
52
53/// Check whether the given `dim` and `stride` lead to overlapping indices
54///
55/// There is overlap if, when iterating through the dimensions in order of
56/// increasing stride, the current stride is less than or equal to the maximum
57/// possible offset along the preceding axes. (Axes of length ≤1 are ignored.)
58pub fn dim_stride_overlap<D: Dimension>(dim: &D, strides: &D) -> bool {
59    let order = strides._fastest_varying_stride_order();
60    let mut sum_prev_offsets = 0;
61    for &index in order.slice() {
62        let d = dim[index];
63        let s = (strides[index] as isize).abs();
64        match d {
65            0 => return false,
66            1 => {}
67            _ => {
68                if s <= sum_prev_offsets {
69                    return true;
70                }
71                sum_prev_offsets += (d - 1) as isize * s;
72            }
73        }
74    }
75    false
76}
77
78/// Returns the `size` of the `dim`, checking that the product of non-zero axis
79/// lengths does not exceed `isize::MAX`.
80///
81/// If `size_of_checked_shape(dim)` returns `Ok(size)`, the data buffer is a
82/// slice or `Vec` of length `size`, and `strides` are created with
83/// `self.default_strides()` or `self.fortran_strides()`, then the invariants
84/// are met to construct an array from the data buffer, `dim`, and `strides`.
85/// (The data buffer being a slice or `Vec` guarantees that it contains no more
86/// than `isize::MAX` bytes.)
87pub fn size_of_shape_checked<D: Dimension>(dim: &D) -> Result<usize, ShapeError> {
88    let size_nonzero = dim
89        .slice()
90        .iter()
91        .filter(|&&d| d != 0)
92        .try_fold(1usize, |acc, &d| acc.checked_mul(d))
93        .ok_or_else(|| from_kind(ErrorKind::Overflow))?;
94    if size_nonzero > ::std::isize::MAX as usize {
95        Err(from_kind(ErrorKind::Overflow))
96    } else {
97        Ok(dim.size())
98    }
99}
100
101/// Checks whether the given data and dimension meet the invariants of the
102/// `ArrayBase` type, assuming the strides are created using
103/// `dim.default_strides()` or `dim.fortran_strides()`.
104///
105/// To meet the invariants,
106///
107/// 1. The product of non-zero axis lengths must not exceed `isize::MAX`.
108///
109/// 2. The result of `dim.size()` (assuming no overflow) must be less than or
110///    equal to the length of the slice.
111///
112///    (Since `dim.default_strides()` and `dim.fortran_strides()` always return
113///    contiguous strides for non-empty arrays, this ensures that for non-empty
114///    arrays the difference between the least address and greatest address
115///    accessible by moving along all axes is < the length of the slice. Since
116///    `dim.default_strides()` and `dim.fortran_strides()` always return all
117///    zero strides for empty arrays, this ensures that for empty arrays the
118///    difference between the least address and greatest address accessible by
119///    moving along all axes is ≤ the length of the slice.)
120///
121/// Note that since slices cannot contain more than `isize::MAX` bytes,
122/// conditions 1 and 2 are sufficient to guarantee that the offset in units of
123/// `A` and in units of bytes between the least address and greatest address
124/// accessible by moving along all axes does not exceed `isize::MAX`.
125pub(crate) fn can_index_slice_with_strides<A, D: Dimension>(data: &[A], dim: &D,
126                                                            strides: &Strides<D>)
127    -> Result<(), ShapeError>
128{
129    if let Strides::Custom(strides) = strides {
130        can_index_slice(data, dim, strides)
131    } else {
132        can_index_slice_not_custom(data.len(), dim)
133    }
134}
135
136pub(crate) fn can_index_slice_not_custom<D: Dimension>(data_len: usize, dim: &D)
137    -> Result<(), ShapeError>
138{
139    // Condition 1.
140    let len = size_of_shape_checked(dim)?;
141    // Condition 2.
142    if len > data_len {
143        return Err(from_kind(ErrorKind::OutOfBounds));
144    }
145    Ok(())
146}
147
148/// Returns the absolute difference in units of `A` between least and greatest
149/// address accessible by moving along all axes.
150///
151/// Returns `Ok` only if
152///
153/// 1. The ndim of `dim` and `strides` is the same.
154///
155/// 2. The absolute difference in units of `A` and in units of bytes between
156///    the least address and greatest address accessible by moving along all axes
157///    does not exceed `isize::MAX`.
158///
159/// 3. The product of non-zero axis lengths does not exceed `isize::MAX`. (This
160///    also implies that the length of any individual axis does not exceed
161///    `isize::MAX`.)
162pub fn max_abs_offset_check_overflow<A, D>(dim: &D, strides: &D) -> Result<usize, ShapeError>
163where
164    D: Dimension,
165{
166    max_abs_offset_check_overflow_impl(mem::size_of::<A>(), dim, strides)
167}
168
169fn max_abs_offset_check_overflow_impl<D>(elem_size: usize, dim: &D, strides: &D)
170    -> Result<usize, ShapeError>
171where
172    D: Dimension,
173{
174    // Condition 1.
175    if dim.ndim() != strides.ndim() {
176        return Err(from_kind(ErrorKind::IncompatibleLayout));
177    }
178
179    // Condition 3.
180    let _ = size_of_shape_checked(dim)?;
181
182    // Determine absolute difference in units of `A` between least and greatest
183    // address accessible by moving along all axes.
184    let max_offset: usize = izip!(dim.slice(), strides.slice())
185        .try_fold(0usize, |acc, (&d, &s)| {
186            let s = s as isize;
187            // Calculate maximum possible absolute movement along this axis.
188            let off = d.saturating_sub(1).checked_mul(s.abs() as usize)?;
189            acc.checked_add(off)
190        })
191        .ok_or_else(|| from_kind(ErrorKind::Overflow))?;
192    // Condition 2a.
193    if max_offset > isize::MAX as usize {
194        return Err(from_kind(ErrorKind::Overflow));
195    }
196
197    // Determine absolute difference in units of bytes between least and
198    // greatest address accessible by moving along all axes
199    let max_offset_bytes = max_offset
200        .checked_mul(elem_size)
201        .ok_or_else(|| from_kind(ErrorKind::Overflow))?;
202    // Condition 2b.
203    if max_offset_bytes > isize::MAX as usize {
204        return Err(from_kind(ErrorKind::Overflow));
205    }
206
207    Ok(max_offset)
208}
209
210/// Checks whether the given data, dimension, and strides meet the invariants
211/// of the `ArrayBase` type (except for checking ownership of the data).
212///
213/// To meet the invariants,
214///
215/// 1. The ndim of `dim` and `strides` must be the same.
216///
217/// 2. The product of non-zero axis lengths must not exceed `isize::MAX`.
218///
219/// 3. If the array will be empty (any axes are zero-length), the difference
220///    between the least address and greatest address accessible by moving
221///    along all axes must be ≤ `data.len()`. (It's fine in this case to move
222///    one byte past the end of the slice since the pointers will be offset but
223///    never dereferenced.)
224///
225///    If the array will not be empty, the difference between the least address
226///    and greatest address accessible by moving along all axes must be <
227///    `data.len()`. This and #3 ensure that all dereferenceable pointers point
228///    to elements within the slice.
229///
230/// 4. The strides must not allow any element to be referenced by two different
231///    indices.
232///
233/// Note that since slices cannot contain more than `isize::MAX` bytes,
234/// condition 4 is sufficient to guarantee that the absolute difference in
235/// units of `A` and in units of bytes between the least address and greatest
236/// address accessible by moving along all axes does not exceed `isize::MAX`.
237///
238/// Warning: This function is sufficient to check the invariants of ArrayBase
239/// only if the pointer to the first element of the array is chosen such that
240/// the element with the smallest memory address is at the start of the
241/// allocation. (In other words, the pointer to the first element of the array
242/// must be computed using `offset_from_low_addr_ptr_to_logical_ptr` so that
243/// negative strides are correctly handled.)
244pub(crate) fn can_index_slice<A, D: Dimension>(
245    data: &[A],
246    dim: &D,
247    strides: &D,
248) -> Result<(), ShapeError> {
249    // Check conditions 1 and 2 and calculate `max_offset`.
250    let max_offset = max_abs_offset_check_overflow::<A, _>(dim, strides)?;
251    can_index_slice_impl(max_offset, data.len(), dim, strides)
252}
253
254fn can_index_slice_impl<D: Dimension>(
255    max_offset: usize,
256    data_len: usize,
257    dim: &D,
258    strides: &D,
259) -> Result<(), ShapeError> {
260    // Check condition 3.
261    let is_empty = dim.slice().iter().any(|&d| d == 0);
262    if is_empty && max_offset > data_len {
263        return Err(from_kind(ErrorKind::OutOfBounds));
264    }
265    if !is_empty && max_offset >= data_len {
266        return Err(from_kind(ErrorKind::OutOfBounds));
267    }
268
269    // Check condition 4.
270    if !is_empty && dim_stride_overlap(dim, strides) {
271        return Err(from_kind(ErrorKind::Unsupported));
272    }
273
274    Ok(())
275}
276
277/// Stride offset checked general version (slices)
278#[inline]
279pub fn stride_offset_checked(dim: &[Ix], strides: &[Ix], index: &[Ix]) -> Option<isize> {
280    if index.len() != dim.len() {
281        return None;
282    }
283    let mut offset = 0;
284    for (&d, &i, &s) in izip!(dim, index, strides) {
285        if i >= d {
286            return None;
287        }
288        offset += stride_offset(i, s);
289    }
290    Some(offset)
291}
292
293/// Checks if strides are non-negative.
294pub fn strides_non_negative<D>(strides: &D) -> Result<(), ShapeError>
295where
296    D: Dimension,
297{
298    for &stride in strides.slice() {
299        if (stride as isize) < 0 {
300            return Err(from_kind(ErrorKind::Unsupported));
301        }
302    }
303    Ok(())
304}
305
306/// Implementation-specific extensions to `Dimension`
307pub trait DimensionExt {
308    // note: many extensions go in the main trait if they need to be special-
309    // cased per dimension
310    /// Get the dimension at `axis`.
311    ///
312    /// *Panics* if `axis` is out of bounds.
313    fn axis(&self, axis: Axis) -> Ix;
314
315    /// Set the dimension at `axis`.
316    ///
317    /// *Panics* if `axis` is out of bounds.
318    fn set_axis(&mut self, axis: Axis, value: Ix);
319}
320
321impl<D> DimensionExt for D
322where
323    D: Dimension,
324{
325    #[inline]
326    fn axis(&self, axis: Axis) -> Ix {
327        self[axis.index()]
328    }
329
330    #[inline]
331    fn set_axis(&mut self, axis: Axis, value: Ix) {
332        self[axis.index()] = value;
333    }
334}
335
336impl DimensionExt for [Ix] {
337    #[inline]
338    fn axis(&self, axis: Axis) -> Ix {
339        self[axis.index()]
340    }
341
342    #[inline]
343    fn set_axis(&mut self, axis: Axis, value: Ix) {
344        self[axis.index()] = value;
345    }
346}
347
348/// Collapse axis `axis` and shift so that only subarray `index` is
349/// available.
350///
351/// **Panics** if `index` is larger than the size of the axis
352// FIXME: Move to Dimension trait
353pub fn do_collapse_axis<D: Dimension>(
354    dims: &mut D,
355    strides: &D,
356    axis: usize,
357    index: usize,
358) -> isize {
359    let dim = dims.slice()[axis];
360    let stride = strides.slice()[axis];
361    ndassert!(
362        index < dim,
363        "collapse_axis: Index {} must be less than axis length {} for \
364         array with shape {:?}",
365        index,
366        dim,
367        *dims
368    );
369    dims.slice_mut()[axis] = 1;
370    stride_offset(index, stride)
371}
372
373/// Compute the equivalent unsigned index given the axis length and signed index.
374#[inline]
375pub fn abs_index(len: Ix, index: Ixs) -> Ix {
376    if index < 0 {
377        len - (-index as Ix)
378    } else {
379        index as Ix
380    }
381}
382
383/// Determines nonnegative start and end indices, and performs sanity checks.
384///
385/// The return value is (start, end, step).
386///
387/// **Panics** if stride is 0 or if any index is out of bounds.
388fn to_abs_slice(axis_len: usize, slice: Slice) -> (usize, usize, isize) {
389    let Slice { start, end, step } = slice;
390    let start = abs_index(axis_len, start);
391    let mut end = abs_index(axis_len, end.unwrap_or(axis_len as isize));
392    if end < start {
393        end = start;
394    }
395    ndassert!(
396        start <= axis_len,
397        "Slice begin {} is past end of axis of length {}",
398        start,
399        axis_len,
400    );
401    ndassert!(
402        end <= axis_len,
403        "Slice end {} is past end of axis of length {}",
404        end,
405        axis_len,
406    );
407    ndassert!(step != 0, "Slice stride must not be zero");
408    (start, end, step)
409}
410
411/// Returns the offset from the lowest-address element to the logically first
412/// element.
413pub fn offset_from_low_addr_ptr_to_logical_ptr<D: Dimension>(dim: &D, strides: &D) -> usize {
414    let offset = izip!(dim.slice(), strides.slice()).fold(0, |_offset, (&d, &s)| {
415        let s = s as isize;
416        if s < 0 && d > 1 {
417            _offset - s * (d as isize - 1)
418        } else {
419            _offset
420        }
421    });
422    debug_assert!(offset >= 0);
423    offset as usize
424}
425
426/// Modify dimension, stride and return data pointer offset
427///
428/// **Panics** if stride is 0 or if any index is out of bounds.
429pub fn do_slice(dim: &mut usize, stride: &mut usize, slice: Slice) -> isize {
430    let (start, end, step) = to_abs_slice(*dim, slice);
431
432    let m = end - start;
433    let s = (*stride) as isize;
434
435    // Compute data pointer offset.
436    let offset = if m == 0 {
437        // In this case, the resulting array is empty, so we *can* avoid performing a nonzero
438        // offset.
439        //
440        // In two special cases (which are the true reason for this `m == 0` check), we *must* avoid
441        // the nonzero offset corresponding to the general case.
442        //
443        // * When `end == 0 && step < 0`. (These conditions imply that `m == 0` since `to_abs_slice`
444        //   ensures that `0 <= start <= end`.) We cannot execute `stride_offset(end - 1, *stride)`
445        //   because the `end - 1` would underflow.
446        //
447        // * When `start == *dim && step > 0`. (These conditions imply that `m == 0` since
448        //   `to_abs_slice` ensures that `start <= end <= *dim`.) We cannot use the offset returned
449        //   by `stride_offset(start, *stride)` because that would be past the end of the axis.
450        0
451    } else if step < 0 {
452        // When the step is negative, the new first element is `end - 1`, not `start`, since the
453        // direction is reversed.
454        stride_offset(end - 1, *stride)
455    } else {
456        stride_offset(start, *stride)
457    };
458
459    // Update dimension.
460    let abs_step = step.abs() as usize;
461    *dim = if abs_step == 1 {
462        m
463    } else {
464        let d = m / abs_step;
465        let r = m % abs_step;
466        d + if r > 0 { 1 } else { 0 }
467    };
468
469    // Update stride. The additional check is necessary to avoid possible
470    // overflow in the multiplication.
471    *stride = if *dim <= 1 { 0 } else { (s * step) as usize };
472
473    offset
474}
475
476/// Solves `a * x + b * y = gcd(a, b)` for `x`, `y`, and `gcd(a, b)`.
477///
478/// Returns `(g, (x, y))`, where `g` is `gcd(a, b)`, and `g` is always
479/// nonnegative.
480///
481/// See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm
482fn extended_gcd(a: isize, b: isize) -> (isize, (isize, isize)) {
483    if a == 0 {
484        (b.abs(), (0, b.signum()))
485    } else if b == 0 {
486        (a.abs(), (a.signum(), 0))
487    } else {
488        let mut r = (a, b);
489        let mut s = (1, 0);
490        let mut t = (0, 1);
491        while r.1 != 0 {
492            let q = r.0 / r.1;
493            r = (r.1, r.0 - q * r.1);
494            s = (s.1, s.0 - q * s.1);
495            t = (t.1, t.0 - q * t.1);
496        }
497        if r.0 > 0 {
498            (r.0, (s.0, t.0))
499        } else {
500            (-r.0, (-s.0, -t.0))
501        }
502    }
503}
504
505/// Solves `a * x + b * y = c` for `x` where `a`, `b`, `c`, `x`, and `y` are
506/// integers.
507///
508/// If the return value is `Some((x0, xd))`, there is a solution. `xd` is
509/// always positive. Solutions `x` are given by `x0 + xd * t` where `t` is any
510/// integer. The value of `y` for any `x` is then `y = (c - a * x) / b`.
511///
512/// If the return value is `None`, no solutions exist.
513///
514/// **Note** `a` and `b` must be nonzero.
515///
516/// See https://en.wikipedia.org/wiki/Diophantine_equation#One_equation
517/// and https://math.stackexchange.com/questions/1656120#1656138
518fn solve_linear_diophantine_eq(a: isize, b: isize, c: isize) -> Option<(isize, isize)> {
519    debug_assert_ne!(a, 0);
520    debug_assert_ne!(b, 0);
521    let (g, (u, _)) = extended_gcd(a, b);
522    if c % g == 0 {
523        Some((c / g * u, (b / g).abs()))
524    } else {
525        None
526    }
527}
528
529/// Returns `true` if two (finite length) arithmetic sequences intersect.
530///
531/// `min*` and `max*` are the (inclusive) bounds of the sequences, and they
532/// must be elements in the sequences. `step*` are the steps between
533/// consecutive elements (the sign is irrelevant).
534///
535/// **Note** `step1` and `step2` must be nonzero.
536fn arith_seq_intersect(
537    (min1, max1, step1): (isize, isize, isize),
538    (min2, max2, step2): (isize, isize, isize),
539) -> bool {
540    debug_assert!(max1 >= min1);
541    debug_assert!(max2 >= min2);
542    debug_assert_eq!((max1 - min1) % step1, 0);
543    debug_assert_eq!((max2 - min2) % step2, 0);
544
545    // Handle the easy case where we don't have to solve anything.
546    if min1 > max2 || min2 > max1 {
547        false
548    } else {
549        // The sign doesn't matter semantically, and it's mathematically convenient
550        // for `step1` and `step2` to be positive.
551        let step1 = step1.abs();
552        let step2 = step2.abs();
553        // Ignoring the min/max bounds, the sequences are
554        //   a(x) = min1 + step1 * x
555        //   b(y) = min2 + step2 * y
556        //
557        // For intersections a(x) = b(y), we have:
558        //   min1 + step1 * x = min2 + step2 * y
559        //   ⇒ -step1 * x + step2 * y = min1 - min2
560        // which is a linear Diophantine equation.
561        if let Some((x0, xd)) = solve_linear_diophantine_eq(-step1, step2, min1 - min2) {
562            // Minimum of [min1, max1] ∩ [min2, max2]
563            let min = ::std::cmp::max(min1, min2);
564            // Maximum of [min1, max1] ∩ [min2, max2]
565            let max = ::std::cmp::min(max1, max2);
566            // The potential intersections are
567            //   a(x) = min1 + step1 * (x0 + xd * t)
568            // where `t` is any integer.
569            //
570            // There is an intersection in `[min, max]` if there exists an
571            // integer `t` such that
572            //   min ≤ a(x) ≤ max
573            //   ⇒ min ≤ min1 + step1 * (x0 + xd * t) ≤ max
574            //   ⇒ min ≤ min1 + step1 * x0 + step1 * xd * t ≤ max
575            //   ⇒ min - min1 - step1 * x0 ≤ (step1 * xd) * t ≤ max - min1 - step1 * x0
576            //
577            // Therefore, the least possible intersection `a(x)` that is ≥ `min` has
578            //   t = ⌈(min - min1 - step1 * x0) / (step1 * xd)⌉
579            // If this `a(x) is also ≤ `max`, then there is an intersection in `[min, max]`.
580            //
581            // The greatest possible intersection `a(x)` that is ≤ `max` has
582            //   t = ⌊(max - min1 - step1 * x0) / (step1 * xd)⌋
583            // If this `a(x) is also ≥ `min`, then there is an intersection in `[min, max]`.
584            min1 + step1 * (x0 - xd * div_floor(min - min1 - step1 * x0, -step1 * xd)) <= max
585                || min1 + step1 * (x0 + xd * div_floor(max - min1 - step1 * x0, step1 * xd)) >= min
586        } else {
587            false
588        }
589    }
590}
591
592/// Returns the minimum and maximum values of the indices (inclusive).
593///
594/// If the slice is empty, then returns `None`, otherwise returns `Some((min, max))`.
595fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> {
596    let (start, end, step) = to_abs_slice(axis_len, slice);
597    if start == end {
598        None
599    } else if step > 0 {
600        Some((start, end - 1 - (end - start - 1) % (step as usize)))
601    } else {
602        Some((start + (end - start - 1) % (-step as usize), end - 1))
603    }
604}
605
606/// Returns `true` iff the slices intersect.
607pub fn slices_intersect<D: Dimension>(
608    dim: &D,
609    indices1: impl SliceArg<D>,
610    indices2: impl SliceArg<D>,
611) -> bool {
612    debug_assert_eq!(indices1.in_ndim(), indices2.in_ndim());
613    for (&axis_len, &si1, &si2) in izip!(
614        dim.slice(),
615        indices1.as_ref().iter().filter(|si| !si.is_new_axis()),
616        indices2.as_ref().iter().filter(|si| !si.is_new_axis()),
617    ) {
618        // The slices do not intersect iff any pair of `SliceInfoElem` does not intersect.
619        match (si1, si2) {
620            (
621                SliceInfoElem::Slice {
622                    start: start1,
623                    end: end1,
624                    step: step1,
625                },
626                SliceInfoElem::Slice {
627                    start: start2,
628                    end: end2,
629                    step: step2,
630                },
631            ) => {
632                let (min1, max1) = match slice_min_max(axis_len, Slice::new(start1, end1, step1)) {
633                    Some(m) => m,
634                    None => return false,
635                };
636                let (min2, max2) = match slice_min_max(axis_len, Slice::new(start2, end2, step2)) {
637                    Some(m) => m,
638                    None => return false,
639                };
640                if !arith_seq_intersect(
641                    (min1 as isize, max1 as isize, step1),
642                    (min2 as isize, max2 as isize, step2),
643                ) {
644                    return false;
645                }
646            }
647            (SliceInfoElem::Slice { start, end, step }, SliceInfoElem::Index(ind))
648            | (SliceInfoElem::Index(ind), SliceInfoElem::Slice { start, end, step }) => {
649                let ind = abs_index(axis_len, ind);
650                let (min, max) = match slice_min_max(axis_len, Slice::new(start, end, step)) {
651                    Some(m) => m,
652                    None => return false,
653                };
654                if ind < min || ind > max || (ind - min) % step.abs() as usize != 0 {
655                    return false;
656                }
657            }
658            (SliceInfoElem::Index(ind1), SliceInfoElem::Index(ind2)) => {
659                let ind1 = abs_index(axis_len, ind1);
660                let ind2 = abs_index(axis_len, ind2);
661                if ind1 != ind2 {
662                    return false;
663                }
664            }
665            (SliceInfoElem::NewAxis, _) | (_, SliceInfoElem::NewAxis) => unreachable!(),
666        }
667    }
668    true
669}
670
671pub(crate) fn is_layout_c<D: Dimension>(dim: &D, strides: &D) -> bool {
672    if let Some(1) = D::NDIM {
673        return strides[0] == 1 || dim[0] <= 1;
674    }
675
676    for &d in dim.slice() {
677        if d == 0 {
678            return true;
679        }
680    }
681
682    let mut contig_stride = 1_isize;
683    // check all dimensions -- a dimension of length 1 can have unequal strides
684    for (&dim, &s) in izip!(dim.slice().iter().rev(), strides.slice().iter().rev()) {
685        if dim != 1 {
686            let s = s as isize;
687            if s != contig_stride {
688                return false;
689            }
690            contig_stride *= dim as isize;
691        }
692    }
693    true
694}
695
696pub(crate) fn is_layout_f<D: Dimension>(dim: &D, strides: &D) -> bool {
697    if let Some(1) = D::NDIM {
698        return strides[0] == 1 || dim[0] <= 1;
699    }
700
701    for &d in dim.slice() {
702        if d == 0 {
703            return true;
704        }
705    }
706
707    let mut contig_stride = 1_isize;
708    // check all dimensions -- a dimension of length 1 can have unequal strides
709    for (&dim, &s) in izip!(dim.slice(), strides.slice()) {
710        if dim != 1 {
711            let s = s as isize;
712            if s != contig_stride {
713                return false;
714            }
715            contig_stride *= dim as isize;
716        }
717    }
718    true
719}
720
721pub fn merge_axes<D>(dim: &mut D, strides: &mut D, take: Axis, into: Axis) -> bool
722where
723    D: Dimension,
724{
725    let into_len = dim.axis(into);
726    let into_stride = strides.axis(into) as isize;
727    let take_len = dim.axis(take);
728    let take_stride = strides.axis(take) as isize;
729    let merged_len = into_len * take_len;
730    if take_len <= 1 {
731        dim.set_axis(into, merged_len);
732        dim.set_axis(take, if merged_len == 0 { 0 } else { 1 });
733        true
734    } else if into_len <= 1 {
735        strides.set_axis(into, take_stride as usize);
736        dim.set_axis(into, merged_len);
737        dim.set_axis(take, if merged_len == 0 { 0 } else { 1 });
738        true
739    } else if take_stride == into_len as isize * into_stride {
740        dim.set_axis(into, merged_len);
741        dim.set_axis(take, 1);
742        true
743    } else {
744        false
745    }
746}
747
748/// Move the axis which has the smallest absolute stride and a length
749/// greater than one to be the last axis.
750pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
751where
752    D: Dimension,
753{
754    debug_assert_eq!(dim.ndim(), strides.ndim());
755    match dim.ndim() {
756        0 | 1 => {}
757        2 => {
758            if dim[1] <= 1
759                || dim[0] > 1 && (strides[0] as isize).abs() < (strides[1] as isize).abs()
760            {
761                dim.slice_mut().swap(0, 1);
762                strides.slice_mut().swap(0, 1);
763            }
764        }
765        n => {
766            if let Some(min_stride_axis) = (0..n)
767                .filter(|&ax| dim[ax] > 1)
768                .min_by_key(|&ax| (strides[ax] as isize).abs())
769            {
770                let last = n - 1;
771                dim.slice_mut().swap(last, min_stride_axis);
772                strides.slice_mut().swap(last, min_stride_axis);
773            }
774        }
775    }
776}
777
778#[cfg(test)]
779mod test {
780    use super::{
781        arith_seq_intersect, can_index_slice, can_index_slice_not_custom, extended_gcd,
782        max_abs_offset_check_overflow, slice_min_max, slices_intersect,
783        solve_linear_diophantine_eq, IntoDimension,
784    };
785    use crate::error::{from_kind, ErrorKind};
786    use crate::slice::Slice;
787    use crate::{Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn, NewAxis};
788    use num_integer::gcd;
789    use quickcheck::{quickcheck, TestResult};
790
791    #[test]
792    fn slice_indexing_uncommon_strides() {
793        let v: alloc::vec::Vec<_> = (0..12).collect();
794        let dim = (2, 3, 2).into_dimension();
795        let strides = (1, 2, 6).into_dimension();
796        assert!(super::can_index_slice(&v, &dim, &strides).is_ok());
797
798        let strides = (2, 4, 12).into_dimension();
799        assert_eq!(
800            super::can_index_slice(&v, &dim, &strides),
801            Err(from_kind(ErrorKind::OutOfBounds))
802        );
803    }
804
805    #[test]
806    fn overlapping_strides_dim() {
807        let dim = (2, 3, 2).into_dimension();
808        let strides = (5, 2, 1).into_dimension();
809        assert!(super::dim_stride_overlap(&dim, &strides));
810        let strides = (-5isize as usize, 2, -1isize as usize).into_dimension();
811        assert!(super::dim_stride_overlap(&dim, &strides));
812        let strides = (6, 2, 1).into_dimension();
813        assert!(!super::dim_stride_overlap(&dim, &strides));
814        let strides = (6, -2isize as usize, 1).into_dimension();
815        assert!(!super::dim_stride_overlap(&dim, &strides));
816        let strides = (6, 0, 1).into_dimension();
817        assert!(super::dim_stride_overlap(&dim, &strides));
818        let strides = (-6isize as usize, 0, 1).into_dimension();
819        assert!(super::dim_stride_overlap(&dim, &strides));
820        let dim = (2, 2).into_dimension();
821        let strides = (3, 2).into_dimension();
822        assert!(!super::dim_stride_overlap(&dim, &strides));
823        let strides = (3, -2isize as usize).into_dimension();
824        assert!(!super::dim_stride_overlap(&dim, &strides));
825    }
826
827    #[test]
828    fn max_abs_offset_check_overflow_examples() {
829        let dim = (1, ::std::isize::MAX as usize, 1).into_dimension();
830        let strides = (1, 1, 1).into_dimension();
831        max_abs_offset_check_overflow::<u8, _>(&dim, &strides).unwrap();
832        let dim = (1, ::std::isize::MAX as usize, 2).into_dimension();
833        let strides = (1, 1, 1).into_dimension();
834        max_abs_offset_check_overflow::<u8, _>(&dim, &strides).unwrap_err();
835        let dim = (0, 2, 2).into_dimension();
836        let strides = (1, ::std::isize::MAX as usize, 1).into_dimension();
837        max_abs_offset_check_overflow::<u8, _>(&dim, &strides).unwrap_err();
838        let dim = (0, 2, 2).into_dimension();
839        let strides = (1, ::std::isize::MAX as usize / 4, 1).into_dimension();
840        max_abs_offset_check_overflow::<i32, _>(&dim, &strides).unwrap_err();
841    }
842
843    #[test]
844    fn can_index_slice_ix0() {
845        can_index_slice::<i32, _>(&[1], &Ix0(), &Ix0()).unwrap();
846        can_index_slice::<i32, _>(&[], &Ix0(), &Ix0()).unwrap_err();
847    }
848
849    #[test]
850    fn can_index_slice_ix1() {
851        can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(0)).unwrap();
852        can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(1)).unwrap();
853        can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(0)).unwrap_err();
854        can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(1)).unwrap_err();
855        can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(0)).unwrap();
856        can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(2)).unwrap();
857        can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(-1isize as usize)).unwrap();
858        can_index_slice::<i32, _>(&[1], &Ix1(2), &Ix1(1)).unwrap_err();
859        can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(0)).unwrap_err();
860        can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(1)).unwrap();
861        can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(-1isize as usize)).unwrap();
862    }
863
864    #[test]
865    fn can_index_slice_ix2() {
866        can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(0, 0)).unwrap();
867        can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(2, 1)).unwrap();
868        can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(0, 0)).unwrap();
869        can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(2, 1)).unwrap();
870        can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(0, 0)).unwrap();
871        can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err();
872        can_index_slice::<i32, _>(&[1], &Ix2(1, 2), &Ix2(5, 1)).unwrap_err();
873        can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 1)).unwrap();
874        can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 2)).unwrap_err();
875        can_index_slice::<i32, _>(&[1, 2, 3, 4, 5], &Ix2(2, 2), &Ix2(3, 1)).unwrap();
876        can_index_slice::<i32, _>(&[1, 2, 3, 4], &Ix2(2, 2), &Ix2(3, 1)).unwrap_err();
877    }
878
879    #[test]
880    fn can_index_slice_ix3() {
881        can_index_slice::<i32, _>(&[], &Ix3(0, 0, 1), &Ix3(2, 1, 3)).unwrap();
882        can_index_slice::<i32, _>(&[], &Ix3(1, 1, 1), &Ix3(2, 1, 3)).unwrap_err();
883        can_index_slice::<i32, _>(&[1], &Ix3(1, 1, 1), &Ix3(2, 1, 3)).unwrap();
884        can_index_slice::<i32, _>(&[1; 11], &Ix3(2, 2, 3), &Ix3(6, 3, 1)).unwrap_err();
885        can_index_slice::<i32, _>(&[1; 12], &Ix3(2, 2, 3), &Ix3(6, 3, 1)).unwrap();
886    }
887
888    #[test]
889    fn can_index_slice_zero_size_elem() {
890        can_index_slice::<(), _>(&[], &Ix1(0), &Ix1(1)).unwrap();
891        can_index_slice::<(), _>(&[()], &Ix1(1), &Ix1(1)).unwrap();
892        can_index_slice::<(), _>(&[(), ()], &Ix1(2), &Ix1(1)).unwrap();
893
894        // These might seem okay because the element type is zero-sized, but
895        // there could be a zero-sized type such that the number of instances
896        // in existence are carefully controlled.
897        can_index_slice::<(), _>(&[], &Ix1(1), &Ix1(1)).unwrap_err();
898        can_index_slice::<(), _>(&[()], &Ix1(2), &Ix1(1)).unwrap_err();
899
900        can_index_slice::<(), _>(&[(), ()], &Ix2(2, 1), &Ix2(1, 0)).unwrap();
901        can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(0, 0)).unwrap();
902
903        // This case would be probably be sound, but that's not entirely clear
904        // and it's not worth the special case code.
905        can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err();
906    }
907
908    quickcheck! {
909        fn can_index_slice_not_custom_same_as_can_index_slice(data: alloc::vec::Vec<u8>, dim: alloc::vec::Vec<usize>) -> bool {
910            let dim = IxDyn(&dim);
911            let result = can_index_slice_not_custom(data.len(), &dim);
912            if dim.size_checked().is_none() {
913                // Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`.
914                result.is_err()
915            } else {
916                result == can_index_slice(&data, &dim, &dim.default_strides()) &&
917                    result == can_index_slice(&data, &dim, &dim.fortran_strides())
918            }
919        }
920    }
921
922    quickcheck! {
923        // FIXME: This test can't handle larger values at the moment
924        fn extended_gcd_solves_eq(a: i16, b: i16) -> bool {
925            let (a, b) = (a as isize, b as isize);
926            let (g, (x, y)) = extended_gcd(a, b);
927            a * x + b * y == g
928        }
929
930        // FIXME: This test can't handle larger values at the moment
931        fn extended_gcd_correct_gcd(a: i16, b: i16) -> bool {
932            let (a, b) = (a as isize, b as isize);
933            let (g, _) = extended_gcd(a, b);
934            g == gcd(a, b)
935        }
936    }
937
938    #[test]
939    fn extended_gcd_zero() {
940        assert_eq!(extended_gcd(0, 0), (0, (0, 0)));
941        assert_eq!(extended_gcd(0, 5), (5, (0, 1)));
942        assert_eq!(extended_gcd(5, 0), (5, (1, 0)));
943        assert_eq!(extended_gcd(0, -5), (5, (0, -1)));
944        assert_eq!(extended_gcd(-5, 0), (5, (-1, 0)));
945    }
946
947    quickcheck! {
948        // FIXME: This test can't handle larger values at the moment
949        fn solve_linear_diophantine_eq_solution_existence(
950            a: i16, b: i16, c: i16
951        ) -> TestResult {
952            let (a, b, c) = (a as isize, b as isize, c as isize);
953
954            if a == 0 || b == 0 {
955                TestResult::discard()
956            } else {
957                TestResult::from_bool(
958                    (c % gcd(a, b) == 0) == solve_linear_diophantine_eq(a, b, c).is_some()
959                )
960            }
961        }
962
963        // FIXME: This test can't handle larger values at the moment
964        fn solve_linear_diophantine_eq_correct_solution(
965            a: i8, b: i8, c: i8, t: i8
966        ) -> TestResult {
967            let (a, b, c, t) = (a as isize, b as isize, c as isize, t as isize);
968
969            if a == 0 || b == 0 {
970                TestResult::discard()
971            } else {
972                match solve_linear_diophantine_eq(a, b, c) {
973                    Some((x0, xd)) => {
974                        let x = x0 + xd * t;
975                        let y = (c - a * x) / b;
976                        TestResult::from_bool(a * x + b * y == c)
977                    }
978                    None => TestResult::discard(),
979                }
980            }
981        }
982    }
983
984    quickcheck! {
985        // FIXME: This test is extremely slow, even with i16 values, investigate
986        fn arith_seq_intersect_correct(
987            first1: i8, len1: i8, step1: i8,
988            first2: i8, len2: i8, step2: i8
989        ) -> TestResult {
990            use std::cmp;
991
992            let (len1, len2) = (len1 as isize, len2 as isize);
993            let (first1, step1) = (first1 as isize, step1 as isize);
994            let (first2, step2) = (first2 as isize, step2 as isize);
995
996            if len1 == 0 || len2 == 0 {
997                // This case is impossible to reach in `arith_seq_intersect()`
998                // because the `min*` and `max*` arguments are inclusive.
999                return TestResult::discard();
1000            }
1001
1002            let len1 = len1.abs();
1003            let len2 = len2.abs();
1004
1005            // Convert to `min*` and `max*` arguments for `arith_seq_intersect()`.
1006            let last1 = first1 + step1 * (len1 - 1);
1007            let (min1, max1) = (cmp::min(first1, last1), cmp::max(first1, last1));
1008            let last2 = first2 + step2 * (len2 - 1);
1009            let (min2, max2) = (cmp::min(first2, last2), cmp::max(first2, last2));
1010
1011            // Naively determine if the sequences intersect.
1012            let seq1: alloc::vec::Vec<_> = (0..len1)
1013                .map(|n| first1 + step1 * n)
1014                .collect();
1015            let intersects = (0..len2)
1016                .map(|n| first2 + step2 * n)
1017                .any(|elem2| seq1.contains(&elem2));
1018
1019            TestResult::from_bool(
1020                arith_seq_intersect(
1021                    (min1, max1, if step1 == 0 { 1 } else { step1 }),
1022                    (min2, max2, if step2 == 0 { 1 } else { step2 })
1023                ) == intersects
1024            )
1025        }
1026    }
1027
1028    #[test]
1029    fn slice_min_max_empty() {
1030        assert_eq!(slice_min_max(0, Slice::new(0, None, 3)), None);
1031        assert_eq!(slice_min_max(10, Slice::new(1, Some(1), 3)), None);
1032        assert_eq!(slice_min_max(10, Slice::new(-1, Some(-1), 3)), None);
1033        assert_eq!(slice_min_max(10, Slice::new(1, Some(1), -3)), None);
1034        assert_eq!(slice_min_max(10, Slice::new(-1, Some(-1), -3)), None);
1035    }
1036
1037    #[test]
1038    fn slice_min_max_pos_step() {
1039        assert_eq!(slice_min_max(10, Slice::new(1, Some(8), 3)), Some((1, 7)));
1040        assert_eq!(slice_min_max(10, Slice::new(1, Some(9), 3)), Some((1, 7)));
1041        assert_eq!(slice_min_max(10, Slice::new(-9, Some(8), 3)), Some((1, 7)));
1042        assert_eq!(slice_min_max(10, Slice::new(-9, Some(9), 3)), Some((1, 7)));
1043        assert_eq!(slice_min_max(10, Slice::new(1, Some(-2), 3)), Some((1, 7)));
1044        assert_eq!(slice_min_max(10, Slice::new(1, Some(-1), 3)), Some((1, 7)));
1045        assert_eq!(slice_min_max(10, Slice::new(-9, Some(-2), 3)), Some((1, 7)));
1046        assert_eq!(slice_min_max(10, Slice::new(-9, Some(-1), 3)), Some((1, 7)));
1047        assert_eq!(slice_min_max(10, Slice::new(1, None, 3)), Some((1, 7)));
1048        assert_eq!(slice_min_max(10, Slice::new(-9, None, 3)), Some((1, 7)));
1049        assert_eq!(slice_min_max(11, Slice::new(1, None, 3)), Some((1, 10)));
1050        assert_eq!(slice_min_max(11, Slice::new(-10, None, 3)), Some((1, 10)));
1051    }
1052
1053    #[test]
1054    fn slice_min_max_neg_step() {
1055        assert_eq!(slice_min_max(10, Slice::new(1, Some(8), -3)), Some((1, 7)));
1056        assert_eq!(slice_min_max(10, Slice::new(2, Some(8), -3)), Some((4, 7)));
1057        assert_eq!(slice_min_max(10, Slice::new(-9, Some(8), -3)), Some((1, 7)));
1058        assert_eq!(slice_min_max(10, Slice::new(-8, Some(8), -3)), Some((4, 7)));
1059        assert_eq!(slice_min_max(10, Slice::new(1, Some(-2), -3)), Some((1, 7)));
1060        assert_eq!(slice_min_max(10, Slice::new(2, Some(-2), -3)), Some((4, 7)));
1061        assert_eq!(
1062            slice_min_max(10, Slice::new(-9, Some(-2), -3)),
1063            Some((1, 7))
1064        );
1065        assert_eq!(
1066            slice_min_max(10, Slice::new(-8, Some(-2), -3)),
1067            Some((4, 7))
1068        );
1069        assert_eq!(slice_min_max(9, Slice::new(2, None, -3)), Some((2, 8)));
1070        assert_eq!(slice_min_max(9, Slice::new(-7, None, -3)), Some((2, 8)));
1071        assert_eq!(slice_min_max(9, Slice::new(3, None, -3)), Some((5, 8)));
1072        assert_eq!(slice_min_max(9, Slice::new(-6, None, -3)), Some((5, 8)));
1073    }
1074
1075    #[test]
1076    fn slices_intersect_true() {
1077        assert!(slices_intersect(
1078            &Dim([4, 5]),
1079            s![NewAxis, .., NewAxis, ..],
1080            s![.., NewAxis, .., NewAxis]
1081        ));
1082        assert!(slices_intersect(
1083            &Dim([4, 5]),
1084            s![NewAxis, 0, ..],
1085            s![0, ..]
1086        ));
1087        assert!(slices_intersect(
1088            &Dim([4, 5]),
1089            s![..;2, ..],
1090            s![..;3, NewAxis, ..]
1091        ));
1092        assert!(slices_intersect(
1093            &Dim([4, 5]),
1094            s![.., ..;2],
1095            s![.., 1..;3, NewAxis]
1096        ));
1097        assert!(slices_intersect(&Dim([4, 10]), s![.., ..;9], s![.., 3..;6]));
1098    }
1099
1100    #[test]
1101    fn slices_intersect_false() {
1102        assert!(!slices_intersect(
1103            &Dim([4, 5]),
1104            s![..;2, ..],
1105            s![NewAxis, 1..;2, ..]
1106        ));
1107        assert!(!slices_intersect(
1108            &Dim([4, 5]),
1109            s![..;2, NewAxis, ..],
1110            s![1..;3, ..]
1111        ));
1112        assert!(!slices_intersect(
1113            &Dim([4, 5]),
1114            s![.., ..;9],
1115            s![.., 3..;6, NewAxis]
1116        ));
1117    }
1118}