polars_arrow/legacy/kernels/rolling/nulls/
quantile.rs

1use super::*;
2use crate::array::MutablePrimitiveArray;
3
4pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {
5    sorted: SortedBufNulls<'a, T>,
6    prob: f64,
7    method: QuantileMethod,
8}
9
10impl<
11        'a,
12        T: NativeType
13            + IsFloat
14            + Float
15            + std::iter::Sum
16            + AddAssign
17            + SubAssign
18            + Div<Output = T>
19            + NumCast
20            + One
21            + Zero
22            + PartialOrd
23            + Sub<Output = T>,
24    > RollingAggWindowNulls<'a, T> for QuantileWindow<'a, T>
25{
26    unsafe fn new(
27        slice: &'a [T],
28        validity: &'a Bitmap,
29        start: usize,
30        end: usize,
31        params: Option<RollingFnParams>,
32    ) -> Self {
33        let params = params.unwrap();
34        let RollingFnParams::Quantile(params) = params else {
35            unreachable!("expected Quantile params");
36        };
37        Self {
38            sorted: SortedBufNulls::new(slice, validity, start, end),
39            prob: params.prob,
40            method: params.method,
41        }
42    }
43
44    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
45        let (values, null_count) = self.sorted.update(start, end);
46        // The min periods_issue will be taken care of when actually rolling
47        if null_count == values.len() {
48            return None;
49        }
50        // Nulls are guaranteed to be at the front
51        let values = &values[null_count..];
52        let length = values.len();
53
54        let mut idx = match self.method {
55            QuantileMethod::Nearest => ((length as f64) * self.prob) as usize,
56            QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => {
57                ((length as f64 - 1.0) * self.prob).floor() as usize
58            },
59            QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,
60            QuantileMethod::Equiprobable => {
61                ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize
62            },
63        };
64
65        idx = std::cmp::min(idx, length - 1);
66
67        // we can unwrap because we sliced of the nulls
68        match self.method {
69            QuantileMethod::Midpoint => {
70                let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
71                Some(
72                    (values.get_unchecked(idx).unwrap() + values.get_unchecked(top_idx).unwrap())
73                        / T::from::<f64>(2.0f64).unwrap(),
74                )
75            },
76            QuantileMethod::Linear => {
77                let float_idx = (length as f64 - 1.0) * self.prob;
78                let top_idx = f64::ceil(float_idx) as usize;
79
80                if top_idx == idx {
81                    Some(values.get_unchecked(idx).unwrap())
82                } else {
83                    let proportion = T::from(float_idx - idx as f64).unwrap();
84                    Some(
85                        proportion
86                            * (values.get_unchecked(top_idx).unwrap()
87                                - values.get_unchecked(idx).unwrap())
88                            + values.get_unchecked(idx).unwrap(),
89                    )
90                }
91            },
92            _ => Some(values.get_unchecked(idx).unwrap()),
93        }
94    }
95
96    fn is_valid(&self, min_periods: usize) -> bool {
97        self.sorted.is_valid(min_periods)
98    }
99}
100
101pub fn rolling_quantile<T>(
102    arr: &PrimitiveArray<T>,
103    window_size: usize,
104    min_periods: usize,
105    center: bool,
106    weights: Option<&[f64]>,
107    params: Option<RollingFnParams>,
108) -> ArrayRef
109where
110    T: NativeType
111        + IsFloat
112        + Float
113        + std::iter::Sum
114        + AddAssign
115        + SubAssign
116        + Div<Output = T>
117        + NumCast
118        + One
119        + Zero
120        + PartialOrd
121        + Sub<Output = T>,
122{
123    if weights.is_some() {
124        panic!("weights not yet supported on array with null values")
125    }
126    let offset_fn = match center {
127        true => det_offsets_center,
128        false => det_offsets,
129    };
130    if !center {
131        let params = params.as_ref().unwrap();
132        let RollingFnParams::Quantile(params) = params else {
133            unreachable!("expected Quantile params");
134        };
135
136        let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>(
137            params.method,
138            min_periods,
139            window_size,
140            arr.clone(),
141            params.prob,
142        );
143        let out: PrimitiveArray<T> = out.into();
144        return Box::new(out);
145    }
146    rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
147        arr.values().as_slice(),
148        arr.validity().as_ref().unwrap(),
149        window_size,
150        min_periods,
151        offset_fn,
152        params,
153    )
154}
155
156#[cfg(test)]
157mod test {
158    use super::*;
159    use crate::buffer::Buffer;
160    use crate::datatypes::ArrowDataType;
161
162    #[test]
163    fn test_rolling_median_nulls() {
164        let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
165        let arr = &PrimitiveArray::new(
166            ArrowDataType::Float64,
167            buf,
168            Some(Bitmap::from(&[true, false, true, true])),
169        );
170        let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
171            prob: 0.5,
172            method: QuantileMethod::Linear,
173        }));
174
175        let out = rolling_quantile(arr, 2, 2, false, None, med_pars.clone());
176        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
177        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
178        assert_eq!(out, &[None, None, None, Some(3.5)]);
179
180        let out = rolling_quantile(arr, 2, 1, false, None, med_pars.clone());
181        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
182        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
183        assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.5)]);
184
185        let out = rolling_quantile(arr, 4, 1, false, None, med_pars.clone());
186        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
187        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
188        assert_eq!(out, &[Some(1.0), Some(1.0), Some(2.0), Some(3.0)]);
189
190        let out = rolling_quantile(arr, 4, 1, true, None, med_pars.clone());
191        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
192        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
193        assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(3.5)]);
194
195        let out = rolling_quantile(arr, 4, 4, true, None, med_pars.clone());
196        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
197        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
198        assert_eq!(out, &[None, None, None, None]);
199    }
200
201    #[test]
202    fn test_rolling_quantile_nulls_limits() {
203        // compare quantiles to corresponding min/max/median values
204        let buf = Buffer::<f64>::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
205        let values = &PrimitiveArray::new(
206            ArrowDataType::Float64,
207            buf,
208            Some(Bitmap::from(&[true, false, false, true, true])),
209        );
210
211        let methods = vec![
212            QuantileMethod::Lower,
213            QuantileMethod::Higher,
214            QuantileMethod::Nearest,
215            QuantileMethod::Midpoint,
216            QuantileMethod::Linear,
217            QuantileMethod::Equiprobable,
218        ];
219
220        for method in methods {
221            let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
222                prob: 0.0,
223                method,
224            }));
225            let out1 = rolling_min(values, 2, 1, false, None, None);
226            let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
227            let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
228            let out2 = rolling_quantile(values, 2, 1, false, None, min_pars);
229            let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
230            let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
231            assert_eq!(out1, out2);
232
233            let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
234                prob: 1.0,
235                method,
236            }));
237            let out1 = rolling_max(values, 2, 1, false, None, None);
238            let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
239            let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
240            let out2 = rolling_quantile(values, 2, 1, false, None, max_pars);
241            let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
242            let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
243            assert_eq!(out1, out2);
244        }
245    }
246}