ndarray/
numeric_util.rs

1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use std::cmp;
10
11use crate::LinalgScalar;
12
13/// Fold over the manually unrolled `xs` with `f`
14pub fn unrolled_fold<A, I, F>(mut xs: &[A], init: I, f: F) -> A
15where
16    A: Clone,
17    I: Fn() -> A,
18    F: Fn(A, A) -> A,
19{
20    // eightfold unrolled so that floating point can be vectorized
21    // (even with strict floating point accuracy semantics)
22    let mut acc = init();
23    let (mut p0, mut p1, mut p2, mut p3, mut p4, mut p5, mut p6, mut p7) = (
24        init(),
25        init(),
26        init(),
27        init(),
28        init(),
29        init(),
30        init(),
31        init(),
32    );
33    while xs.len() >= 8 {
34        p0 = f(p0, xs[0].clone());
35        p1 = f(p1, xs[1].clone());
36        p2 = f(p2, xs[2].clone());
37        p3 = f(p3, xs[3].clone());
38        p4 = f(p4, xs[4].clone());
39        p5 = f(p5, xs[5].clone());
40        p6 = f(p6, xs[6].clone());
41        p7 = f(p7, xs[7].clone());
42
43        xs = &xs[8..];
44    }
45    acc = f(acc.clone(), f(p0, p4));
46    acc = f(acc.clone(), f(p1, p5));
47    acc = f(acc.clone(), f(p2, p6));
48    acc = f(acc.clone(), f(p3, p7));
49
50    // make it clear to the optimizer that this loop is short
51    // and can not be autovectorized.
52    for (i, x) in xs.iter().enumerate() {
53        if i >= 7 {
54            break;
55        }
56        acc = f(acc.clone(), x.clone())
57    }
58    acc
59}
60
61/// Compute the dot product.
62///
63/// `xs` and `ys` must be the same length
64pub fn unrolled_dot<A>(xs: &[A], ys: &[A]) -> A
65where
66    A: LinalgScalar,
67{
68    debug_assert_eq!(xs.len(), ys.len());
69    // eightfold unrolled so that floating point can be vectorized
70    // (even with strict floating point accuracy semantics)
71    let len = cmp::min(xs.len(), ys.len());
72    let mut xs = &xs[..len];
73    let mut ys = &ys[..len];
74    let mut sum = A::zero();
75    let (mut p0, mut p1, mut p2, mut p3, mut p4, mut p5, mut p6, mut p7) = (
76        A::zero(),
77        A::zero(),
78        A::zero(),
79        A::zero(),
80        A::zero(),
81        A::zero(),
82        A::zero(),
83        A::zero(),
84    );
85    while xs.len() >= 8 {
86        p0 = p0 + xs[0] * ys[0];
87        p1 = p1 + xs[1] * ys[1];
88        p2 = p2 + xs[2] * ys[2];
89        p3 = p3 + xs[3] * ys[3];
90        p4 = p4 + xs[4] * ys[4];
91        p5 = p5 + xs[5] * ys[5];
92        p6 = p6 + xs[6] * ys[6];
93        p7 = p7 + xs[7] * ys[7];
94
95        xs = &xs[8..];
96        ys = &ys[8..];
97    }
98    sum = sum + (p0 + p4);
99    sum = sum + (p1 + p5);
100    sum = sum + (p2 + p6);
101    sum = sum + (p3 + p7);
102
103    for (i, (&x, &y)) in xs.iter().zip(ys).enumerate() {
104        if i >= 7 {
105            break;
106        }
107        sum = sum + x * y;
108    }
109    sum
110}
111
112/// Compute pairwise equality
113///
114/// `xs` and `ys` must be the same length
115pub fn unrolled_eq<A, B>(xs: &[A], ys: &[B]) -> bool
116where
117    A: PartialEq<B>,
118{
119    debug_assert_eq!(xs.len(), ys.len());
120    // eightfold unrolled for performance (this is not done by llvm automatically)
121    let len = cmp::min(xs.len(), ys.len());
122    let mut xs = &xs[..len];
123    let mut ys = &ys[..len];
124
125    while xs.len() >= 8 {
126        if (xs[0] != ys[0])
127            | (xs[1] != ys[1])
128            | (xs[2] != ys[2])
129            | (xs[3] != ys[3])
130            | (xs[4] != ys[4])
131            | (xs[5] != ys[5])
132            | (xs[6] != ys[6])
133            | (xs[7] != ys[7])
134        {
135            return false;
136        }
137        xs = &xs[8..];
138        ys = &ys[8..];
139    }
140
141    for i in 0..xs.len() {
142        if xs[i] != ys[i] {
143            return false;
144        }
145    }
146
147    true
148}