argmin/core/math/
dot_vec.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::math::ArgminDot;
9use crate::core::math::ArgminTranspose;
10use num_complex::Complex;
11
12macro_rules! make_dot_vec {
13    ($t:ty) => {
14        impl<'a> ArgminDot<Vec<$t>, $t> for Vec<$t> {
15            #[inline]
16            fn dot(&self, other: &Vec<$t>) -> $t {
17                self.iter().zip(other.iter()).map(|(a, b)| a * b).sum()
18            }
19        }
20
21        impl<'a> ArgminDot<$t, Vec<$t>> for Vec<$t> {
22            #[inline]
23            fn dot(&self, other: &$t) -> Vec<$t> {
24                self.iter().map(|a| a * other).collect()
25            }
26        }
27
28        impl<'a> ArgminDot<Vec<$t>, Vec<$t>> for $t {
29            #[inline]
30            fn dot(&self, other: &Vec<$t>) -> Vec<$t> {
31                other.iter().map(|a| a * self).collect()
32            }
33        }
34
35        impl ArgminDot<Vec<$t>, Vec<Vec<$t>>> for Vec<$t> {
36            #[inline]
37            fn dot(&self, other: &Vec<$t>) -> Vec<Vec<$t>> {
38                self.iter()
39                    .map(|b| other.iter().map(|a| a * b).collect())
40                    .collect()
41            }
42        }
43
44        impl ArgminDot<Vec<$t>, Vec<$t>> for Vec<Vec<$t>> {
45            #[inline]
46            fn dot(&self, other: &Vec<$t>) -> Vec<$t> {
47                (0..self.len()).map(|i| self[i].dot(other)).collect()
48            }
49        }
50
51        impl ArgminDot<Vec<Vec<$t>>, Vec<Vec<$t>>> for Vec<Vec<$t>> {
52            #[inline]
53            fn dot(&self, other: &Vec<Vec<$t>>) -> Vec<Vec<$t>> {
54                // Would be more efficient if this wasn't necessary!
55                let other = other.clone().t();
56                let sr = self.len();
57                assert!(sr > 0);
58                let sc = self[0].len();
59                assert!(sc > 0);
60                let or = other.len();
61                assert!(or > 0);
62                let oc = other[0].len();
63                assert_eq!(sc, or);
64                assert!(oc > 0);
65                let mut v = Vec::with_capacity(oc);
66                unsafe {
67                    v.set_len(oc);
68                }
69                let mut out = vec![v; sr];
70                for i in 0..sr {
71                    assert_eq!(self[i].len(), sc);
72                    // assert_eq!(other[i].len(), oc);
73                    for j in 0..oc {
74                        out[i][j] = self[i].dot(&other[j]);
75                    }
76                }
77                out
78            }
79        }
80
81        impl<'a> ArgminDot<$t, Vec<Vec<$t>>> for Vec<Vec<$t>> {
82            #[inline]
83            fn dot(&self, other: &$t) -> Vec<Vec<$t>> {
84                (0..self.len())
85                    .map(|i| self[i].iter().map(|a| a * other).collect())
86                    .collect()
87            }
88        }
89
90        impl<'a> ArgminDot<Vec<Vec<$t>>, Vec<Vec<$t>>> for $t {
91            #[inline]
92            fn dot(&self, other: &Vec<Vec<$t>>) -> Vec<Vec<$t>> {
93                (0..other.len())
94                    .map(|i| other[i].iter().map(|a| a * self).collect())
95                    .collect()
96            }
97        }
98    };
99}
100
101make_dot_vec!(f32);
102make_dot_vec!(f64);
103make_dot_vec!(i8);
104make_dot_vec!(i16);
105make_dot_vec!(i32);
106make_dot_vec!(i64);
107make_dot_vec!(u8);
108make_dot_vec!(u16);
109make_dot_vec!(u32);
110make_dot_vec!(u64);
111make_dot_vec!(isize);
112make_dot_vec!(usize);
113make_dot_vec!(Complex<f32>);
114make_dot_vec!(Complex<f64>);
115make_dot_vec!(Complex<i8>);
116make_dot_vec!(Complex<i16>);
117make_dot_vec!(Complex<i32>);
118make_dot_vec!(Complex<i64>);
119make_dot_vec!(Complex<u8>);
120make_dot_vec!(Complex<u16>);
121make_dot_vec!(Complex<u32>);
122make_dot_vec!(Complex<u64>);
123make_dot_vec!(Complex<isize>);
124make_dot_vec!(Complex<usize>);
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use paste::item;
130
131    macro_rules! make_test {
132        ($t:ty) => {
133            item! {
134                #[test]
135                fn [<test_vec_vec_ $t>]() {
136                    let a = vec![1 as $t, 2 as $t, 3 as $t];
137                    let b = vec![4 as $t, 5 as $t, 6 as $t];
138                    let res: $t = a.dot(&b);
139                    assert!((((res - 32 as $t) as f64).abs()) < std::f64::EPSILON);
140                }
141            }
142
143            item! {
144                #[test]
145                fn [<test_vec_scalar_ $t>]() {
146                    let a = vec![1 as $t, 2 as $t, 3 as $t];
147                    let b = 2 as $t;
148                    let product = a.dot(&b);
149                    let res = vec![2 as $t, 4 as $t, 6 as $t];
150                    for i in 0..3 {
151                        assert!((((res[i] - product[i]) as f64).abs()) < std::f64::EPSILON);
152                    }
153                }
154            }
155
156            item! {
157                #[test]
158                fn [<test_scalar_vec_ $t>]() {
159                    let a = vec![1 as $t, 2 as $t, 3 as $t];
160                    let b = 2 as $t;
161                    let product = b.dot(&a);
162                    let res = vec![2 as $t, 4 as $t, 6 as $t];
163                    for i in 0..3 {
164                        assert!((((res[i] - product[i]) as f64).abs()) < std::f64::EPSILON);
165                    }
166                }
167            }
168
169            item! {
170                #[test]
171                fn [<test_mat_vec_ $t>]() {
172                    let a = vec![1 as $t, 2 as $t, 3 as $t];
173                    let b = vec![4 as $t, 5 as $t, 6 as $t];
174                    let res = vec![
175                        vec![4 as $t, 5 as $t, 6 as $t],
176                        vec![8 as $t, 10 as $t, 12 as $t],
177                        vec![12 as $t, 15 as $t, 18 as $t]
178                    ];
179                    let product: Vec<Vec<$t>> = a.dot(&b);
180                    for i in 0..3 {
181                        for j in 0..3 {
182                            assert!((((res[i][j] - product[i][j]) as f64).abs()) < std::f64::EPSILON);
183                        }
184                    }
185                }
186            }
187
188            item! {
189                #[test]
190                fn [<test_mat_vec_2_ $t>]() {
191                    let a = vec![
192                        vec![1 as $t, 2 as $t, 3 as $t],
193                        vec![4 as $t, 5 as $t, 6 as $t],
194                        vec![7 as $t, 8 as $t, 9 as $t]
195                    ];
196                    let b = vec![1 as $t, 2 as $t, 3 as $t];
197                    let res = vec![14 as $t, 32 as $t, 50 as $t];
198                    let product = a.dot(&b);
199                    for i in 0..3 {
200                        assert!((((res[i] - product[i]) as f64).abs()) < std::f64::EPSILON);
201                    }
202                }
203            }
204
205            item! {
206                #[test]
207                fn [<test_mat_mat_ $t>]() {
208                    let a = vec![
209                        vec![1 as $t, 2 as $t, 3 as $t],
210                        vec![4 as $t, 5 as $t, 6 as $t],
211                        vec![3 as $t, 2 as $t, 1 as $t]
212                    ];
213                    let b = vec![
214                        vec![3 as $t, 2 as $t, 1 as $t],
215                        vec![6 as $t, 5 as $t, 4 as $t],
216                        vec![2 as $t, 4 as $t, 3 as $t]
217                    ];
218                    let res = vec![
219                        vec![21 as $t, 24 as $t, 18 as $t],
220                        vec![54 as $t, 57 as $t, 42 as $t],
221                        vec![23 as $t, 20 as $t, 14 as $t]
222                    ];
223                    let product = a.dot(&b);
224                    for i in 0..3 {
225                        for j in 0..3 {
226                            assert!((((res[i][j] - product[i][j]) as f64).abs()) < std::f64::EPSILON);
227                        }
228                    }
229                }
230            }
231
232            item! {
233                #[test]
234                #[should_panic]
235                fn [<test_mat_mat_panic_1_ $t>]() {
236                    let a = vec![];
237                    let b = vec![
238                        vec![3 as $t, 2 as $t, 1 as $t],
239                        vec![6 as $t, 5 as $t, 4 as $t],
240                        vec![2 as $t, 4 as $t, 3 as $t]
241                    ];
242                    a.dot(&b);
243                }
244            }
245
246            item! {
247                #[test]
248                #[should_panic]
249                fn [<test_mat_mat_panic_2_ $t>]() {
250                    let a: Vec<Vec<$t>> = vec![];
251                    let b = vec![
252                        vec![3 as $t, 2 as $t, 1 as $t],
253                        vec![6 as $t, 5 as $t, 4 as $t],
254                        vec![2 as $t, 4 as $t, 3 as $t]
255                    ];
256                    b.dot(&a);
257                }
258            }
259
260            item! {
261                #[test]
262                #[should_panic]
263                fn [<test_mat_mat_panic_3_ $t>]() {
264                    let a = vec![
265                        vec![1 as $t, 2 as $t],
266                        vec![4 as $t, 5 as $t],
267                        vec![3 as $t, 2 as $t]
268                    ];
269                    let b = vec![
270                        vec![3 as $t, 2 as $t, 1 as $t],
271                        vec![6 as $t, 5 as $t, 4 as $t],
272                        vec![2 as $t, 4 as $t, 3 as $t]
273                    ];
274                    a.dot(&b);
275                }
276            }
277
278            item! {
279                #[test]
280                #[should_panic]
281                fn [<test_mat_mat_panic_4_ $t>]() {
282                    let a = vec![
283                        vec![1 as $t, 2 as $t, 3 as $t],
284                        vec![4 as $t, 5 as $t, 6 as $t],
285                        vec![3 as $t, 2 as $t, 1 as $t]
286                    ];
287                    let b = vec![
288                        vec![3 as $t, 2 as $t],
289                        vec![6 as $t, 5 as $t],
290                        vec![3 as $t, 2 as $t]
291                    ];
292                    a.dot(&b);
293                }
294            }
295
296            item! {
297                #[test]
298                #[should_panic]
299                fn [<test_mat_mat_panic_5_ $t>]() {
300                    let a = vec![
301                        vec![1 as $t, 2 as $t, 3 as $t],
302                        vec![4 as $t, 5 as $t, 6 as $t],
303                        vec![3 as $t, 2 as $t, 1 as $t]
304                    ];
305                    let b = vec![
306                        vec![3 as $t, 2 as $t, 1 as $t],
307                        vec![6 as $t, 5 as $t, 4 as $t],
308                        vec![2 as $t, 3 as $t]
309                    ];
310                    a.dot(&b);
311                }
312            }
313
314            item! {
315                #[test]
316                #[should_panic]
317                fn [<test_mat_mat_panic_6_ $t>]() {
318                    let a = vec![
319                        vec![1 as $t, 2 as $t, 3 as $t],
320                        vec![4 as $t, 5 as $t],
321                        vec![3 as $t, 2 as $t, 1 as $t]
322                    ];
323                    let b = vec![
324                        vec![3 as $t, 2 as $t, 1 as $t],
325                        vec![6 as $t, 5 as $t, 4 as $t],
326                        vec![2 as $t, 4 as $t, 3 as $t]
327                    ];
328                    a.dot(&b);
329                }
330            }
331
332            item! {
333                #[test]
334                fn [<test_mat_primitive_ $t>]() {
335                    let a = vec![
336                        vec![1 as $t, 2 as $t, 3 as $t],
337                        vec![4 as $t, 5 as $t, 6 as $t],
338                        vec![3 as $t, 2 as $t, 1 as $t]
339                    ];
340                    let res = vec![
341                        vec![2 as $t, 4 as $t, 6 as $t],
342                        vec![8 as $t, 10 as $t, 12 as $t],
343                        vec![6 as $t, 4 as $t, 2 as $t]
344                    ];
345                    let product = a.dot(&(2 as $t));
346                    for i in 0..3 {
347                        for j in 0..3 {
348                            assert!((((res[i][j] - product[i][j]) as f64).abs()) < std::f64::EPSILON);
349                        }
350                    }
351                }
352            }
353
354            item! {
355                #[test]
356                fn [<test_primitive_mat_ $t>]() {
357                    let a = vec![
358                        vec![1 as $t, 2 as $t, 3 as $t],
359                        vec![4 as $t, 5 as $t, 6 as $t],
360                        vec![3 as $t, 2 as $t, 1 as $t]
361                    ];
362                    let res = vec![
363                        vec![2 as $t, 4 as $t, 6 as $t],
364                        vec![8 as $t, 10 as $t, 12 as $t],
365                        vec![6 as $t, 4 as $t, 2 as $t]
366                    ];
367                    let product = (2 as $t).dot(&a);
368                    for i in 0..3 {
369                        for j in 0..3 {
370                            assert!((((res[i][j] - product[i][j]) as f64).abs()) < std::f64::EPSILON);
371                        }
372                    }
373                }
374            }
375        };
376    }
377
378    make_test!(isize);
379    make_test!(usize);
380    make_test!(i8);
381    make_test!(u8);
382    make_test!(i16);
383    make_test!(u16);
384    make_test!(i32);
385    make_test!(u32);
386    make_test!(i64);
387    make_test!(u64);
388    make_test!(f32);
389    make_test!(f64);
390}