1use num_traits::ToPrimitive;
2use polars_error::polars_ensure;
3
4use super::QuantileMethod::*;
5use super::*;
6
7pub struct QuantileWindow<'a, T: NativeType> {
8 sorted: SortedBuf<'a, T>,
9 prob: f64,
10 method: QuantileMethod,
11}
12
13impl<
14 'a,
15 T: NativeType
16 + Float
17 + std::iter::Sum
18 + AddAssign
19 + SubAssign
20 + Div<Output = T>
21 + NumCast
22 + One
23 + Zero
24 + Sub<Output = T>,
25 > RollingAggWindowNoNulls<'a, T> for QuantileWindow<'a, T>
26{
27 fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self {
28 let params = params.unwrap();
29 let RollingFnParams::Quantile(params) = params else {
30 unreachable!("expected Quantile params");
31 };
32
33 Self {
34 sorted: SortedBuf::new(slice, start, end),
35 prob: params.prob,
36 method: params.method,
37 }
38 }
39
40 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
41 let vals = self.sorted.update(start, end);
42 let length = vals.len();
43
44 let idx = match self.method {
45 Linear => {
46 let length_f = length as f64;
48 let idx = ((length_f - 1.0) * self.prob).floor() as usize;
49
50 let float_idx_top = (length_f - 1.0) * self.prob;
51 let top_idx = float_idx_top.ceil() as usize;
52 return if idx == top_idx {
53 Some(unsafe { *vals.get_unchecked(idx) })
54 } else {
55 let proportion = T::from(float_idx_top - idx as f64).unwrap();
56 let vi = unsafe { *vals.get_unchecked(idx) };
57 let vj = unsafe { *vals.get_unchecked(top_idx) };
58
59 Some(proportion * (vj - vi) + vi)
60 };
61 },
62 Midpoint => {
63 let length_f = length as f64;
64 let idx = (length_f * self.prob) as usize;
65 let idx = std::cmp::min(idx, length - 1);
66
67 let top_idx = ((length_f - 1.0) * self.prob).ceil() as usize;
68 return if top_idx == idx {
69 Some(unsafe { *vals.get_unchecked(idx) })
72 } else {
73 let (mid, mid_plus_1) =
76 unsafe { (*vals.get_unchecked(idx), *vals.get_unchecked(idx + 1)) };
77
78 Some((mid + mid_plus_1) / (T::one() + T::one()))
79 };
80 },
81 Nearest => {
82 let idx = ((length as f64) * self.prob) as usize;
83 std::cmp::min(idx, length - 1)
84 },
85 Lower => ((length as f64 - 1.0) * self.prob).floor() as usize,
86 Higher => {
87 let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
88 std::cmp::min(idx, length - 1)
89 },
90 Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize,
91 };
92
93 Some(unsafe { *vals.get_unchecked(idx) })
96 }
97}
98
99pub fn rolling_quantile<T>(
100 values: &[T],
101 window_size: usize,
102 min_periods: usize,
103 center: bool,
104 weights: Option<&[f64]>,
105 params: Option<RollingFnParams>,
106) -> PolarsResult<ArrayRef>
107where
108 T: NativeType
109 + IsFloat
110 + Float
111 + std::iter::Sum
112 + AddAssign
113 + SubAssign
114 + Div<Output = T>
115 + NumCast
116 + One
117 + Zero
118 + PartialOrd
119 + Sub<Output = T>,
120{
121 let offset_fn = match center {
122 true => det_offsets_center,
123 false => det_offsets,
124 };
125 match weights {
126 None => {
127 if !center {
128 let params = params.as_ref().unwrap();
129 let RollingFnParams::Quantile(params) = params else {
130 unreachable!("expected Quantile params");
131 };
132 let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>(
133 params.method,
134 min_periods,
135 window_size,
136 values,
137 params.prob,
138 );
139 let validity = create_validity(min_periods, values.len(), window_size, offset_fn);
140 return Ok(Box::new(PrimitiveArray::new(
141 T::PRIMITIVE.into(),
142 out.into(),
143 validity.map(|b| b.into()),
144 )));
145 }
146
147 rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
148 values,
149 window_size,
150 min_periods,
151 offset_fn,
152 params,
153 )
154 },
155 Some(weights) => {
156 let wsum = weights.iter().sum();
157 polars_ensure!(
158 wsum != 0.0,
159 ComputeError: "Weighted quantile is undefined if weights sum to 0"
160 );
161 let params = params.unwrap();
162 let RollingFnParams::Quantile(params) = params else {
163 unreachable!("expected Quantile params");
164 };
165
166 Ok(rolling_apply_weighted_quantile(
167 values,
168 params.prob,
169 params.method,
170 window_size,
171 min_periods,
172 offset_fn,
173 weights,
174 wsum,
175 ))
176 },
177 }
178}
179
180#[inline]
181fn compute_wq<T>(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T
182where
183 T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
184{
185 let (mut s, mut s_old, mut vk, mut v_old) = (0.0, 0.0, T::zero(), T::zero());
189
190 let h: f64 = p * (wsum - buf[0].1) + buf[0].1;
193 for &(v, w) in buf.iter() {
194 if s > h {
195 break;
196 }
197 (s_old, v_old, vk) = (s, vk, v);
198 s += w;
199 }
200 match (h == s_old, method) {
201 (true, _) => v_old, (_, Lower) => v_old,
203 (_, Higher) => vk,
204 (_, Nearest) => {
205 if s - h > h - s_old {
206 v_old
207 } else {
208 vk
209 }
210 },
211 (_, Equiprobable) => {
212 let threshold = (wsum * p).ceil() - 1.0;
213 if s > threshold {
214 vk
215 } else {
216 v_old
217 }
218 },
219 (_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(),
220 (_, Linear) => {
222 v_old + <T as NumCast>::from((h - s_old) / (s - s_old)).unwrap() * (vk - v_old)
223 },
224 }
225}
226
227#[allow(clippy::too_many_arguments)]
228fn rolling_apply_weighted_quantile<T, Fo>(
229 values: &[T],
230 p: f64,
231 method: QuantileMethod,
232 window_size: usize,
233 min_periods: usize,
234 det_offsets_fn: Fo,
235 weights: &[f64],
236 wsum: f64,
237) -> ArrayRef
238where
239 Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
240 T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
241{
242 assert_eq!(weights.len(), window_size);
243 let nz_idx_wts: Vec<_> = weights.iter().enumerate().filter(|x| x.1 != &0.0).collect();
245 let mut buf = vec![(T::zero(), 0.0); nz_idx_wts.len()];
246 let len = values.len();
247 let out = (0..len)
248 .map(|idx| {
249 let (start, _) = det_offsets_fn(idx, window_size, len);
251
252 unsafe {
254 buf.iter_mut()
255 .zip(nz_idx_wts.iter())
256 .for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w));
257 }
258 buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0));
259 compute_wq(&buf, p, wsum, method)
260 })
261 .collect_trusted::<Vec<T>>();
262
263 let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
264 Box::new(PrimitiveArray::new(
265 T::PRIMITIVE.into(),
266 out.into(),
267 validity.map(|b| b.into()),
268 ))
269}
270
271#[cfg(test)]
272mod test {
273 use super::*;
274
275 #[test]
276 fn test_rolling_median() {
277 let values = &[1.0, 2.0, 3.0, 4.0];
278 let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
279 prob: 0.5,
280 method: Linear,
281 }));
282 let out = rolling_quantile(values, 2, 2, false, None, med_pars.clone()).unwrap();
283 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
284 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
285 assert_eq!(out, &[None, Some(1.5), Some(2.5), Some(3.5)]);
286
287 let out = rolling_quantile(values, 2, 1, false, None, med_pars.clone()).unwrap();
288 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
289 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
290 assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.5), Some(3.5)]);
291
292 let out = rolling_quantile(values, 4, 1, false, None, med_pars.clone()).unwrap();
293 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
294 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
295 assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.0), Some(2.5)]);
296
297 let out = rolling_quantile(values, 4, 1, true, None, med_pars.clone()).unwrap();
298 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
299 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
300 assert_eq!(out, &[Some(1.5), Some(2.0), Some(2.5), Some(3.0)]);
301
302 let out = rolling_quantile(values, 4, 4, true, None, med_pars.clone()).unwrap();
303 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
304 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
305 assert_eq!(out, &[None, None, Some(2.5), None]);
306 }
307
308 #[test]
309 fn test_rolling_quantile_limits() {
310 let values = &[1.0f64, 2.0, 3.0, 4.0];
311
312 let methods = vec![
313 QuantileMethod::Lower,
314 QuantileMethod::Higher,
315 QuantileMethod::Nearest,
316 QuantileMethod::Midpoint,
317 QuantileMethod::Linear,
318 QuantileMethod::Equiprobable,
319 ];
320
321 for method in methods {
322 let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
323 prob: 0.0,
324 method,
325 }));
326 let out1 = rolling_min(values, 2, 2, false, None, None).unwrap();
327 let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
328 let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
329 let out2 = rolling_quantile(values, 2, 2, false, None, min_pars).unwrap();
330 let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
331 let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
332 assert_eq!(out1, out2);
333
334 let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
335 prob: 1.0,
336 method,
337 }));
338 let out1 = rolling_max(values, 2, 2, false, None, None).unwrap();
339 let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
340 let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
341 let out2 = rolling_quantile(values, 2, 2, false, None, max_pars).unwrap();
342 let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
343 let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
344 assert_eq!(out1, out2);
345 }
346 }
347}