polars_arrow/legacy/kernels/rolling/no_nulls/
quantile.rs

1use num_traits::ToPrimitive;
2use polars_error::polars_ensure;
3
4use super::QuantileMethod::*;
5use super::*;
6
7pub struct QuantileWindow<'a, T: NativeType> {
8    sorted: SortedBuf<'a, T>,
9    prob: f64,
10    method: QuantileMethod,
11}
12
13impl<
14        'a,
15        T: NativeType
16            + Float
17            + std::iter::Sum
18            + AddAssign
19            + SubAssign
20            + Div<Output = T>
21            + NumCast
22            + One
23            + Zero
24            + Sub<Output = T>,
25    > RollingAggWindowNoNulls<'a, T> for QuantileWindow<'a, T>
26{
27    fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self {
28        let params = params.unwrap();
29        let RollingFnParams::Quantile(params) = params else {
30            unreachable!("expected Quantile params");
31        };
32
33        Self {
34            sorted: SortedBuf::new(slice, start, end),
35            prob: params.prob,
36            method: params.method,
37        }
38    }
39
40    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
41        let vals = self.sorted.update(start, end);
42        let length = vals.len();
43
44        let idx = match self.method {
45            Linear => {
46                // Maybe add a fast path for median case? They could branch depending on odd/even.
47                let length_f = length as f64;
48                let idx = ((length_f - 1.0) * self.prob).floor() as usize;
49
50                let float_idx_top = (length_f - 1.0) * self.prob;
51                let top_idx = float_idx_top.ceil() as usize;
52                return if idx == top_idx {
53                    Some(unsafe { *vals.get_unchecked(idx) })
54                } else {
55                    let proportion = T::from(float_idx_top - idx as f64).unwrap();
56                    let vi = unsafe { *vals.get_unchecked(idx) };
57                    let vj = unsafe { *vals.get_unchecked(top_idx) };
58
59                    Some(proportion * (vj - vi) + vi)
60                };
61            },
62            Midpoint => {
63                let length_f = length as f64;
64                let idx = (length_f * self.prob) as usize;
65                let idx = std::cmp::min(idx, length - 1);
66
67                let top_idx = ((length_f - 1.0) * self.prob).ceil() as usize;
68                return if top_idx == idx {
69                    // SAFETY:
70                    // we are in bounds
71                    Some(unsafe { *vals.get_unchecked(idx) })
72                } else {
73                    // SAFETY:
74                    // we are in bounds
75                    let (mid, mid_plus_1) =
76                        unsafe { (*vals.get_unchecked(idx), *vals.get_unchecked(idx + 1)) };
77
78                    Some((mid + mid_plus_1) / (T::one() + T::one()))
79                };
80            },
81            Nearest => {
82                let idx = ((length as f64) * self.prob) as usize;
83                std::cmp::min(idx, length - 1)
84            },
85            Lower => ((length as f64 - 1.0) * self.prob).floor() as usize,
86            Higher => {
87                let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
88                std::cmp::min(idx, length - 1)
89            },
90            Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize,
91        };
92
93        // SAFETY:
94        // we are in bounds
95        Some(unsafe { *vals.get_unchecked(idx) })
96    }
97}
98
99pub fn rolling_quantile<T>(
100    values: &[T],
101    window_size: usize,
102    min_periods: usize,
103    center: bool,
104    weights: Option<&[f64]>,
105    params: Option<RollingFnParams>,
106) -> PolarsResult<ArrayRef>
107where
108    T: NativeType
109        + IsFloat
110        + Float
111        + std::iter::Sum
112        + AddAssign
113        + SubAssign
114        + Div<Output = T>
115        + NumCast
116        + One
117        + Zero
118        + PartialOrd
119        + Sub<Output = T>,
120{
121    let offset_fn = match center {
122        true => det_offsets_center,
123        false => det_offsets,
124    };
125    match weights {
126        None => {
127            if !center {
128                let params = params.as_ref().unwrap();
129                let RollingFnParams::Quantile(params) = params else {
130                    unreachable!("expected Quantile params");
131                };
132                let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>(
133                    params.method,
134                    min_periods,
135                    window_size,
136                    values,
137                    params.prob,
138                );
139                let validity = create_validity(min_periods, values.len(), window_size, offset_fn);
140                return Ok(Box::new(PrimitiveArray::new(
141                    T::PRIMITIVE.into(),
142                    out.into(),
143                    validity.map(|b| b.into()),
144                )));
145            }
146
147            rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
148                values,
149                window_size,
150                min_periods,
151                offset_fn,
152                params,
153            )
154        },
155        Some(weights) => {
156            let wsum = weights.iter().sum();
157            polars_ensure!(
158                wsum != 0.0,
159                ComputeError: "Weighted quantile is undefined if weights sum to 0"
160            );
161            let params = params.unwrap();
162            let RollingFnParams::Quantile(params) = params else {
163                unreachable!("expected Quantile params");
164            };
165
166            Ok(rolling_apply_weighted_quantile(
167                values,
168                params.prob,
169                params.method,
170                window_size,
171                min_periods,
172                offset_fn,
173                weights,
174                wsum,
175            ))
176        },
177    }
178}
179
180#[inline]
181fn compute_wq<T>(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T
182where
183    T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
184{
185    // There are a few ways to compute a weighted quantile but no "canonical" way.
186    // This is mostly taken from the Julia implementation which was readable and reasonable
187    // https://juliastats.org/StatsBase.jl/stable/scalarstats/#Quantile-and-Related-Functions-1
188    let (mut s, mut s_old, mut vk, mut v_old) = (0.0, 0.0, T::zero(), T::zero());
189
190    // Once the cumulative weight crosses h, we've found our ind{ex/ices}. The definition may look
191    // odd but it's the equivalent of taking h = p * (n - 1) + 1 if your data is indexed from 1.
192    let h: f64 = p * (wsum - buf[0].1) + buf[0].1;
193    for &(v, w) in buf.iter() {
194        if s > h {
195            break;
196        }
197        (s_old, v_old, vk) = (s, vk, v);
198        s += w;
199    }
200    match (h == s_old, method) {
201        (true, _) => v_old, // If we hit the break exactly interpolation shouldn't matter
202        (_, Lower) => v_old,
203        (_, Higher) => vk,
204        (_, Nearest) => {
205            if s - h > h - s_old {
206                v_old
207            } else {
208                vk
209            }
210        },
211        (_, Equiprobable) => {
212            let threshold = (wsum * p).ceil() - 1.0;
213            if s > threshold {
214                vk
215            } else {
216                v_old
217            }
218        },
219        (_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(),
220        // This is seemingly the canonical way to do it.
221        (_, Linear) => {
222            v_old + <T as NumCast>::from((h - s_old) / (s - s_old)).unwrap() * (vk - v_old)
223        },
224    }
225}
226
227#[allow(clippy::too_many_arguments)]
228fn rolling_apply_weighted_quantile<T, Fo>(
229    values: &[T],
230    p: f64,
231    method: QuantileMethod,
232    window_size: usize,
233    min_periods: usize,
234    det_offsets_fn: Fo,
235    weights: &[f64],
236    wsum: f64,
237) -> ArrayRef
238where
239    Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
240    T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
241{
242    assert_eq!(weights.len(), window_size);
243    // Keep nonzero weights and their indices to know which values we need each iteration.
244    let nz_idx_wts: Vec<_> = weights.iter().enumerate().filter(|x| x.1 != &0.0).collect();
245    let mut buf = vec![(T::zero(), 0.0); nz_idx_wts.len()];
246    let len = values.len();
247    let out = (0..len)
248        .map(|idx| {
249            // Don't need end. Window size is constant and we computed offsets from start above.
250            let (start, _) = det_offsets_fn(idx, window_size, len);
251
252            // Sorting is not ideal, see https://github.com/tobiasschoch/wquantile for something faster
253            unsafe {
254                buf.iter_mut()
255                    .zip(nz_idx_wts.iter())
256                    .for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w));
257            }
258            buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0));
259            compute_wq(&buf, p, wsum, method)
260        })
261        .collect_trusted::<Vec<T>>();
262
263    let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
264    Box::new(PrimitiveArray::new(
265        T::PRIMITIVE.into(),
266        out.into(),
267        validity.map(|b| b.into()),
268    ))
269}
270
271#[cfg(test)]
272mod test {
273    use super::*;
274
275    #[test]
276    fn test_rolling_median() {
277        let values = &[1.0, 2.0, 3.0, 4.0];
278        let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
279            prob: 0.5,
280            method: Linear,
281        }));
282        let out = rolling_quantile(values, 2, 2, false, None, med_pars.clone()).unwrap();
283        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
284        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
285        assert_eq!(out, &[None, Some(1.5), Some(2.5), Some(3.5)]);
286
287        let out = rolling_quantile(values, 2, 1, false, None, med_pars.clone()).unwrap();
288        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
289        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
290        assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.5), Some(3.5)]);
291
292        let out = rolling_quantile(values, 4, 1, false, None, med_pars.clone()).unwrap();
293        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
294        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
295        assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.0), Some(2.5)]);
296
297        let out = rolling_quantile(values, 4, 1, true, None, med_pars.clone()).unwrap();
298        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
299        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
300        assert_eq!(out, &[Some(1.5), Some(2.0), Some(2.5), Some(3.0)]);
301
302        let out = rolling_quantile(values, 4, 4, true, None, med_pars.clone()).unwrap();
303        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
304        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
305        assert_eq!(out, &[None, None, Some(2.5), None]);
306    }
307
308    #[test]
309    fn test_rolling_quantile_limits() {
310        let values = &[1.0f64, 2.0, 3.0, 4.0];
311
312        let methods = vec![
313            QuantileMethod::Lower,
314            QuantileMethod::Higher,
315            QuantileMethod::Nearest,
316            QuantileMethod::Midpoint,
317            QuantileMethod::Linear,
318            QuantileMethod::Equiprobable,
319        ];
320
321        for method in methods {
322            let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
323                prob: 0.0,
324                method,
325            }));
326            let out1 = rolling_min(values, 2, 2, false, None, None).unwrap();
327            let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
328            let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
329            let out2 = rolling_quantile(values, 2, 2, false, None, min_pars).unwrap();
330            let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
331            let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
332            assert_eq!(out1, out2);
333
334            let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
335                prob: 1.0,
336                method,
337            }));
338            let out1 = rolling_max(values, 2, 2, false, None, None).unwrap();
339            let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
340            let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
341            let out2 = rolling_quantile(values, 2, 2, false, None, max_pars).unwrap();
342            let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
343            let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
344            assert_eq!(out1, out2);
345        }
346    }
347}