argmin/core/math/
weighteddot.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::math::ArgminDot;
9use crate::core::math::ArgminWeightedDot;
10
11impl<T, U, V> ArgminWeightedDot<T, U, V> for T
12where
13    Self: ArgminDot<T, U>,
14    V: ArgminDot<T, T>,
15{
16    #[inline]
17    fn weighted_dot(&self, w: &V, v: &T) -> U {
18        self.dot(&w.dot(v))
19    }
20}
21
22#[cfg(test)]
23mod tests_vec {
24    use super::*;
25    use paste::item;
26
27    macro_rules! make_test {
28        ($t:ty) => {
29            item! {
30                #[test]
31                fn [<test_ $t>]() {
32                    let a = vec![2 as $t, 1 as $t, 2 as $t];
33                    let b = vec![1 as $t, 2 as $t, 1 as $t];
34                    let w = vec![
35                        vec![8 as $t, 1 as $t, 6 as $t],
36                        vec![3 as $t, 5 as $t, 7 as $t],
37                        vec![4 as $t, 9 as $t, 2 as $t],
38                    ];
39                    let res: $t = a.weighted_dot(&w, &b);
40                    assert!((((res - 100 as $t) as f64).abs()) < std::f64::EPSILON);
41                }
42            }
43        };
44    }
45
46    make_test!(i8);
47    make_test!(u8);
48    make_test!(i16);
49    make_test!(u16);
50    make_test!(i32);
51    make_test!(u32);
52    make_test!(i64);
53    make_test!(u64);
54    make_test!(f32);
55    make_test!(f64);
56}
57
58#[cfg(feature = "ndarrayl")]
59#[cfg(test)]
60mod tests_ndarray {
61    use super::*;
62    use ndarray::array;
63    use paste::item;
64
65    macro_rules! make_test {
66        ($t:ty) => {
67            item! {
68                #[test]
69                fn [<test_ $t>]() {
70                    let a = array![2 as $t, 1 as $t, 2 as $t];
71                    let b = array![1 as $t, 2 as $t, 1 as $t];
72                    let w = array![
73                        [8 as $t, 1 as $t, 6 as $t],
74                        [3 as $t, 5 as $t, 7 as $t],
75                        [4 as $t, 9 as $t, 2 as $t],
76                    ];
77                    let res: $t = a.weighted_dot(&w, &b);
78                    assert!((((res - 100 as $t) as f64).abs()) < std::f64::EPSILON);
79                }
80            }
81        };
82    }
83
84    make_test!(i8);
85    make_test!(u8);
86    make_test!(i16);
87    make_test!(u16);
88    make_test!(i32);
89    make_test!(u32);
90    make_test!(i64);
91    make_test!(u64);
92    make_test!(f32);
93    make_test!(f64);
94}