1use crate::dimension::DimMax;
10use crate::Zip;
11use num_complex::Complex;
12
13pub trait ScalarOperand: 'static + Clone {}
35impl ScalarOperand for bool {}
36impl ScalarOperand for i8 {}
37impl ScalarOperand for u8 {}
38impl ScalarOperand for i16 {}
39impl ScalarOperand for u16 {}
40impl ScalarOperand for i32 {}
41impl ScalarOperand for u32 {}
42impl ScalarOperand for i64 {}
43impl ScalarOperand for u64 {}
44impl ScalarOperand for i128 {}
45impl ScalarOperand for u128 {}
46impl ScalarOperand for isize {}
47impl ScalarOperand for usize {}
48impl ScalarOperand for f32 {}
49impl ScalarOperand for f64 {}
50impl ScalarOperand for Complex<f32> {}
51impl ScalarOperand for Complex<f64> {}
52
53macro_rules! impl_binary_op(
54 ($trt:ident, $operator:tt, $mth:ident, $iop:tt, $doc:expr) => (
55#[doc=$doc]
57impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
66where
67 A: Clone + $trt<B, Output=A>,
68 B: Clone,
69 S: DataOwned<Elem=A> + DataMut,
70 S2: Data<Elem=B>,
71 D: Dimension + DimMax<E>,
72 E: Dimension,
73{
74 type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
75 fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
76 {
77 self.$mth(&rhs)
78 }
79}
80
81#[doc=$doc]
83impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
93where
94 A: Clone + $trt<B, Output=A>,
95 B: Clone,
96 S: DataOwned<Elem=A> + DataMut,
97 S2: Data<Elem=B>,
98 D: Dimension + DimMax<E>,
99 E: Dimension,
100{
101 type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
102 fn $mth(self, rhs: &ArrayBase<S2, E>) -> Self::Output
103 {
104 if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
105 let mut out = self.into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
106 out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth));
107 out
108 } else {
109 let (lhs_view, rhs_view) = self.broadcast_with(&rhs).unwrap();
110 if lhs_view.shape() == self.shape() {
111 let mut out = self.into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
112 out.zip_mut_with_same_shape(&rhs_view, clone_iopf(A::$mth));
113 out
114 } else {
115 Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth))
116 }
117 }
118 }
119}
120
121#[doc=$doc]
123impl<'a, A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for &'a ArrayBase<S, D>
133where
134 A: Clone + $trt<B, Output=B>,
135 B: Clone,
136 S: Data<Elem=A>,
137 S2: DataOwned<Elem=B> + DataMut,
138 D: Dimension,
139 E: Dimension + DimMax<D>,
140{
141 type Output = ArrayBase<S2, <E as DimMax<D>>::Output>;
142 fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
143 where
144 {
145 if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
146 let mut out = rhs.into_dimensionality::<<E as DimMax<D>>::Output>().unwrap();
147 out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth));
148 out
149 } else {
150 let (rhs_view, lhs_view) = rhs.broadcast_with(self).unwrap();
151 if rhs_view.shape() == rhs.shape() {
152 let mut out = rhs.into_dimensionality::<<E as DimMax<D>>::Output>().unwrap();
153 out.zip_mut_with_same_shape(&lhs_view, clone_iopf_rev(A::$mth));
154 out
155 } else {
156 Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth))
157 }
158 }
159 }
160}
161
162#[doc=$doc]
164impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
172where
173 A: Clone + $trt<B, Output=A>,
174 B: Clone,
175 S: Data<Elem=A>,
176 S2: Data<Elem=B>,
177 D: Dimension + DimMax<E>,
178 E: Dimension,
179{
180 type Output = Array<A, <D as DimMax<E>>::Output>;
181 fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
182 let (lhs, rhs) = if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
183 let lhs = self.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
184 let rhs = rhs.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
185 (lhs, rhs)
186 } else {
187 self.broadcast_with(rhs).unwrap()
188 };
189 Zip::from(lhs).and(rhs).map_collect(clone_opf(A::$mth))
190 }
191}
192
193#[doc=$doc]
195impl<A, S, D, B> $trt<B> for ArrayBase<S, D>
200 where A: Clone + $trt<B, Output=A>,
201 S: DataOwned<Elem=A> + DataMut,
202 D: Dimension,
203 B: ScalarOperand,
204{
205 type Output = ArrayBase<S, D>;
206 fn $mth(mut self, x: B) -> ArrayBase<S, D> {
207 self.map_inplace(move |elt| {
208 *elt = elt.clone() $operator x.clone();
209 });
210 self
211 }
212}
213
214#[doc=$doc]
216impl<'a, A, S, D, B> $trt<B> for &'a ArrayBase<S, D>
219 where A: Clone + $trt<B, Output=A>,
220 S: Data<Elem=A>,
221 D: Dimension,
222 B: ScalarOperand,
223{
224 type Output = Array<A, D>;
225 fn $mth(self, x: B) -> Self::Output {
226 self.map(move |elt| elt.clone() $operator x.clone())
227 }
228}
229 );
230);
231
232macro_rules! if_commutative {
234 (Commute { $a:expr } or { $b:expr }) => {
235 $a
236 };
237 (Ordered { $a:expr } or { $b:expr }) => {
238 $b
239 };
240}
241
242macro_rules! impl_scalar_lhs_op {
243 ($scalar:ty, $commutative:ident, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => (
246impl<S, D> $trt<ArrayBase<S, D>> for $scalar
251 where S: DataOwned<Elem=$scalar> + DataMut,
252 D: Dimension,
253{
254 type Output = ArrayBase<S, D>;
255 fn $mth(self, rhs: ArrayBase<S, D>) -> ArrayBase<S, D> {
256 if_commutative!($commutative {
257 rhs.$mth(self)
258 } or {{
259 let mut rhs = rhs;
260 rhs.map_inplace(move |elt| {
261 *elt = self $operator *elt;
262 });
263 rhs
264 }})
265 }
266}
267
268impl<'a, S, D> $trt<&'a ArrayBase<S, D>> for $scalar
272 where S: Data<Elem=$scalar>,
273 D: Dimension,
274{
275 type Output = Array<$scalar, D>;
276 fn $mth(self, rhs: &ArrayBase<S, D>) -> Self::Output {
277 if_commutative!($commutative {
278 rhs.$mth(self)
279 } or {
280 rhs.map(move |elt| self.clone() $operator elt.clone())
281 })
282 }
283}
284 );
285}
286
287mod arithmetic_ops {
288 use super::*;
289 use crate::imp_prelude::*;
290
291 use num_complex::Complex;
292 use std::ops::*;
293
294 fn clone_opf<A: Clone, B: Clone, C>(f: impl Fn(A, B) -> C) -> impl FnMut(&A, &B) -> C {
295 move |x, y| f(x.clone(), y.clone())
296 }
297
298 fn clone_iopf<A: Clone, B: Clone>(f: impl Fn(A, B) -> A) -> impl FnMut(&mut A, &B) {
299 move |x, y| *x = f(x.clone(), y.clone())
300 }
301
302 fn clone_iopf_rev<A: Clone, B: Clone>(f: impl Fn(A, B) -> B) -> impl FnMut(&mut B, &A) {
303 move |x, y| *x = f(y.clone(), x.clone())
304 }
305
306 impl_binary_op!(Add, +, add, +=, "addition");
307 impl_binary_op!(Sub, -, sub, -=, "subtraction");
308 impl_binary_op!(Mul, *, mul, *=, "multiplication");
309 impl_binary_op!(Div, /, div, /=, "division");
310 impl_binary_op!(Rem, %, rem, %=, "remainder");
311 impl_binary_op!(BitAnd, &, bitand, &=, "bit and");
312 impl_binary_op!(BitOr, |, bitor, |=, "bit or");
313 impl_binary_op!(BitXor, ^, bitxor, ^=, "bit xor");
314 impl_binary_op!(Shl, <<, shl, <<=, "left shift");
315 impl_binary_op!(Shr, >>, shr, >>=, "right shift");
316
317 macro_rules! all_scalar_ops {
318 ($int_scalar:ty) => (
319 impl_scalar_lhs_op!($int_scalar, Commute, +, Add, add, "addition");
320 impl_scalar_lhs_op!($int_scalar, Ordered, -, Sub, sub, "subtraction");
321 impl_scalar_lhs_op!($int_scalar, Commute, *, Mul, mul, "multiplication");
322 impl_scalar_lhs_op!($int_scalar, Ordered, /, Div, div, "division");
323 impl_scalar_lhs_op!($int_scalar, Ordered, %, Rem, rem, "remainder");
324 impl_scalar_lhs_op!($int_scalar, Commute, &, BitAnd, bitand, "bit and");
325 impl_scalar_lhs_op!($int_scalar, Commute, |, BitOr, bitor, "bit or");
326 impl_scalar_lhs_op!($int_scalar, Commute, ^, BitXor, bitxor, "bit xor");
327 impl_scalar_lhs_op!($int_scalar, Ordered, <<, Shl, shl, "left shift");
328 impl_scalar_lhs_op!($int_scalar, Ordered, >>, Shr, shr, "right shift");
329 );
330 }
331 all_scalar_ops!(i8);
332 all_scalar_ops!(u8);
333 all_scalar_ops!(i16);
334 all_scalar_ops!(u16);
335 all_scalar_ops!(i32);
336 all_scalar_ops!(u32);
337 all_scalar_ops!(i64);
338 all_scalar_ops!(u64);
339 all_scalar_ops!(isize);
340 all_scalar_ops!(usize);
341 all_scalar_ops!(i128);
342 all_scalar_ops!(u128);
343
344 impl_scalar_lhs_op!(bool, Commute, &, BitAnd, bitand, "bit and");
345 impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or");
346 impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor");
347
348 impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition");
349 impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction");
350 impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication");
351 impl_scalar_lhs_op!(f32, Ordered, /, Div, div, "division");
352 impl_scalar_lhs_op!(f32, Ordered, %, Rem, rem, "remainder");
353
354 impl_scalar_lhs_op!(f64, Commute, +, Add, add, "addition");
355 impl_scalar_lhs_op!(f64, Ordered, -, Sub, sub, "subtraction");
356 impl_scalar_lhs_op!(f64, Commute, *, Mul, mul, "multiplication");
357 impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division");
358 impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder");
359
360 impl_scalar_lhs_op!(Complex<f32>, Commute, +, Add, add, "addition");
361 impl_scalar_lhs_op!(Complex<f32>, Ordered, -, Sub, sub, "subtraction");
362 impl_scalar_lhs_op!(Complex<f32>, Commute, *, Mul, mul, "multiplication");
363 impl_scalar_lhs_op!(Complex<f32>, Ordered, /, Div, div, "division");
364
365 impl_scalar_lhs_op!(Complex<f64>, Commute, +, Add, add, "addition");
366 impl_scalar_lhs_op!(Complex<f64>, Ordered, -, Sub, sub, "subtraction");
367 impl_scalar_lhs_op!(Complex<f64>, Commute, *, Mul, mul, "multiplication");
368 impl_scalar_lhs_op!(Complex<f64>, Ordered, /, Div, div, "division");
369
370 impl<A, S, D> Neg for ArrayBase<S, D>
371 where
372 A: Clone + Neg<Output = A>,
373 S: DataOwned<Elem = A> + DataMut,
374 D: Dimension,
375 {
376 type Output = Self;
377 fn neg(mut self) -> Self {
379 self.map_inplace(|elt| {
380 *elt = -elt.clone();
381 });
382 self
383 }
384 }
385
386 impl<'a, A, S, D> Neg for &'a ArrayBase<S, D>
387 where
388 &'a A: 'a + Neg<Output = A>,
389 S: Data<Elem = A>,
390 D: Dimension,
391 {
392 type Output = Array<A, D>;
393 fn neg(self) -> Array<A, D> {
396 self.map(Neg::neg)
397 }
398 }
399
400 impl<A, S, D> Not for ArrayBase<S, D>
401 where
402 A: Clone + Not<Output = A>,
403 S: DataOwned<Elem = A> + DataMut,
404 D: Dimension,
405 {
406 type Output = Self;
407 fn not(mut self) -> Self {
409 self.map_inplace(|elt| {
410 *elt = !elt.clone();
411 });
412 self
413 }
414 }
415
416 impl<'a, A, S, D> Not for &'a ArrayBase<S, D>
417 where
418 &'a A: 'a + Not<Output = A>,
419 S: Data<Elem = A>,
420 D: Dimension,
421 {
422 type Output = Array<A, D>;
423 fn not(self) -> Array<A, D> {
426 self.map(Not::not)
427 }
428 }
429}
430
431mod assign_ops {
432 use super::*;
433 use crate::imp_prelude::*;
434
435 macro_rules! impl_assign_op {
436 ($trt:ident, $method:ident, $doc:expr) => {
437 use std::ops::$trt;
438
439 #[doc=$doc]
440 impl<'a, A, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
444 where
445 A: Clone + $trt<A>,
446 S: DataMut<Elem = A>,
447 S2: Data<Elem = A>,
448 D: Dimension,
449 E: Dimension,
450 {
451 fn $method(&mut self, rhs: &ArrayBase<S2, E>) {
452 self.zip_mut_with(rhs, |x, y| {
453 x.$method(y.clone());
454 });
455 }
456 }
457
458 #[doc=$doc]
459 impl<A, S, D> $trt<A> for ArrayBase<S, D>
460 where
461 A: ScalarOperand + $trt<A>,
462 S: DataMut<Elem = A>,
463 D: Dimension,
464 {
465 fn $method(&mut self, rhs: A) {
466 self.map_inplace(move |elt| {
467 elt.$method(rhs.clone());
468 });
469 }
470 }
471 };
472 }
473
474 impl_assign_op!(
475 AddAssign,
476 add_assign,
477 "Perform `self += rhs` as elementwise addition (in place).\n"
478 );
479 impl_assign_op!(
480 SubAssign,
481 sub_assign,
482 "Perform `self -= rhs` as elementwise subtraction (in place).\n"
483 );
484 impl_assign_op!(
485 MulAssign,
486 mul_assign,
487 "Perform `self *= rhs` as elementwise multiplication (in place).\n"
488 );
489 impl_assign_op!(
490 DivAssign,
491 div_assign,
492 "Perform `self /= rhs` as elementwise division (in place).\n"
493 );
494 impl_assign_op!(
495 RemAssign,
496 rem_assign,
497 "Perform `self %= rhs` as elementwise remainder (in place).\n"
498 );
499 impl_assign_op!(
500 BitAndAssign,
501 bitand_assign,
502 "Perform `self &= rhs` as elementwise bit and (in place).\n"
503 );
504 impl_assign_op!(
505 BitOrAssign,
506 bitor_assign,
507 "Perform `self |= rhs` as elementwise bit or (in place).\n"
508 );
509 impl_assign_op!(
510 BitXorAssign,
511 bitxor_assign,
512 "Perform `self ^= rhs` as elementwise bit xor (in place).\n"
513 );
514 impl_assign_op!(
515 ShlAssign,
516 shl_assign,
517 "Perform `self <<= rhs` as elementwise left shift (in place).\n"
518 );
519 impl_assign_op!(
520 ShrAssign,
521 shr_assign,
522 "Perform `self >>= rhs` as elementwise right shift (in place).\n"
523 );
524}