polars_arrow/legacy/kernels/rolling/nulls/
variance.rs1use super::*;
2
3pub(super) struct SumSquaredWindow<'a, T> {
4 slice: &'a [T],
5 validity: &'a Bitmap,
6 sum_of_squares: Option<T>,
7 last_start: usize,
8 last_end: usize,
9 null_count: usize,
10}
11
12impl<T: NativeType + IsFloat + Add<Output = T> + Sub<Output = T> + Mul<Output = T>>
13 SumSquaredWindow<'_, T>
14{
15 unsafe fn compute_sum_and_null_count(&mut self, start: usize, end: usize) -> Option<T> {
17 let mut sum_of_squares = None;
18 let mut idx = start;
19 self.null_count = 0;
20 for value in &self.slice[start..end] {
21 let valid = self.validity.get_bit_unchecked(idx);
22 if valid {
23 match sum_of_squares {
24 None => sum_of_squares = Some(*value * *value),
25 Some(current) => sum_of_squares = Some(*value * *value + current),
26 }
27 } else {
28 self.null_count += 1;
29 }
30 idx += 1;
31 }
32 self.sum_of_squares = sum_of_squares;
33 sum_of_squares
34 }
35}
36
37impl<'a, T: NativeType + IsFloat + Add<Output = T> + Sub<Output = T> + Mul<Output = T>>
38 RollingAggWindowNulls<'a, T> for SumSquaredWindow<'a, T>
39{
40 unsafe fn new(
41 slice: &'a [T],
42 validity: &'a Bitmap,
43 start: usize,
44 end: usize,
45 _params: Option<RollingFnParams>,
46 ) -> Self {
47 let mut out = Self {
48 slice,
49 validity,
50 sum_of_squares: None,
51 last_start: start,
52 last_end: end,
53 null_count: 0,
54 };
55 out.compute_sum_and_null_count(start, end);
56 out
57 }
58
59 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
60 let recompute_sum = if start >= self.last_end {
61 true
62 } else {
63 let mut recompute_sum = false;
65 for idx in self.last_start..start {
66 let valid = self.validity.get_bit_unchecked(idx);
69 if valid {
70 let leaving_value = *self.slice.get_unchecked(idx);
71
72 if T::is_float() && !leaving_value.is_finite() {
74 recompute_sum = true;
75 break;
76 }
77 self.sum_of_squares = self
78 .sum_of_squares
79 .map(|v| v - leaving_value * leaving_value)
80 } else {
81 self.null_count -= 1;
83
84 if self.sum_of_squares.is_none() {
87 recompute_sum = true;
88 break;
89 }
90 }
91 }
92 recompute_sum
93 };
94
95 self.last_start = start;
96
97 if recompute_sum {
99 self.compute_sum_and_null_count(start, end);
100 } else {
101 for idx in self.last_end..end {
102 let valid = self.validity.get_bit_unchecked(idx);
103
104 if valid {
105 let value = *self.slice.get_unchecked(idx);
106 let value = value * value;
107 match self.sum_of_squares {
108 None => self.sum_of_squares = Some(value),
109 Some(current) => self.sum_of_squares = Some(current + value),
110 }
111 } else {
112 self.null_count += 1;
114 }
115 }
116 }
117 self.last_end = end;
118 self.sum_of_squares
119 }
120 fn is_valid(&self, min_periods: usize) -> bool {
121 ((self.last_end - self.last_start) - self.null_count) >= min_periods
122 }
123}
124
125pub struct VarWindow<'a, T> {
129 mean: MeanWindow<'a, T>,
130 sum_of_squares: SumSquaredWindow<'a, T>,
131 ddof: u8,
132}
133
134impl<
135 'a,
136 T: NativeType
137 + IsFloat
138 + Float
139 + std::iter::Sum
140 + AddAssign
141 + SubAssign
142 + Div<Output = T>
143 + NumCast
144 + One
145 + Zero
146 + PartialOrd
147 + Add<Output = T>
148 + Sub<Output = T>,
149 > RollingAggWindowNulls<'a, T> for VarWindow<'a, T>
150{
151 unsafe fn new(
152 slice: &'a [T],
153 validity: &'a Bitmap,
154 start: usize,
155 end: usize,
156 params: Option<RollingFnParams>,
157 ) -> Self {
158 Self {
159 mean: MeanWindow::new(slice, validity, start, end, None),
160 sum_of_squares: SumSquaredWindow::new(slice, validity, start, end, None),
161 ddof: match params {
162 None => 1,
163 Some(pars) => match pars {
164 RollingFnParams::Var(p) => p.ddof,
165 _ => unreachable!("expected Var params"),
166 },
167 },
168 }
169 }
170
171 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
172 let sum_of_squares = self.sum_of_squares.update(start, end)?;
173 let null_count = self.sum_of_squares.null_count;
174 let count: T = NumCast::from(end - start - null_count).unwrap();
175
176 let mean = self.mean.update(start, end)?;
177 let ddof = NumCast::from(self.ddof).unwrap();
178
179 let denom = count - ddof;
180
181 if denom <= T::zero() {
182 None
183 } else if count == T::one() {
184 Some(T::zero())
185 } else if denom <= T::zero() {
186 Some(T::infinity())
187 } else {
188 let var = (sum_of_squares - count * mean * mean) / denom;
189 Some(if var < T::zero() { T::zero() } else { var })
190 }
191 }
192 fn is_valid(&self, min_periods: usize) -> bool {
193 self.mean.is_valid(min_periods)
194 }
195}
196
197pub fn rolling_var<T>(
198 arr: &PrimitiveArray<T>,
199 window_size: usize,
200 min_periods: usize,
201 center: bool,
202 weights: Option<&[f64]>,
203 params: Option<RollingFnParams>,
204) -> ArrayRef
205where
206 T: NativeType + std::iter::Sum<T> + Zero + AddAssign + SubAssign + IsFloat + Float,
207{
208 if weights.is_some() {
209 panic!("weights not yet supported on array with null values")
210 }
211 let offsets_fn = if center {
212 det_offsets_center
213 } else {
214 det_offsets
215 };
216 rolling_apply_agg_window::<VarWindow<_>, _, _>(
217 arr.values().as_slice(),
218 arr.validity().as_ref().unwrap(),
219 window_size,
220 min_periods,
221 offsets_fn,
222 params,
223 )
224}