polars_arrow/legacy/kernels/rolling/nulls/
mod.rs

1mod mean;
2mod min_max;
3mod quantile;
4mod sum;
5mod variance;
6
7pub use mean::*;
8pub use min_max::*;
9pub use quantile::*;
10pub use sum::*;
11pub use variance::*;
12
13use super::*;
14
15pub trait RollingAggWindowNulls<'a, T: NativeType> {
16    /// # Safety
17    /// `start` and `end` must be in bounds for `slice` and `validity`
18    unsafe fn new(
19        slice: &'a [T],
20        validity: &'a Bitmap,
21        start: usize,
22        end: usize,
23        params: Option<RollingFnParams>,
24    ) -> Self;
25
26    /// # Safety
27    /// `start` and `end` must be in bounds of `slice` and `bitmap`
28    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T>;
29
30    fn is_valid(&self, min_periods: usize) -> bool;
31}
32
33// Use an aggregation window that maintains the state
34pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(
35    values: &'a [T],
36    validity: &'a Bitmap,
37    window_size: usize,
38    min_periods: usize,
39    det_offsets_fn: Fo,
40    params: Option<RollingFnParams>,
41) -> ArrayRef
42where
43    Fo: Fn(Idx, WindowSize, Len) -> (Start, End) + Copy,
44    Agg: RollingAggWindowNulls<'a, T>,
45    T: IsFloat + NativeType,
46{
47    let len = values.len();
48    let (start, end) = det_offsets_fn(0, window_size, len);
49    // SAFETY; we are in bounds
50    let mut agg_window = unsafe { Agg::new(values, validity, start, end, params) };
51
52    let mut validity = create_validity(min_periods, len, window_size, det_offsets_fn)
53        .unwrap_or_else(|| {
54            let mut validity = MutableBitmap::with_capacity(len);
55            validity.extend_constant(len, true);
56            validity
57        });
58
59    let out = (0..len)
60        .map(|idx| {
61            let (start, end) = det_offsets_fn(idx, window_size, len);
62            // SAFETY:
63            // we are in bounds
64            let agg = unsafe { agg_window.update(start, end) };
65            match agg {
66                Some(val) => {
67                    if agg_window.is_valid(min_periods) {
68                        val
69                    } else {
70                        // SAFETY: we are in bounds
71                        unsafe { validity.set_unchecked(idx, false) };
72                        T::default()
73                    }
74                },
75                None => {
76                    // SAFETY: we are in bounds
77                    unsafe { validity.set_unchecked(idx, false) };
78                    T::default()
79                },
80            }
81        })
82        .collect_trusted::<Vec<_>>();
83
84    Box::new(PrimitiveArray::new(
85        T::PRIMITIVE.into(),
86        out.into(),
87        Some(validity.into()),
88    ))
89}
90
91#[cfg(test)]
92mod test {
93    use super::*;
94    use crate::array::{Array, Int32Array};
95    use crate::buffer::Buffer;
96    use crate::datatypes::ArrowDataType;
97
98    fn get_null_arr() -> PrimitiveArray<f64> {
99        // 1, None, -1, 4
100        let buf = Buffer::from(vec![1.0, 0.0, -1.0, 4.0]);
101        PrimitiveArray::new(
102            ArrowDataType::Float64,
103            buf,
104            Some(Bitmap::from(&[true, false, true, true])),
105        )
106    }
107
108    #[test]
109    fn test_rolling_sum_nulls() {
110        let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
111        let arr = &PrimitiveArray::new(
112            ArrowDataType::Float64,
113            buf,
114            Some(Bitmap::from(&[true, false, true, true])),
115        );
116
117        let out = rolling_sum(arr, 2, 2, false, None, None);
118        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
119        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
120        assert_eq!(out, &[None, None, None, Some(7.0)]);
121
122        let out = rolling_sum(arr, 2, 1, false, None, None);
123        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
124        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
125        assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(7.0)]);
126
127        let out = rolling_sum(arr, 4, 1, false, None, None);
128        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
129        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
130        assert_eq!(out, &[Some(1.0), Some(1.0), Some(4.0), Some(8.0)]);
131
132        let out = rolling_sum(arr, 4, 1, true, None, None);
133        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
134        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
135        assert_eq!(out, &[Some(1.0), Some(4.0), Some(8.0), Some(7.0)]);
136
137        let out = rolling_sum(arr, 4, 4, true, None, None);
138        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
139        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
140        assert_eq!(out, &[None, None, None, None]);
141    }
142
143    #[test]
144    fn test_rolling_mean_nulls() {
145        let arr = get_null_arr();
146        let arr = &arr;
147
148        let out = rolling_mean(arr, 2, 2, false, None, None);
149        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
150        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
151        assert_eq!(out, &[None, None, None, Some(1.5)]);
152
153        let out = rolling_mean(arr, 2, 1, false, None, None);
154        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
155        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
156        assert_eq!(out, &[Some(1.0), Some(1.0), Some(-1.0), Some(1.5)]);
157
158        let out = rolling_mean(arr, 4, 1, false, None, None);
159        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
160        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
161        assert_eq!(out, &[Some(1.0), Some(1.0), Some(0.0), Some(4.0 / 3.0)]);
162    }
163
164    #[test]
165    fn test_rolling_var_nulls() {
166        let arr = get_null_arr();
167        let arr = &arr;
168
169        let out = rolling_var(arr, 3, 1, false, None, None);
170        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
171        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
172
173        assert_eq!(out, &[None, None, Some(2.0), Some(12.5)]);
174
175        let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 }));
176        let out = rolling_var(arr, 3, 1, false, None, testpars.clone());
177        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
178        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
179
180        assert_eq!(out, &[Some(0.0), Some(0.0), Some(1.0), Some(6.25)]);
181
182        let out = rolling_var(arr, 4, 1, false, None, None);
183        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
184        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
185        assert_eq!(out, &[None, None, Some(2.0), Some(6.333333333333334)]);
186
187        let out = rolling_var(arr, 4, 1, false, None, testpars.clone());
188        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
189        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
190        assert_eq!(
191            out,
192            &[Some(0.), Some(0.0), Some(1.0), Some(4.222222222222222)]
193        );
194    }
195
196    #[test]
197    fn test_rolling_max_no_nulls() {
198        let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
199        let arr = &PrimitiveArray::new(
200            ArrowDataType::Float64,
201            buf,
202            Some(Bitmap::from(&[true, true, true, true])),
203        );
204        let out = rolling_max(arr, 4, 1, false, None, None);
205        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
206        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
207        assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(4.0)]);
208
209        let out = rolling_max(arr, 2, 2, false, None, None);
210        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
211        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
212        assert_eq!(out, &[None, Some(2.0), Some(3.0), Some(4.0)]);
213
214        let out = rolling_max(arr, 4, 4, false, None, None);
215        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
216        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
217        assert_eq!(out, &[None, None, None, Some(4.0)]);
218
219        let buf = Buffer::from(vec![4.0, 3.0, 2.0, 1.0]);
220        let arr = &PrimitiveArray::new(
221            ArrowDataType::Float64,
222            buf,
223            Some(Bitmap::from(&[true, true, true, true])),
224        );
225        let out = rolling_max(arr, 2, 1, false, None, None);
226        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
227        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
228        assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);
229
230        let out =
231            super::no_nulls::rolling_max(arr.values().as_slice(), 2, 1, false, None, None).unwrap();
232        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
233        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
234        assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);
235    }
236
237    #[test]
238    fn test_rolling_extrema_nulls() {
239        let vals = vec![3, 3, 3, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1];
240        let validity = Bitmap::new_with_value(true, vals.len());
241        let window_size = 3;
242        let min_periods = 3;
243
244        let arr = Int32Array::new(ArrowDataType::Int32, vals.into(), Some(validity));
245
246        let out = rolling_apply_agg_window::<MaxWindow<_>, _, _>(
247            arr.values().as_slice(),
248            arr.validity().as_ref().unwrap(),
249            window_size,
250            min_periods,
251            det_offsets,
252            None,
253        );
254        let arr = out.as_any().downcast_ref::<Int32Array>().unwrap();
255        assert_eq!(arr.null_count(), 2);
256        assert_eq!(
257            &arr.values().as_slice()[2..],
258            &[3, 10, 10, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3]
259        );
260    }
261}