polars_arrow/legacy/kernels/rolling/nulls/
mean.rs1use super::*;
2
3pub struct MeanWindow<'a, T> {
4 sum: SumWindow<'a, T>,
5}
6
7impl<
8 'a,
9 T: NativeType + IsFloat + Add<Output = T> + Sub<Output = T> + NumCast + Div<Output = T>,
10 > RollingAggWindowNulls<'a, T> for MeanWindow<'a, T>
11{
12 unsafe fn new(
13 slice: &'a [T],
14 validity: &'a Bitmap,
15 start: usize,
16 end: usize,
17 params: Option<RollingFnParams>,
18 ) -> Self {
19 Self {
20 sum: SumWindow::new(slice, validity, start, end, params),
21 }
22 }
23
24 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
25 let sum = self.sum.update(start, end);
26 sum.map(|sum| sum / NumCast::from(end - start - self.sum.null_count).unwrap())
27 }
28 fn is_valid(&self, min_periods: usize) -> bool {
29 self.sum.is_valid(min_periods)
30 }
31}
32
33pub fn rolling_mean<T>(
34 arr: &PrimitiveArray<T>,
35 window_size: usize,
36 min_periods: usize,
37 center: bool,
38 weights: Option<&[f64]>,
39 _params: Option<RollingFnParams>,
40) -> ArrayRef
41where
42 T: NativeType
43 + IsFloat
44 + PartialOrd
45 + Add<Output = T>
46 + Sub<Output = T>
47 + NumCast
48 + Div<Output = T>,
49{
50 if weights.is_some() {
51 panic!("weights not yet supported on array with null values")
52 }
53 if center {
54 rolling_apply_agg_window::<MeanWindow<_>, _, _>(
55 arr.values().as_slice(),
56 arr.validity().as_ref().unwrap(),
57 window_size,
58 min_periods,
59 det_offsets_center,
60 None,
61 )
62 } else {
63 rolling_apply_agg_window::<MeanWindow<_>, _, _>(
64 arr.values().as_slice(),
65 arr.validity().as_ref().unwrap(),
66 window_size,
67 min_periods,
68 det_offsets,
69 None,
70 )
71 }
72}