polars_arrow/legacy/kernels/rolling/no_nulls/
variance.rs1use polars_error::polars_ensure;
2
3use super::*;
4
5pub(super) struct SumSquaredWindow<'a, T> {
6 slice: &'a [T],
7 sum_of_squares: T,
8 last_start: usize,
9 last_end: usize,
10 last_recompute: u8,
13}
14
15impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul<Output = T>>
16 RollingAggWindowNoNulls<'a, T> for SumSquaredWindow<'a, T>
17{
18 fn new(slice: &'a [T], start: usize, end: usize, _params: Option<RollingFnParams>) -> Self {
19 let sum = slice[start..end].iter().map(|v| *v * *v).sum::<T>();
20 Self {
21 slice,
22 sum_of_squares: sum,
23 last_start: start,
24 last_end: end,
25 last_recompute: 0,
26 }
27 }
28
29 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
30 let recompute_sum = if start >= self.last_end || self.last_recompute > 128 {
33 self.last_recompute = 0;
34 true
35 } else {
36 self.last_recompute += 1;
37 let mut recompute_sum = false;
39 for idx in self.last_start..start {
40 let leaving_value = self.slice.get_unchecked(idx);
43
44 if T::is_float() && !leaving_value.is_finite() {
45 recompute_sum = true;
46 break;
47 }
48
49 self.sum_of_squares -= *leaving_value * *leaving_value;
50 }
51 recompute_sum
52 };
53
54 self.last_start = start;
55
56 if T::is_float() && recompute_sum {
58 self.sum_of_squares = self
59 .slice
60 .get_unchecked(start..end)
61 .iter()
62 .map(|v| *v * *v)
63 .sum::<T>();
64 } else {
65 for idx in self.last_end..end {
66 let entering_value = *self.slice.get_unchecked(idx);
67 self.sum_of_squares += entering_value * entering_value;
68 }
69 }
70 self.last_end = end;
71 Some(self.sum_of_squares)
72 }
73}
74
75pub struct VarWindow<'a, T> {
79 mean: MeanWindow<'a, T>,
80 sum_of_squares: SumSquaredWindow<'a, T>,
81 ddof: u8,
82}
83
84impl<
85 'a,
86 T: NativeType
87 + IsFloat
88 + Float
89 + std::iter::Sum
90 + AddAssign
91 + SubAssign
92 + Div<Output = T>
93 + NumCast
94 + One
95 + Zero
96 + PartialOrd
97 + Sub<Output = T>,
98 > RollingAggWindowNoNulls<'a, T> for VarWindow<'a, T>
99{
100 fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self {
101 Self {
102 mean: MeanWindow::new(slice, start, end, None),
103 sum_of_squares: SumSquaredWindow::new(slice, start, end, None),
104 ddof: match params {
105 None => 1,
106 Some(pars) => {
107 let RollingFnParams::Var(pars) = pars else {
108 unreachable!("expected Var params");
109 };
110 pars.ddof
111 },
112 },
113 }
114 }
115
116 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
117 let count: T = NumCast::from(end - start).unwrap();
118 let sum_of_squares = self.sum_of_squares.update(start, end).unwrap_unchecked();
119 let mean = self.mean.update(start, end).unwrap_unchecked();
120
121 let denom = count - NumCast::from(self.ddof).unwrap();
122 if denom <= T::zero() {
123 None
124 } else if end - start == 1 {
125 Some(T::zero())
126 } else {
127 let out = (sum_of_squares - count * mean * mean) / denom;
128 if out < T::zero() {
131 Some(T::zero())
132 } else {
133 Some(out)
134 }
135 }
136 }
137}
138
139pub fn rolling_var<T>(
140 values: &[T],
141 window_size: usize,
142 min_periods: usize,
143 center: bool,
144 weights: Option<&[f64]>,
145 params: Option<RollingFnParams>,
146) -> PolarsResult<ArrayRef>
147where
148 T: NativeType
149 + Float
150 + IsFloat
151 + std::iter::Sum
152 + AddAssign
153 + SubAssign
154 + Div<Output = T>
155 + NumCast
156 + One
157 + Zero
158 + Sub<Output = T>,
159{
160 let offset_fn = match center {
161 true => det_offsets_center,
162 false => det_offsets,
163 };
164 match weights {
165 None => rolling_apply_agg_window::<VarWindow<_>, _, _>(
166 values,
167 window_size,
168 min_periods,
169 offset_fn,
170 params,
171 ),
172 Some(weights) => {
173 let mut wts = no_nulls::coerce_weights(weights);
176 let wsum = wts.iter().fold(T::zero(), |acc, x| acc + *x);
177 polars_ensure!(
178 wsum != T::zero(),
179 ComputeError: "Weighted variance is undefined if weights sum to 0"
180 );
181 wts.iter_mut().for_each(|w| *w = *w / wsum);
182 super::rolling_apply_weights(
183 values,
184 window_size,
185 min_periods,
186 offset_fn,
187 compute_var_weights,
188 &wts,
189 )
190 },
191 }
192}
193
194#[cfg(test)]
195mod test {
196 use super::*;
197
198 #[test]
199 fn test_rolling_var() {
200 let values = &[1.0f64, 5.0, 3.0, 4.0];
201
202 let out = rolling_var(values, 2, 2, false, None, None).unwrap();
203 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
204 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
205 assert_eq!(out, &[None, Some(8.0), Some(2.0), Some(0.5)]);
206
207 let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 }));
208 let out = rolling_var(values, 2, 2, false, None, testpars).unwrap();
209 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
210 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
211 assert_eq!(out, &[None, Some(4.0), Some(1.0), Some(0.25)]);
212
213 let out = rolling_var(values, 2, 1, false, None, None).unwrap();
214 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
215 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
216 assert_eq!(
218 format!("{:?}", out.as_slice()),
219 format!("{:?}", &[None, Some(8.0), Some(2.0), Some(0.5)])
220 );
221 let values = &[-10.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
223 let out = rolling_var(values, 3, 3, false, None, None).unwrap();
224 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
225 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
226 assert_eq!(
228 format!("{:?}", out.as_slice()),
229 format!(
230 "{:?}",
231 &[
232 None,
233 None,
234 Some(52.333333333333336),
235 Some(f64::nan()),
236 Some(f64::nan()),
237 Some(f64::nan()),
238 Some(1.0)
239 ]
240 )
241 );
242 }
243}