polars_compute/
var_cov.rs

1// Some formulae:
2//     mean_x = sum(weight[i] * x[i]) / sum(weight)
3//     dp_xy = weighted sum of deviation products of variables x, y, written in
4//             the paper as simply XY.
5//     dp_xy = sum(weight[i] * (x[i] - mean_x) * (y[i] - mean_y))
6//
7//     cov(x, y) = dp_xy / sum(weight)
8//     var(x) = cov(x, x)
9//
10// Algorithms from:
11// Numerically stable parallel computation of (co-)variance.
12// Schubert, E., & Gertz, M. (2018).
13//
14// Key equations from the paper:
15// (17) for mean update, (23) for dp update (and also Table 1).
16
17use arrow::array::{Array, PrimitiveArray};
18use arrow::types::NativeType;
19use num_traits::AsPrimitive;
20use polars_utils::algebraic_ops::*;
21
22const CHUNK_SIZE: usize = 128;
23
24#[derive(Default, Clone)]
25pub struct VarState {
26    weight: f64,
27    mean: f64,
28    dp: f64,
29}
30
31#[derive(Default, Clone)]
32pub struct CovState {
33    weight: f64,
34    mean_x: f64,
35    mean_y: f64,
36    dp_xy: f64,
37}
38
39#[derive(Default, Clone)]
40pub struct PearsonState {
41    weight: f64,
42    mean_x: f64,
43    mean_y: f64,
44    dp_xx: f64,
45    dp_xy: f64,
46    dp_yy: f64,
47}
48
49impl VarState {
50    fn new(x: &[f64]) -> Self {
51        if x.is_empty() {
52            return Self::default();
53        }
54
55        let weight = x.len() as f64;
56        let mean = alg_sum_f64(x.iter().copied()) / weight;
57        Self {
58            weight,
59            mean,
60            dp: alg_sum_f64(x.iter().map(|&xi| (xi - mean) * (xi - mean))),
61        }
62    }
63
64    pub fn add_one(&mut self, x: f64) {
65        // Just a specialized version of
66        // self.combine(&Self { weight: 1.0, mean: x, dp: 0.0 })
67        let new_weight = self.weight + 1.0;
68        let delta_mean = self.mean - x;
69        let new_mean = self.mean - delta_mean / new_weight;
70        self.dp += (new_mean - x) * delta_mean;
71        self.weight = new_weight;
72        self.mean = new_mean;
73    }
74
75    pub fn combine(&mut self, other: &Self) {
76        if other.weight == 0.0 {
77            return;
78        }
79
80        let new_weight = self.weight + other.weight;
81        let other_weight_frac = other.weight / new_weight;
82        let delta_mean = self.mean - other.mean;
83        let new_mean = self.mean - delta_mean * other_weight_frac;
84        self.dp += other.dp + other.weight * (new_mean - other.mean) * delta_mean;
85        self.weight = new_weight;
86        self.mean = new_mean;
87    }
88
89    pub fn finalize(&self, ddof: u8) -> Option<f64> {
90        if self.weight <= ddof as f64 {
91            None
92        } else {
93            Some(self.dp / (self.weight - ddof as f64))
94        }
95    }
96}
97
98impl CovState {
99    fn new(x: &[f64], y: &[f64]) -> Self {
100        assert!(x.len() == y.len());
101        if x.is_empty() {
102            return Self::default();
103        }
104
105        let weight = x.len() as f64;
106        let inv_weight = 1.0 / weight;
107        let mean_x = alg_sum_f64(x.iter().copied()) * inv_weight;
108        let mean_y = alg_sum_f64(y.iter().copied()) * inv_weight;
109        Self {
110            weight,
111            mean_x,
112            mean_y,
113            dp_xy: alg_sum_f64(
114                x.iter()
115                    .zip(y)
116                    .map(|(&xi, &yi)| (xi - mean_x) * (yi - mean_y)),
117            ),
118        }
119    }
120
121    pub fn combine(&mut self, other: &Self) {
122        if other.weight == 0.0 {
123            return;
124        }
125
126        let new_weight = self.weight + other.weight;
127        let other_weight_frac = other.weight / new_weight;
128        let delta_mean_x = self.mean_x - other.mean_x;
129        let delta_mean_y = self.mean_y - other.mean_y;
130        let new_mean_x = self.mean_x - delta_mean_x * other_weight_frac;
131        let new_mean_y = self.mean_y - delta_mean_y * other_weight_frac;
132        self.dp_xy += other.dp_xy + other.weight * (new_mean_x - other.mean_x) * delta_mean_y;
133        self.weight = new_weight;
134        self.mean_x = new_mean_x;
135        self.mean_y = new_mean_y;
136    }
137
138    pub fn finalize(&self, ddof: u8) -> Option<f64> {
139        if self.weight <= ddof as f64 {
140            None
141        } else {
142            Some(self.dp_xy / (self.weight - ddof as f64))
143        }
144    }
145}
146
147impl PearsonState {
148    fn new(x: &[f64], y: &[f64]) -> Self {
149        assert!(x.len() == y.len());
150        if x.is_empty() {
151            return Self::default();
152        }
153
154        let weight = x.len() as f64;
155        let inv_weight = 1.0 / weight;
156        let mean_x = alg_sum_f64(x.iter().copied()) * inv_weight;
157        let mean_y = alg_sum_f64(y.iter().copied()) * inv_weight;
158        let mut dp_xx = 0.0;
159        let mut dp_xy = 0.0;
160        let mut dp_yy = 0.0;
161        for (xi, yi) in x.iter().zip(y.iter()) {
162            dp_xx = alg_add_f64(dp_xx, (xi - mean_x) * (xi - mean_x));
163            dp_xy = alg_add_f64(dp_xy, (xi - mean_x) * (yi - mean_y));
164            dp_yy = alg_add_f64(dp_yy, (yi - mean_y) * (yi - mean_y));
165        }
166        Self {
167            weight,
168            mean_x,
169            mean_y,
170            dp_xx,
171            dp_xy,
172            dp_yy,
173        }
174    }
175
176    pub fn combine(&mut self, other: &Self) {
177        if other.weight == 0.0 {
178            return;
179        }
180
181        let new_weight = self.weight + other.weight;
182        let other_weight_frac = other.weight / new_weight;
183        let delta_mean_x = self.mean_x - other.mean_x;
184        let delta_mean_y = self.mean_y - other.mean_y;
185        let new_mean_x = self.mean_x - delta_mean_x * other_weight_frac;
186        let new_mean_y = self.mean_y - delta_mean_y * other_weight_frac;
187        self.dp_xx += other.dp_xx + other.weight * (new_mean_x - other.mean_x) * delta_mean_x;
188        self.dp_xy += other.dp_xy + other.weight * (new_mean_x - other.mean_x) * delta_mean_y;
189        self.dp_yy += other.dp_yy + other.weight * (new_mean_y - other.mean_y) * delta_mean_y;
190        self.weight = new_weight;
191        self.mean_x = new_mean_x;
192        self.mean_y = new_mean_y;
193    }
194
195    pub fn finalize(&self) -> f64 {
196        let denom = (self.dp_xx * self.dp_yy).sqrt();
197        if denom == 0.0 {
198            f64::NAN
199        } else {
200            self.dp_xy / denom
201        }
202    }
203}
204
205fn chunk_as_float<T, I, F>(it: I, mut f: F)
206where
207    T: NativeType + AsPrimitive<f64>,
208    I: IntoIterator<Item = T>,
209    F: FnMut(&[f64]),
210{
211    let mut chunk = [0.0; CHUNK_SIZE];
212    let mut i = 0;
213    for val in it {
214        if i >= CHUNK_SIZE {
215            f(&chunk);
216            i = 0;
217        }
218        chunk[i] = val.as_();
219        i += 1;
220    }
221    if i > 0 {
222        f(&chunk[..i]);
223    }
224}
225
226fn chunk_as_float_binary<T, U, I, F>(it: I, mut f: F)
227where
228    T: NativeType + AsPrimitive<f64>,
229    U: NativeType + AsPrimitive<f64>,
230    I: IntoIterator<Item = (T, U)>,
231    F: FnMut(&[f64], &[f64]),
232{
233    let mut left_chunk = [0.0; CHUNK_SIZE];
234    let mut right_chunk = [0.0; CHUNK_SIZE];
235    let mut i = 0;
236    for (l, r) in it {
237        if i >= CHUNK_SIZE {
238            f(&left_chunk, &right_chunk);
239            i = 0;
240        }
241        left_chunk[i] = l.as_();
242        right_chunk[i] = r.as_();
243        i += 1;
244    }
245    if i > 0 {
246        f(&left_chunk[..i], &right_chunk[..i]);
247    }
248}
249
250pub fn var<T>(arr: &PrimitiveArray<T>) -> VarState
251where
252    T: NativeType + AsPrimitive<f64>,
253{
254    let mut out = VarState::default();
255    if arr.has_nulls() {
256        chunk_as_float(arr.non_null_values_iter(), |chunk| {
257            out.combine(&VarState::new(chunk))
258        });
259    } else {
260        chunk_as_float(arr.values().iter().copied(), |chunk| {
261            out.combine(&VarState::new(chunk))
262        });
263    }
264    out
265}
266
267pub fn cov<T, U>(x: &PrimitiveArray<T>, y: &PrimitiveArray<U>) -> CovState
268where
269    T: NativeType + AsPrimitive<f64>,
270    U: NativeType + AsPrimitive<f64>,
271{
272    assert!(x.len() == y.len());
273    let mut out = CovState::default();
274    if x.has_nulls() || y.has_nulls() {
275        chunk_as_float_binary(
276            x.iter()
277                .zip(y.iter())
278                .filter_map(|(l, r)| l.copied().zip(r.copied())),
279            |l, r| out.combine(&CovState::new(l, r)),
280        );
281    } else {
282        chunk_as_float_binary(
283            x.values().iter().copied().zip(y.values().iter().copied()),
284            |l, r| out.combine(&CovState::new(l, r)),
285        );
286    }
287    out
288}
289
290pub fn pearson_corr<T, U>(x: &PrimitiveArray<T>, y: &PrimitiveArray<U>) -> PearsonState
291where
292    T: NativeType + AsPrimitive<f64>,
293    U: NativeType + AsPrimitive<f64>,
294{
295    assert!(x.len() == y.len());
296    let mut out = PearsonState::default();
297    if x.has_nulls() || y.has_nulls() {
298        chunk_as_float_binary(
299            x.iter()
300                .zip(y.iter())
301                .filter_map(|(l, r)| l.copied().zip(r.copied())),
302            |l, r| out.combine(&PearsonState::new(l, r)),
303        );
304    } else {
305        chunk_as_float_binary(
306            x.values().iter().copied().zip(y.values().iter().copied()),
307            |l, r| out.combine(&PearsonState::new(l, r)),
308        );
309    }
310    out
311}