polars_arrow/legacy/kernels/rolling/no_nulls/
variance.rs

1use polars_error::polars_ensure;
2
3use super::*;
4
5pub(super) struct SumSquaredWindow<'a, T> {
6    slice: &'a [T],
7    sum_of_squares: T,
8    last_start: usize,
9    last_end: usize,
10    // if we don't recompute every 'n' iterations
11    // we get a accumulated error/drift
12    last_recompute: u8,
13}
14
15impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul<Output = T>>
16    RollingAggWindowNoNulls<'a, T> for SumSquaredWindow<'a, T>
17{
18    fn new(slice: &'a [T], start: usize, end: usize, _params: Option<RollingFnParams>) -> Self {
19        let sum = slice[start..end].iter().map(|v| *v * *v).sum::<T>();
20        Self {
21            slice,
22            sum_of_squares: sum,
23            last_start: start,
24            last_end: end,
25            last_recompute: 0,
26        }
27    }
28
29    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
30        // if we exceed the end, we have a completely new window
31        // so we recompute
32        let recompute_sum = if start >= self.last_end || self.last_recompute > 128 {
33            self.last_recompute = 0;
34            true
35        } else {
36            self.last_recompute += 1;
37            // remove elements that should leave the window
38            let mut recompute_sum = false;
39            for idx in self.last_start..start {
40                // SAFETY:
41                // we are in bounds
42                let leaving_value = self.slice.get_unchecked(idx);
43
44                if T::is_float() && !leaving_value.is_finite() {
45                    recompute_sum = true;
46                    break;
47                }
48
49                self.sum_of_squares -= *leaving_value * *leaving_value;
50            }
51            recompute_sum
52        };
53
54        self.last_start = start;
55
56        // we traverse all values and compute
57        if T::is_float() && recompute_sum {
58            self.sum_of_squares = self
59                .slice
60                .get_unchecked(start..end)
61                .iter()
62                .map(|v| *v * *v)
63                .sum::<T>();
64        } else {
65            for idx in self.last_end..end {
66                let entering_value = *self.slice.get_unchecked(idx);
67                self.sum_of_squares += entering_value * entering_value;
68            }
69        }
70        self.last_end = end;
71        Some(self.sum_of_squares)
72    }
73}
74
75// E[(xi - E[x])^2]
76// can be expanded to
77// E[x^2] - E[x]^2
78pub struct VarWindow<'a, T> {
79    mean: MeanWindow<'a, T>,
80    sum_of_squares: SumSquaredWindow<'a, T>,
81    ddof: u8,
82}
83
84impl<
85        'a,
86        T: NativeType
87            + IsFloat
88            + Float
89            + std::iter::Sum
90            + AddAssign
91            + SubAssign
92            + Div<Output = T>
93            + NumCast
94            + One
95            + Zero
96            + PartialOrd
97            + Sub<Output = T>,
98    > RollingAggWindowNoNulls<'a, T> for VarWindow<'a, T>
99{
100    fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self {
101        Self {
102            mean: MeanWindow::new(slice, start, end, None),
103            sum_of_squares: SumSquaredWindow::new(slice, start, end, None),
104            ddof: match params {
105                None => 1,
106                Some(pars) => {
107                    let RollingFnParams::Var(pars) = pars else {
108                        unreachable!("expected Var params");
109                    };
110                    pars.ddof
111                },
112            },
113        }
114    }
115
116    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
117        let count: T = NumCast::from(end - start).unwrap();
118        let sum_of_squares = self.sum_of_squares.update(start, end).unwrap_unchecked();
119        let mean = self.mean.update(start, end).unwrap_unchecked();
120
121        let denom = count - NumCast::from(self.ddof).unwrap();
122        if denom <= T::zero() {
123            None
124        } else if end - start == 1 {
125            Some(T::zero())
126        } else {
127            let out = (sum_of_squares - count * mean * mean) / denom;
128            // variance cannot be negative.
129            // if it is negative it is due to numeric instability
130            if out < T::zero() {
131                Some(T::zero())
132            } else {
133                Some(out)
134            }
135        }
136    }
137}
138
139pub fn rolling_var<T>(
140    values: &[T],
141    window_size: usize,
142    min_periods: usize,
143    center: bool,
144    weights: Option<&[f64]>,
145    params: Option<RollingFnParams>,
146) -> PolarsResult<ArrayRef>
147where
148    T: NativeType
149        + Float
150        + IsFloat
151        + std::iter::Sum
152        + AddAssign
153        + SubAssign
154        + Div<Output = T>
155        + NumCast
156        + One
157        + Zero
158        + Sub<Output = T>,
159{
160    let offset_fn = match center {
161        true => det_offsets_center,
162        false => det_offsets,
163    };
164    match weights {
165        None => rolling_apply_agg_window::<VarWindow<_>, _, _>(
166            values,
167            window_size,
168            min_periods,
169            offset_fn,
170            params,
171        ),
172        Some(weights) => {
173            // Validate and standardize the weights like we do for the mean. This definition is fine
174            // because frequency weights and unbiasing don't make sense for rolling operations.
175            let mut wts = no_nulls::coerce_weights(weights);
176            let wsum = wts.iter().fold(T::zero(), |acc, x| acc + *x);
177            polars_ensure!(
178                wsum != T::zero(),
179                ComputeError: "Weighted variance is undefined if weights sum to 0"
180            );
181            wts.iter_mut().for_each(|w| *w = *w / wsum);
182            super::rolling_apply_weights(
183                values,
184                window_size,
185                min_periods,
186                offset_fn,
187                compute_var_weights,
188                &wts,
189            )
190        },
191    }
192}
193
194#[cfg(test)]
195mod test {
196    use super::*;
197
198    #[test]
199    fn test_rolling_var() {
200        let values = &[1.0f64, 5.0, 3.0, 4.0];
201
202        let out = rolling_var(values, 2, 2, false, None, None).unwrap();
203        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
204        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
205        assert_eq!(out, &[None, Some(8.0), Some(2.0), Some(0.5)]);
206
207        let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 }));
208        let out = rolling_var(values, 2, 2, false, None, testpars).unwrap();
209        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
210        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
211        assert_eq!(out, &[None, Some(4.0), Some(1.0), Some(0.25)]);
212
213        let out = rolling_var(values, 2, 1, false, None, None).unwrap();
214        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
215        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
216        // we cannot compare nans, so we compare the string values
217        assert_eq!(
218            format!("{:?}", out.as_slice()),
219            format!("{:?}", &[None, Some(8.0), Some(2.0), Some(0.5)])
220        );
221        // test nan handling.
222        let values = &[-10.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
223        let out = rolling_var(values, 3, 3, false, None, None).unwrap();
224        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
225        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
226        // we cannot compare nans, so we compare the string values
227        assert_eq!(
228            format!("{:?}", out.as_slice()),
229            format!(
230                "{:?}",
231                &[
232                    None,
233                    None,
234                    Some(52.333333333333336),
235                    Some(f64::nan()),
236                    Some(f64::nan()),
237                    Some(f64::nan()),
238                    Some(1.0)
239                ]
240            )
241        );
242    }
243}