polars_arrow/legacy/kernels/rolling/nulls/
mean.rs

1use super::*;
2
3pub struct MeanWindow<'a, T> {
4    sum: SumWindow<'a, T>,
5}
6
7impl<
8        'a,
9        T: NativeType + IsFloat + Add<Output = T> + Sub<Output = T> + NumCast + Div<Output = T>,
10    > RollingAggWindowNulls<'a, T> for MeanWindow<'a, T>
11{
12    unsafe fn new(
13        slice: &'a [T],
14        validity: &'a Bitmap,
15        start: usize,
16        end: usize,
17        params: Option<RollingFnParams>,
18    ) -> Self {
19        Self {
20            sum: SumWindow::new(slice, validity, start, end, params),
21        }
22    }
23
24    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
25        let sum = self.sum.update(start, end);
26        sum.map(|sum| sum / NumCast::from(end - start - self.sum.null_count).unwrap())
27    }
28    fn is_valid(&self, min_periods: usize) -> bool {
29        self.sum.is_valid(min_periods)
30    }
31}
32
33pub fn rolling_mean<T>(
34    arr: &PrimitiveArray<T>,
35    window_size: usize,
36    min_periods: usize,
37    center: bool,
38    weights: Option<&[f64]>,
39    _params: Option<RollingFnParams>,
40) -> ArrayRef
41where
42    T: NativeType
43        + IsFloat
44        + PartialOrd
45        + Add<Output = T>
46        + Sub<Output = T>
47        + NumCast
48        + Div<Output = T>,
49{
50    if weights.is_some() {
51        panic!("weights not yet supported on array with null values")
52    }
53    if center {
54        rolling_apply_agg_window::<MeanWindow<_>, _, _>(
55            arr.values().as_slice(),
56            arr.validity().as_ref().unwrap(),
57            window_size,
58            min_periods,
59            det_offsets_center,
60            None,
61        )
62    } else {
63        rolling_apply_agg_window::<MeanWindow<_>, _, _>(
64            arr.values().as_slice(),
65            arr.validity().as_ref().unwrap(),
66            window_size,
67            min_periods,
68            det_offsets,
69            None,
70        )
71    }
72}