ndarray/
impl_ops.rs

1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::dimension::DimMax;
10use crate::Zip;
11use num_complex::Complex;
12
13/// Elements that can be used as direct operands in arithmetic with arrays.
14///
15/// For example, `f64` is a `ScalarOperand` which means that for an array `a`,
16/// arithmetic like `a + 1.0`, and, `a * 2.`, and `a += 3.` are allowed.
17///
18/// In the description below, let `A` be an array or array view,
19/// let `B` be an array with owned data,
20/// and let `C` be an array with mutable data.
21///
22/// `ScalarOperand` determines for which scalars `K` operations `&A @ K`, and `B @ K`,
23/// and `C @= K` are defined, as ***right hand side operands***, for applicable
24/// arithmetic operators (denoted `@`).
25///
26/// ***Left hand side*** scalar operands are not related to this trait
27/// (they need one `impl` per concrete scalar type); but they are still
28/// implemented for the same types, allowing operations
29/// `K @ &A`, and `K @ B` for primitive numeric types `K`.
30///
31/// This trait ***does not*** limit which elements can be stored in an array in general.
32/// Non-`ScalarOperand` types can still participate in arithmetic as array elements in
33/// in array-array operations.
34pub trait ScalarOperand: 'static + Clone {}
35impl ScalarOperand for bool {}
36impl ScalarOperand for i8 {}
37impl ScalarOperand for u8 {}
38impl ScalarOperand for i16 {}
39impl ScalarOperand for u16 {}
40impl ScalarOperand for i32 {}
41impl ScalarOperand for u32 {}
42impl ScalarOperand for i64 {}
43impl ScalarOperand for u64 {}
44impl ScalarOperand for i128 {}
45impl ScalarOperand for u128 {}
46impl ScalarOperand for isize {}
47impl ScalarOperand for usize {}
48impl ScalarOperand for f32 {}
49impl ScalarOperand for f64 {}
50impl ScalarOperand for Complex<f32> {}
51impl ScalarOperand for Complex<f64> {}
52
53macro_rules! impl_binary_op(
54    ($trt:ident, $operator:tt, $mth:ident, $iop:tt, $doc:expr) => (
55/// Perform elementwise
56#[doc=$doc]
57/// between `self` and `rhs`,
58/// and return the result.
59///
60/// `self` must be an `Array` or `ArcArray`.
61///
62/// If their shapes disagree, `self` is broadcast to their broadcast shape.
63///
64/// **Panics** if broadcasting isn’t possible.
65impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
66where
67    A: Clone + $trt<B, Output=A>,
68    B: Clone,
69    S: DataOwned<Elem=A> + DataMut,
70    S2: Data<Elem=B>,
71    D: Dimension + DimMax<E>,
72    E: Dimension,
73{
74    type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
75    fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
76    {
77        self.$mth(&rhs)
78    }
79}
80
81/// Perform elementwise
82#[doc=$doc]
83/// between `self` and reference `rhs`,
84/// and return the result.
85///
86/// `rhs` must be an `Array` or `ArcArray`.
87///
88/// If their shapes disagree, `self` is broadcast to their broadcast shape,
89/// cloning the data if needed.
90///
91/// **Panics** if broadcasting isn’t possible.
92impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
93where
94    A: Clone + $trt<B, Output=A>,
95    B: Clone,
96    S: DataOwned<Elem=A> + DataMut,
97    S2: Data<Elem=B>,
98    D: Dimension + DimMax<E>,
99    E: Dimension,
100{
101    type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
102    fn $mth(self, rhs: &ArrayBase<S2, E>) -> Self::Output
103    {
104        if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
105            let mut out = self.into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
106            out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth));
107            out
108        } else {
109            let (lhs_view, rhs_view) = self.broadcast_with(&rhs).unwrap();
110            if lhs_view.shape() == self.shape() {
111                let mut out = self.into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
112                out.zip_mut_with_same_shape(&rhs_view, clone_iopf(A::$mth));
113                out
114            } else {
115                Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth))
116            }
117        }
118    }
119}
120
121/// Perform elementwise
122#[doc=$doc]
123/// between reference `self` and `rhs`,
124/// and return the result.
125///
126/// `rhs` must be an `Array` or `ArcArray`.
127///
128/// If their shapes disagree, `self` is broadcast to their broadcast shape,
129/// cloning the data if needed.
130///
131/// **Panics** if broadcasting isn’t possible.
132impl<'a, A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for &'a ArrayBase<S, D>
133where
134    A: Clone + $trt<B, Output=B>,
135    B: Clone,
136    S: Data<Elem=A>,
137    S2: DataOwned<Elem=B> + DataMut,
138    D: Dimension,
139    E: Dimension + DimMax<D>,
140{
141    type Output = ArrayBase<S2, <E as DimMax<D>>::Output>;
142    fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
143    where
144    {
145        if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
146            let mut out = rhs.into_dimensionality::<<E as DimMax<D>>::Output>().unwrap();
147            out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth));
148            out
149        } else {
150            let (rhs_view, lhs_view) = rhs.broadcast_with(self).unwrap();
151            if rhs_view.shape() == rhs.shape() {
152                let mut out = rhs.into_dimensionality::<<E as DimMax<D>>::Output>().unwrap();
153                out.zip_mut_with_same_shape(&lhs_view, clone_iopf_rev(A::$mth));
154                out
155            } else {
156                Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth))
157            }
158        }
159    }
160}
161
162/// Perform elementwise
163#[doc=$doc]
164/// between references `self` and `rhs`,
165/// and return the result as a new `Array`.
166///
167/// If their shapes disagree, `self` and `rhs` is broadcast to their broadcast shape,
168/// cloning the data if needed.
169///
170/// **Panics** if broadcasting isn’t possible.
171impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
172where
173    A: Clone + $trt<B, Output=A>,
174    B: Clone,
175    S: Data<Elem=A>,
176    S2: Data<Elem=B>,
177    D: Dimension + DimMax<E>,
178    E: Dimension,
179{
180    type Output = Array<A, <D as DimMax<E>>::Output>;
181    fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
182        let (lhs, rhs) = if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
183            let lhs = self.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
184            let rhs = rhs.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
185            (lhs, rhs)
186        } else {
187            self.broadcast_with(rhs).unwrap()
188        };
189        Zip::from(lhs).and(rhs).map_collect(clone_opf(A::$mth))
190    }
191}
192
193/// Perform elementwise
194#[doc=$doc]
195/// between `self` and the scalar `x`,
196/// and return the result (based on `self`).
197///
198/// `self` must be an `Array` or `ArcArray`.
199impl<A, S, D, B> $trt<B> for ArrayBase<S, D>
200    where A: Clone + $trt<B, Output=A>,
201          S: DataOwned<Elem=A> + DataMut,
202          D: Dimension,
203          B: ScalarOperand,
204{
205    type Output = ArrayBase<S, D>;
206    fn $mth(mut self, x: B) -> ArrayBase<S, D> {
207        self.map_inplace(move |elt| {
208            *elt = elt.clone() $operator x.clone();
209        });
210        self
211    }
212}
213
214/// Perform elementwise
215#[doc=$doc]
216/// between the reference `self` and the scalar `x`,
217/// and return the result as a new `Array`.
218impl<'a, A, S, D, B> $trt<B> for &'a ArrayBase<S, D>
219    where A: Clone + $trt<B, Output=A>,
220          S: Data<Elem=A>,
221          D: Dimension,
222          B: ScalarOperand,
223{
224    type Output = Array<A, D>;
225    fn $mth(self, x: B) -> Self::Output {
226        self.map(move |elt| elt.clone() $operator x.clone())
227    }
228}
229    );
230);
231
232// Pick the expression $a for commutative and $b for ordered binop
233macro_rules! if_commutative {
234    (Commute { $a:expr } or { $b:expr }) => {
235        $a
236    };
237    (Ordered { $a:expr } or { $b:expr }) => {
238        $b
239    };
240}
241
242macro_rules! impl_scalar_lhs_op {
243    // $commutative flag. Reuse the self + scalar impl if we can.
244    // We can do this safely since these are the primitive numeric types
245    ($scalar:ty, $commutative:ident, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => (
246// these have no doc -- they are not visible in rustdoc
247// Perform elementwise
248// between the scalar `self` and array `rhs`,
249// and return the result (based on `self`).
250impl<S, D> $trt<ArrayBase<S, D>> for $scalar
251    where S: DataOwned<Elem=$scalar> + DataMut,
252          D: Dimension,
253{
254    type Output = ArrayBase<S, D>;
255    fn $mth(self, rhs: ArrayBase<S, D>) -> ArrayBase<S, D> {
256        if_commutative!($commutative {
257            rhs.$mth(self)
258        } or {{
259            let mut rhs = rhs;
260            rhs.map_inplace(move |elt| {
261                *elt = self $operator *elt;
262            });
263            rhs
264        }})
265    }
266}
267
268// Perform elementwise
269// between the scalar `self` and array `rhs`,
270// and return the result as a new `Array`.
271impl<'a, S, D> $trt<&'a ArrayBase<S, D>> for $scalar
272    where S: Data<Elem=$scalar>,
273          D: Dimension,
274{
275    type Output = Array<$scalar, D>;
276    fn $mth(self, rhs: &ArrayBase<S, D>) -> Self::Output {
277        if_commutative!($commutative {
278            rhs.$mth(self)
279        } or {
280            rhs.map(move |elt| self.clone() $operator elt.clone())
281        })
282    }
283}
284    );
285}
286
287mod arithmetic_ops {
288    use super::*;
289    use crate::imp_prelude::*;
290
291    use num_complex::Complex;
292    use std::ops::*;
293
294    fn clone_opf<A: Clone, B: Clone, C>(f: impl Fn(A, B) -> C) -> impl FnMut(&A, &B) -> C {
295        move |x, y| f(x.clone(), y.clone())
296    }
297
298    fn clone_iopf<A: Clone, B: Clone>(f: impl Fn(A, B) -> A) -> impl FnMut(&mut A, &B) {
299        move |x, y| *x = f(x.clone(), y.clone())
300    }
301
302    fn clone_iopf_rev<A: Clone, B: Clone>(f: impl Fn(A, B) -> B) -> impl FnMut(&mut B, &A) {
303        move |x, y| *x = f(y.clone(), x.clone())
304    }
305
306    impl_binary_op!(Add, +, add, +=, "addition");
307    impl_binary_op!(Sub, -, sub, -=, "subtraction");
308    impl_binary_op!(Mul, *, mul, *=, "multiplication");
309    impl_binary_op!(Div, /, div, /=, "division");
310    impl_binary_op!(Rem, %, rem, %=, "remainder");
311    impl_binary_op!(BitAnd, &, bitand, &=, "bit and");
312    impl_binary_op!(BitOr, |, bitor, |=, "bit or");
313    impl_binary_op!(BitXor, ^, bitxor, ^=, "bit xor");
314    impl_binary_op!(Shl, <<, shl, <<=, "left shift");
315    impl_binary_op!(Shr, >>, shr, >>=, "right shift");
316
317    macro_rules! all_scalar_ops {
318        ($int_scalar:ty) => (
319            impl_scalar_lhs_op!($int_scalar, Commute, +, Add, add, "addition");
320            impl_scalar_lhs_op!($int_scalar, Ordered, -, Sub, sub, "subtraction");
321            impl_scalar_lhs_op!($int_scalar, Commute, *, Mul, mul, "multiplication");
322            impl_scalar_lhs_op!($int_scalar, Ordered, /, Div, div, "division");
323            impl_scalar_lhs_op!($int_scalar, Ordered, %, Rem, rem, "remainder");
324            impl_scalar_lhs_op!($int_scalar, Commute, &, BitAnd, bitand, "bit and");
325            impl_scalar_lhs_op!($int_scalar, Commute, |, BitOr, bitor, "bit or");
326            impl_scalar_lhs_op!($int_scalar, Commute, ^, BitXor, bitxor, "bit xor");
327            impl_scalar_lhs_op!($int_scalar, Ordered, <<, Shl, shl, "left shift");
328            impl_scalar_lhs_op!($int_scalar, Ordered, >>, Shr, shr, "right shift");
329        );
330    }
331    all_scalar_ops!(i8);
332    all_scalar_ops!(u8);
333    all_scalar_ops!(i16);
334    all_scalar_ops!(u16);
335    all_scalar_ops!(i32);
336    all_scalar_ops!(u32);
337    all_scalar_ops!(i64);
338    all_scalar_ops!(u64);
339    all_scalar_ops!(isize);
340    all_scalar_ops!(usize);
341    all_scalar_ops!(i128);
342    all_scalar_ops!(u128);
343
344    impl_scalar_lhs_op!(bool, Commute, &, BitAnd, bitand, "bit and");
345    impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or");
346    impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor");
347
348    impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition");
349    impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction");
350    impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication");
351    impl_scalar_lhs_op!(f32, Ordered, /, Div, div, "division");
352    impl_scalar_lhs_op!(f32, Ordered, %, Rem, rem, "remainder");
353
354    impl_scalar_lhs_op!(f64, Commute, +, Add, add, "addition");
355    impl_scalar_lhs_op!(f64, Ordered, -, Sub, sub, "subtraction");
356    impl_scalar_lhs_op!(f64, Commute, *, Mul, mul, "multiplication");
357    impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division");
358    impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder");
359
360    impl_scalar_lhs_op!(Complex<f32>, Commute, +, Add, add, "addition");
361    impl_scalar_lhs_op!(Complex<f32>, Ordered, -, Sub, sub, "subtraction");
362    impl_scalar_lhs_op!(Complex<f32>, Commute, *, Mul, mul, "multiplication");
363    impl_scalar_lhs_op!(Complex<f32>, Ordered, /, Div, div, "division");
364
365    impl_scalar_lhs_op!(Complex<f64>, Commute, +, Add, add, "addition");
366    impl_scalar_lhs_op!(Complex<f64>, Ordered, -, Sub, sub, "subtraction");
367    impl_scalar_lhs_op!(Complex<f64>, Commute, *, Mul, mul, "multiplication");
368    impl_scalar_lhs_op!(Complex<f64>, Ordered, /, Div, div, "division");
369
370    impl<A, S, D> Neg for ArrayBase<S, D>
371    where
372        A: Clone + Neg<Output = A>,
373        S: DataOwned<Elem = A> + DataMut,
374        D: Dimension,
375    {
376        type Output = Self;
377        /// Perform an elementwise negation of `self` and return the result.
378        fn neg(mut self) -> Self {
379            self.map_inplace(|elt| {
380                *elt = -elt.clone();
381            });
382            self
383        }
384    }
385
386    impl<'a, A, S, D> Neg for &'a ArrayBase<S, D>
387    where
388        &'a A: 'a + Neg<Output = A>,
389        S: Data<Elem = A>,
390        D: Dimension,
391    {
392        type Output = Array<A, D>;
393        /// Perform an elementwise negation of reference `self` and return the
394        /// result as a new `Array`.
395        fn neg(self) -> Array<A, D> {
396            self.map(Neg::neg)
397        }
398    }
399
400    impl<A, S, D> Not for ArrayBase<S, D>
401    where
402        A: Clone + Not<Output = A>,
403        S: DataOwned<Elem = A> + DataMut,
404        D: Dimension,
405    {
406        type Output = Self;
407        /// Perform an elementwise unary not of `self` and return the result.
408        fn not(mut self) -> Self {
409            self.map_inplace(|elt| {
410                *elt = !elt.clone();
411            });
412            self
413        }
414    }
415
416    impl<'a, A, S, D> Not for &'a ArrayBase<S, D>
417    where
418        &'a A: 'a + Not<Output = A>,
419        S: Data<Elem = A>,
420        D: Dimension,
421    {
422        type Output = Array<A, D>;
423        /// Perform an elementwise unary not of reference `self` and return the
424        /// result as a new `Array`.
425        fn not(self) -> Array<A, D> {
426            self.map(Not::not)
427        }
428    }
429}
430
431mod assign_ops {
432    use super::*;
433    use crate::imp_prelude::*;
434
435    macro_rules! impl_assign_op {
436        ($trt:ident, $method:ident, $doc:expr) => {
437            use std::ops::$trt;
438
439            #[doc=$doc]
440            /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
441            ///
442            /// **Panics** if broadcasting isn’t possible.
443            impl<'a, A, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
444            where
445                A: Clone + $trt<A>,
446                S: DataMut<Elem = A>,
447                S2: Data<Elem = A>,
448                D: Dimension,
449                E: Dimension,
450            {
451                fn $method(&mut self, rhs: &ArrayBase<S2, E>) {
452                    self.zip_mut_with(rhs, |x, y| {
453                        x.$method(y.clone());
454                    });
455                }
456            }
457
458            #[doc=$doc]
459            impl<A, S, D> $trt<A> for ArrayBase<S, D>
460            where
461                A: ScalarOperand + $trt<A>,
462                S: DataMut<Elem = A>,
463                D: Dimension,
464            {
465                fn $method(&mut self, rhs: A) {
466                    self.map_inplace(move |elt| {
467                        elt.$method(rhs.clone());
468                    });
469                }
470            }
471        };
472    }
473
474    impl_assign_op!(
475        AddAssign,
476        add_assign,
477        "Perform `self += rhs` as elementwise addition (in place).\n"
478    );
479    impl_assign_op!(
480        SubAssign,
481        sub_assign,
482        "Perform `self -= rhs` as elementwise subtraction (in place).\n"
483    );
484    impl_assign_op!(
485        MulAssign,
486        mul_assign,
487        "Perform `self *= rhs` as elementwise multiplication (in place).\n"
488    );
489    impl_assign_op!(
490        DivAssign,
491        div_assign,
492        "Perform `self /= rhs` as elementwise division (in place).\n"
493    );
494    impl_assign_op!(
495        RemAssign,
496        rem_assign,
497        "Perform `self %= rhs` as elementwise remainder (in place).\n"
498    );
499    impl_assign_op!(
500        BitAndAssign,
501        bitand_assign,
502        "Perform `self &= rhs` as elementwise bit and (in place).\n"
503    );
504    impl_assign_op!(
505        BitOrAssign,
506        bitor_assign,
507        "Perform `self |= rhs` as elementwise bit or (in place).\n"
508    );
509    impl_assign_op!(
510        BitXorAssign,
511        bitxor_assign,
512        "Perform `self ^= rhs` as elementwise bit xor (in place).\n"
513    );
514    impl_assign_op!(
515        ShlAssign,
516        shl_assign,
517        "Perform `self <<= rhs` as elementwise left shift (in place).\n"
518    );
519    impl_assign_op!(
520        ShrAssign,
521        shr_assign,
522        "Perform `self >>= rhs` as elementwise right shift (in place).\n"
523    );
524}