1use crate::core::math::ArgminDot;
9use crate::core::math::ArgminTranspose;
10use num_complex::Complex;
11
12macro_rules! make_dot_vec {
13 ($t:ty) => {
14 impl<'a> ArgminDot<Vec<$t>, $t> for Vec<$t> {
15 #[inline]
16 fn dot(&self, other: &Vec<$t>) -> $t {
17 self.iter().zip(other.iter()).map(|(a, b)| a * b).sum()
18 }
19 }
20
21 impl<'a> ArgminDot<$t, Vec<$t>> for Vec<$t> {
22 #[inline]
23 fn dot(&self, other: &$t) -> Vec<$t> {
24 self.iter().map(|a| a * other).collect()
25 }
26 }
27
28 impl<'a> ArgminDot<Vec<$t>, Vec<$t>> for $t {
29 #[inline]
30 fn dot(&self, other: &Vec<$t>) -> Vec<$t> {
31 other.iter().map(|a| a * self).collect()
32 }
33 }
34
35 impl ArgminDot<Vec<$t>, Vec<Vec<$t>>> for Vec<$t> {
36 #[inline]
37 fn dot(&self, other: &Vec<$t>) -> Vec<Vec<$t>> {
38 self.iter()
39 .map(|b| other.iter().map(|a| a * b).collect())
40 .collect()
41 }
42 }
43
44 impl ArgminDot<Vec<$t>, Vec<$t>> for Vec<Vec<$t>> {
45 #[inline]
46 fn dot(&self, other: &Vec<$t>) -> Vec<$t> {
47 (0..self.len()).map(|i| self[i].dot(other)).collect()
48 }
49 }
50
51 impl ArgminDot<Vec<Vec<$t>>, Vec<Vec<$t>>> for Vec<Vec<$t>> {
52 #[inline]
53 fn dot(&self, other: &Vec<Vec<$t>>) -> Vec<Vec<$t>> {
54 let other = other.clone().t();
56 let sr = self.len();
57 assert!(sr > 0);
58 let sc = self[0].len();
59 assert!(sc > 0);
60 let or = other.len();
61 assert!(or > 0);
62 let oc = other[0].len();
63 assert_eq!(sc, or);
64 assert!(oc > 0);
65 let mut v = Vec::with_capacity(oc);
66 unsafe {
67 v.set_len(oc);
68 }
69 let mut out = vec![v; sr];
70 for i in 0..sr {
71 assert_eq!(self[i].len(), sc);
72 for j in 0..oc {
74 out[i][j] = self[i].dot(&other[j]);
75 }
76 }
77 out
78 }
79 }
80
81 impl<'a> ArgminDot<$t, Vec<Vec<$t>>> for Vec<Vec<$t>> {
82 #[inline]
83 fn dot(&self, other: &$t) -> Vec<Vec<$t>> {
84 (0..self.len())
85 .map(|i| self[i].iter().map(|a| a * other).collect())
86 .collect()
87 }
88 }
89
90 impl<'a> ArgminDot<Vec<Vec<$t>>, Vec<Vec<$t>>> for $t {
91 #[inline]
92 fn dot(&self, other: &Vec<Vec<$t>>) -> Vec<Vec<$t>> {
93 (0..other.len())
94 .map(|i| other[i].iter().map(|a| a * self).collect())
95 .collect()
96 }
97 }
98 };
99}
100
101make_dot_vec!(f32);
102make_dot_vec!(f64);
103make_dot_vec!(i8);
104make_dot_vec!(i16);
105make_dot_vec!(i32);
106make_dot_vec!(i64);
107make_dot_vec!(u8);
108make_dot_vec!(u16);
109make_dot_vec!(u32);
110make_dot_vec!(u64);
111make_dot_vec!(isize);
112make_dot_vec!(usize);
113make_dot_vec!(Complex<f32>);
114make_dot_vec!(Complex<f64>);
115make_dot_vec!(Complex<i8>);
116make_dot_vec!(Complex<i16>);
117make_dot_vec!(Complex<i32>);
118make_dot_vec!(Complex<i64>);
119make_dot_vec!(Complex<u8>);
120make_dot_vec!(Complex<u16>);
121make_dot_vec!(Complex<u32>);
122make_dot_vec!(Complex<u64>);
123make_dot_vec!(Complex<isize>);
124make_dot_vec!(Complex<usize>);
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use paste::item;
130
131 macro_rules! make_test {
132 ($t:ty) => {
133 item! {
134 #[test]
135 fn [<test_vec_vec_ $t>]() {
136 let a = vec![1 as $t, 2 as $t, 3 as $t];
137 let b = vec![4 as $t, 5 as $t, 6 as $t];
138 let res: $t = a.dot(&b);
139 assert!((((res - 32 as $t) as f64).abs()) < std::f64::EPSILON);
140 }
141 }
142
143 item! {
144 #[test]
145 fn [<test_vec_scalar_ $t>]() {
146 let a = vec![1 as $t, 2 as $t, 3 as $t];
147 let b = 2 as $t;
148 let product = a.dot(&b);
149 let res = vec![2 as $t, 4 as $t, 6 as $t];
150 for i in 0..3 {
151 assert!((((res[i] - product[i]) as f64).abs()) < std::f64::EPSILON);
152 }
153 }
154 }
155
156 item! {
157 #[test]
158 fn [<test_scalar_vec_ $t>]() {
159 let a = vec![1 as $t, 2 as $t, 3 as $t];
160 let b = 2 as $t;
161 let product = b.dot(&a);
162 let res = vec![2 as $t, 4 as $t, 6 as $t];
163 for i in 0..3 {
164 assert!((((res[i] - product[i]) as f64).abs()) < std::f64::EPSILON);
165 }
166 }
167 }
168
169 item! {
170 #[test]
171 fn [<test_mat_vec_ $t>]() {
172 let a = vec![1 as $t, 2 as $t, 3 as $t];
173 let b = vec![4 as $t, 5 as $t, 6 as $t];
174 let res = vec![
175 vec![4 as $t, 5 as $t, 6 as $t],
176 vec![8 as $t, 10 as $t, 12 as $t],
177 vec![12 as $t, 15 as $t, 18 as $t]
178 ];
179 let product: Vec<Vec<$t>> = a.dot(&b);
180 for i in 0..3 {
181 for j in 0..3 {
182 assert!((((res[i][j] - product[i][j]) as f64).abs()) < std::f64::EPSILON);
183 }
184 }
185 }
186 }
187
188 item! {
189 #[test]
190 fn [<test_mat_vec_2_ $t>]() {
191 let a = vec![
192 vec![1 as $t, 2 as $t, 3 as $t],
193 vec![4 as $t, 5 as $t, 6 as $t],
194 vec![7 as $t, 8 as $t, 9 as $t]
195 ];
196 let b = vec![1 as $t, 2 as $t, 3 as $t];
197 let res = vec![14 as $t, 32 as $t, 50 as $t];
198 let product = a.dot(&b);
199 for i in 0..3 {
200 assert!((((res[i] - product[i]) as f64).abs()) < std::f64::EPSILON);
201 }
202 }
203 }
204
205 item! {
206 #[test]
207 fn [<test_mat_mat_ $t>]() {
208 let a = vec![
209 vec![1 as $t, 2 as $t, 3 as $t],
210 vec![4 as $t, 5 as $t, 6 as $t],
211 vec![3 as $t, 2 as $t, 1 as $t]
212 ];
213 let b = vec![
214 vec![3 as $t, 2 as $t, 1 as $t],
215 vec![6 as $t, 5 as $t, 4 as $t],
216 vec![2 as $t, 4 as $t, 3 as $t]
217 ];
218 let res = vec![
219 vec![21 as $t, 24 as $t, 18 as $t],
220 vec![54 as $t, 57 as $t, 42 as $t],
221 vec![23 as $t, 20 as $t, 14 as $t]
222 ];
223 let product = a.dot(&b);
224 for i in 0..3 {
225 for j in 0..3 {
226 assert!((((res[i][j] - product[i][j]) as f64).abs()) < std::f64::EPSILON);
227 }
228 }
229 }
230 }
231
232 item! {
233 #[test]
234 #[should_panic]
235 fn [<test_mat_mat_panic_1_ $t>]() {
236 let a = vec![];
237 let b = vec![
238 vec![3 as $t, 2 as $t, 1 as $t],
239 vec![6 as $t, 5 as $t, 4 as $t],
240 vec![2 as $t, 4 as $t, 3 as $t]
241 ];
242 a.dot(&b);
243 }
244 }
245
246 item! {
247 #[test]
248 #[should_panic]
249 fn [<test_mat_mat_panic_2_ $t>]() {
250 let a: Vec<Vec<$t>> = vec![];
251 let b = vec![
252 vec![3 as $t, 2 as $t, 1 as $t],
253 vec![6 as $t, 5 as $t, 4 as $t],
254 vec![2 as $t, 4 as $t, 3 as $t]
255 ];
256 b.dot(&a);
257 }
258 }
259
260 item! {
261 #[test]
262 #[should_panic]
263 fn [<test_mat_mat_panic_3_ $t>]() {
264 let a = vec![
265 vec![1 as $t, 2 as $t],
266 vec![4 as $t, 5 as $t],
267 vec![3 as $t, 2 as $t]
268 ];
269 let b = vec![
270 vec![3 as $t, 2 as $t, 1 as $t],
271 vec![6 as $t, 5 as $t, 4 as $t],
272 vec![2 as $t, 4 as $t, 3 as $t]
273 ];
274 a.dot(&b);
275 }
276 }
277
278 item! {
279 #[test]
280 #[should_panic]
281 fn [<test_mat_mat_panic_4_ $t>]() {
282 let a = vec![
283 vec![1 as $t, 2 as $t, 3 as $t],
284 vec![4 as $t, 5 as $t, 6 as $t],
285 vec![3 as $t, 2 as $t, 1 as $t]
286 ];
287 let b = vec![
288 vec![3 as $t, 2 as $t],
289 vec![6 as $t, 5 as $t],
290 vec![3 as $t, 2 as $t]
291 ];
292 a.dot(&b);
293 }
294 }
295
296 item! {
297 #[test]
298 #[should_panic]
299 fn [<test_mat_mat_panic_5_ $t>]() {
300 let a = vec![
301 vec![1 as $t, 2 as $t, 3 as $t],
302 vec![4 as $t, 5 as $t, 6 as $t],
303 vec![3 as $t, 2 as $t, 1 as $t]
304 ];
305 let b = vec![
306 vec![3 as $t, 2 as $t, 1 as $t],
307 vec![6 as $t, 5 as $t, 4 as $t],
308 vec![2 as $t, 3 as $t]
309 ];
310 a.dot(&b);
311 }
312 }
313
314 item! {
315 #[test]
316 #[should_panic]
317 fn [<test_mat_mat_panic_6_ $t>]() {
318 let a = vec![
319 vec![1 as $t, 2 as $t, 3 as $t],
320 vec![4 as $t, 5 as $t],
321 vec![3 as $t, 2 as $t, 1 as $t]
322 ];
323 let b = vec![
324 vec![3 as $t, 2 as $t, 1 as $t],
325 vec![6 as $t, 5 as $t, 4 as $t],
326 vec![2 as $t, 4 as $t, 3 as $t]
327 ];
328 a.dot(&b);
329 }
330 }
331
332 item! {
333 #[test]
334 fn [<test_mat_primitive_ $t>]() {
335 let a = vec![
336 vec![1 as $t, 2 as $t, 3 as $t],
337 vec![4 as $t, 5 as $t, 6 as $t],
338 vec![3 as $t, 2 as $t, 1 as $t]
339 ];
340 let res = vec![
341 vec![2 as $t, 4 as $t, 6 as $t],
342 vec![8 as $t, 10 as $t, 12 as $t],
343 vec![6 as $t, 4 as $t, 2 as $t]
344 ];
345 let product = a.dot(&(2 as $t));
346 for i in 0..3 {
347 for j in 0..3 {
348 assert!((((res[i][j] - product[i][j]) as f64).abs()) < std::f64::EPSILON);
349 }
350 }
351 }
352 }
353
354 item! {
355 #[test]
356 fn [<test_primitive_mat_ $t>]() {
357 let a = vec![
358 vec![1 as $t, 2 as $t, 3 as $t],
359 vec![4 as $t, 5 as $t, 6 as $t],
360 vec![3 as $t, 2 as $t, 1 as $t]
361 ];
362 let res = vec![
363 vec![2 as $t, 4 as $t, 6 as $t],
364 vec![8 as $t, 10 as $t, 12 as $t],
365 vec![6 as $t, 4 as $t, 2 as $t]
366 ];
367 let product = (2 as $t).dot(&a);
368 for i in 0..3 {
369 for j in 0..3 {
370 assert!((((res[i][j] - product[i][j]) as f64).abs()) < std::f64::EPSILON);
371 }
372 }
373 }
374 }
375 };
376 }
377
378 make_test!(isize);
379 make_test!(usize);
380 make_test!(i8);
381 make_test!(u8);
382 make_test!(i16);
383 make_test!(u16);
384 make_test!(i32);
385 make_test!(u32);
386 make_test!(i64);
387 make_test!(u64);
388 make_test!(f32);
389 make_test!(f64);
390}