polars_arrow/legacy/kernels/rolling/nulls/
variance.rs

1use super::*;
2
3pub(super) struct SumSquaredWindow<'a, T> {
4    slice: &'a [T],
5    validity: &'a Bitmap,
6    sum_of_squares: Option<T>,
7    last_start: usize,
8    last_end: usize,
9    null_count: usize,
10}
11
12impl<T: NativeType + IsFloat + Add<Output = T> + Sub<Output = T> + Mul<Output = T>>
13    SumSquaredWindow<'_, T>
14{
15    // compute sum from the entire window
16    unsafe fn compute_sum_and_null_count(&mut self, start: usize, end: usize) -> Option<T> {
17        let mut sum_of_squares = None;
18        let mut idx = start;
19        self.null_count = 0;
20        for value in &self.slice[start..end] {
21            let valid = self.validity.get_bit_unchecked(idx);
22            if valid {
23                match sum_of_squares {
24                    None => sum_of_squares = Some(*value * *value),
25                    Some(current) => sum_of_squares = Some(*value * *value + current),
26                }
27            } else {
28                self.null_count += 1;
29            }
30            idx += 1;
31        }
32        self.sum_of_squares = sum_of_squares;
33        sum_of_squares
34    }
35}
36
37impl<'a, T: NativeType + IsFloat + Add<Output = T> + Sub<Output = T> + Mul<Output = T>>
38    RollingAggWindowNulls<'a, T> for SumSquaredWindow<'a, T>
39{
40    unsafe fn new(
41        slice: &'a [T],
42        validity: &'a Bitmap,
43        start: usize,
44        end: usize,
45        _params: Option<RollingFnParams>,
46    ) -> Self {
47        let mut out = Self {
48            slice,
49            validity,
50            sum_of_squares: None,
51            last_start: start,
52            last_end: end,
53            null_count: 0,
54        };
55        out.compute_sum_and_null_count(start, end);
56        out
57    }
58
59    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
60        let recompute_sum = if start >= self.last_end {
61            true
62        } else {
63            // remove elements that should leave the window
64            let mut recompute_sum = false;
65            for idx in self.last_start..start {
66                // SAFETY:
67                // we are in bounds
68                let valid = self.validity.get_bit_unchecked(idx);
69                if valid {
70                    let leaving_value = *self.slice.get_unchecked(idx);
71
72                    // if the leaving value is nan we need to recompute the window
73                    if T::is_float() && !leaving_value.is_finite() {
74                        recompute_sum = true;
75                        break;
76                    }
77                    self.sum_of_squares = self
78                        .sum_of_squares
79                        .map(|v| v - leaving_value * leaving_value)
80                } else {
81                    // null value leaving the window
82                    self.null_count -= 1;
83
84                    // self.sum is None and the leaving value is None
85                    // if the entering value is valid, we might get a new sum.
86                    if self.sum_of_squares.is_none() {
87                        recompute_sum = true;
88                        break;
89                    }
90                }
91            }
92            recompute_sum
93        };
94
95        self.last_start = start;
96
97        // we traverse all values and compute
98        if recompute_sum {
99            self.compute_sum_and_null_count(start, end);
100        } else {
101            for idx in self.last_end..end {
102                let valid = self.validity.get_bit_unchecked(idx);
103
104                if valid {
105                    let value = *self.slice.get_unchecked(idx);
106                    let value = value * value;
107                    match self.sum_of_squares {
108                        None => self.sum_of_squares = Some(value),
109                        Some(current) => self.sum_of_squares = Some(current + value),
110                    }
111                } else {
112                    // null value entering the window
113                    self.null_count += 1;
114                }
115            }
116        }
117        self.last_end = end;
118        self.sum_of_squares
119    }
120    fn is_valid(&self, min_periods: usize) -> bool {
121        ((self.last_end - self.last_start) - self.null_count) >= min_periods
122    }
123}
124
125// E[(xi - E[x])^2]
126// can be expanded to
127// E[x^2] - E[x]^2
128pub struct VarWindow<'a, T> {
129    mean: MeanWindow<'a, T>,
130    sum_of_squares: SumSquaredWindow<'a, T>,
131    ddof: u8,
132}
133
134impl<
135        'a,
136        T: NativeType
137            + IsFloat
138            + Float
139            + std::iter::Sum
140            + AddAssign
141            + SubAssign
142            + Div<Output = T>
143            + NumCast
144            + One
145            + Zero
146            + PartialOrd
147            + Add<Output = T>
148            + Sub<Output = T>,
149    > RollingAggWindowNulls<'a, T> for VarWindow<'a, T>
150{
151    unsafe fn new(
152        slice: &'a [T],
153        validity: &'a Bitmap,
154        start: usize,
155        end: usize,
156        params: Option<RollingFnParams>,
157    ) -> Self {
158        Self {
159            mean: MeanWindow::new(slice, validity, start, end, None),
160            sum_of_squares: SumSquaredWindow::new(slice, validity, start, end, None),
161            ddof: match params {
162                None => 1,
163                Some(pars) => match pars {
164                    RollingFnParams::Var(p) => p.ddof,
165                    _ => unreachable!("expected Var params"),
166                },
167            },
168        }
169    }
170
171    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
172        let sum_of_squares = self.sum_of_squares.update(start, end)?;
173        let null_count = self.sum_of_squares.null_count;
174        let count: T = NumCast::from(end - start - null_count).unwrap();
175
176        let mean = self.mean.update(start, end)?;
177        let ddof = NumCast::from(self.ddof).unwrap();
178
179        let denom = count - ddof;
180
181        if denom <= T::zero() {
182            None
183        } else if count == T::one() {
184            Some(T::zero())
185        } else if denom <= T::zero() {
186            Some(T::infinity())
187        } else {
188            let var = (sum_of_squares - count * mean * mean) / denom;
189            Some(if var < T::zero() { T::zero() } else { var })
190        }
191    }
192    fn is_valid(&self, min_periods: usize) -> bool {
193        self.mean.is_valid(min_periods)
194    }
195}
196
197pub fn rolling_var<T>(
198    arr: &PrimitiveArray<T>,
199    window_size: usize,
200    min_periods: usize,
201    center: bool,
202    weights: Option<&[f64]>,
203    params: Option<RollingFnParams>,
204) -> ArrayRef
205where
206    T: NativeType + std::iter::Sum<T> + Zero + AddAssign + SubAssign + IsFloat + Float,
207{
208    if weights.is_some() {
209        panic!("weights not yet supported on array with null values")
210    }
211    let offsets_fn = if center {
212        det_offsets_center
213    } else {
214        det_offsets
215    };
216    rolling_apply_agg_window::<VarWindow<_>, _, _>(
217        arr.values().as_slice(),
218        arr.validity().as_ref().unwrap(),
219        window_size,
220        min_periods,
221        offsets_fn,
222        params,
223    )
224}