1use std::cmp;
10
11use crate::LinalgScalar;
12
13pub fn unrolled_fold<A, I, F>(mut xs: &[A], init: I, f: F) -> A
15where
16 A: Clone,
17 I: Fn() -> A,
18 F: Fn(A, A) -> A,
19{
20 let mut acc = init();
23 let (mut p0, mut p1, mut p2, mut p3, mut p4, mut p5, mut p6, mut p7) = (
24 init(),
25 init(),
26 init(),
27 init(),
28 init(),
29 init(),
30 init(),
31 init(),
32 );
33 while xs.len() >= 8 {
34 p0 = f(p0, xs[0].clone());
35 p1 = f(p1, xs[1].clone());
36 p2 = f(p2, xs[2].clone());
37 p3 = f(p3, xs[3].clone());
38 p4 = f(p4, xs[4].clone());
39 p5 = f(p5, xs[5].clone());
40 p6 = f(p6, xs[6].clone());
41 p7 = f(p7, xs[7].clone());
42
43 xs = &xs[8..];
44 }
45 acc = f(acc.clone(), f(p0, p4));
46 acc = f(acc.clone(), f(p1, p5));
47 acc = f(acc.clone(), f(p2, p6));
48 acc = f(acc.clone(), f(p3, p7));
49
50 for (i, x) in xs.iter().enumerate() {
53 if i >= 7 {
54 break;
55 }
56 acc = f(acc.clone(), x.clone())
57 }
58 acc
59}
60
61pub fn unrolled_dot<A>(xs: &[A], ys: &[A]) -> A
65where
66 A: LinalgScalar,
67{
68 debug_assert_eq!(xs.len(), ys.len());
69 let len = cmp::min(xs.len(), ys.len());
72 let mut xs = &xs[..len];
73 let mut ys = &ys[..len];
74 let mut sum = A::zero();
75 let (mut p0, mut p1, mut p2, mut p3, mut p4, mut p5, mut p6, mut p7) = (
76 A::zero(),
77 A::zero(),
78 A::zero(),
79 A::zero(),
80 A::zero(),
81 A::zero(),
82 A::zero(),
83 A::zero(),
84 );
85 while xs.len() >= 8 {
86 p0 = p0 + xs[0] * ys[0];
87 p1 = p1 + xs[1] * ys[1];
88 p2 = p2 + xs[2] * ys[2];
89 p3 = p3 + xs[3] * ys[3];
90 p4 = p4 + xs[4] * ys[4];
91 p5 = p5 + xs[5] * ys[5];
92 p6 = p6 + xs[6] * ys[6];
93 p7 = p7 + xs[7] * ys[7];
94
95 xs = &xs[8..];
96 ys = &ys[8..];
97 }
98 sum = sum + (p0 + p4);
99 sum = sum + (p1 + p5);
100 sum = sum + (p2 + p6);
101 sum = sum + (p3 + p7);
102
103 for (i, (&x, &y)) in xs.iter().zip(ys).enumerate() {
104 if i >= 7 {
105 break;
106 }
107 sum = sum + x * y;
108 }
109 sum
110}
111
112pub fn unrolled_eq<A, B>(xs: &[A], ys: &[B]) -> bool
116where
117 A: PartialEq<B>,
118{
119 debug_assert_eq!(xs.len(), ys.len());
120 let len = cmp::min(xs.len(), ys.len());
122 let mut xs = &xs[..len];
123 let mut ys = &ys[..len];
124
125 while xs.len() >= 8 {
126 if (xs[0] != ys[0])
127 | (xs[1] != ys[1])
128 | (xs[2] != ys[2])
129 | (xs[3] != ys[3])
130 | (xs[4] != ys[4])
131 | (xs[5] != ys[5])
132 | (xs[6] != ys[6])
133 | (xs[7] != ys[7])
134 {
135 return false;
136 }
137 xs = &xs[8..];
138 ys = &ys[8..];
139 }
140
141 for i in 0..xs.len() {
142 if xs[i] != ys[i] {
143 return false;
144 }
145 }
146
147 true
148}