polars_arrow/legacy/kernels/rolling/nulls/
quantile.rs1use super::*;
2use crate::array::MutablePrimitiveArray;
3
4pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {
5 sorted: SortedBufNulls<'a, T>,
6 prob: f64,
7 method: QuantileMethod,
8}
9
10impl<
11 'a,
12 T: NativeType
13 + IsFloat
14 + Float
15 + std::iter::Sum
16 + AddAssign
17 + SubAssign
18 + Div<Output = T>
19 + NumCast
20 + One
21 + Zero
22 + PartialOrd
23 + Sub<Output = T>,
24 > RollingAggWindowNulls<'a, T> for QuantileWindow<'a, T>
25{
26 unsafe fn new(
27 slice: &'a [T],
28 validity: &'a Bitmap,
29 start: usize,
30 end: usize,
31 params: Option<RollingFnParams>,
32 ) -> Self {
33 let params = params.unwrap();
34 let RollingFnParams::Quantile(params) = params else {
35 unreachable!("expected Quantile params");
36 };
37 Self {
38 sorted: SortedBufNulls::new(slice, validity, start, end),
39 prob: params.prob,
40 method: params.method,
41 }
42 }
43
44 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
45 let (values, null_count) = self.sorted.update(start, end);
46 if null_count == values.len() {
48 return None;
49 }
50 let values = &values[null_count..];
52 let length = values.len();
53
54 let mut idx = match self.method {
55 QuantileMethod::Nearest => ((length as f64) * self.prob) as usize,
56 QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => {
57 ((length as f64 - 1.0) * self.prob).floor() as usize
58 },
59 QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,
60 QuantileMethod::Equiprobable => {
61 ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize
62 },
63 };
64
65 idx = std::cmp::min(idx, length - 1);
66
67 match self.method {
69 QuantileMethod::Midpoint => {
70 let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
71 Some(
72 (values.get_unchecked(idx).unwrap() + values.get_unchecked(top_idx).unwrap())
73 / T::from::<f64>(2.0f64).unwrap(),
74 )
75 },
76 QuantileMethod::Linear => {
77 let float_idx = (length as f64 - 1.0) * self.prob;
78 let top_idx = f64::ceil(float_idx) as usize;
79
80 if top_idx == idx {
81 Some(values.get_unchecked(idx).unwrap())
82 } else {
83 let proportion = T::from(float_idx - idx as f64).unwrap();
84 Some(
85 proportion
86 * (values.get_unchecked(top_idx).unwrap()
87 - values.get_unchecked(idx).unwrap())
88 + values.get_unchecked(idx).unwrap(),
89 )
90 }
91 },
92 _ => Some(values.get_unchecked(idx).unwrap()),
93 }
94 }
95
96 fn is_valid(&self, min_periods: usize) -> bool {
97 self.sorted.is_valid(min_periods)
98 }
99}
100
101pub fn rolling_quantile<T>(
102 arr: &PrimitiveArray<T>,
103 window_size: usize,
104 min_periods: usize,
105 center: bool,
106 weights: Option<&[f64]>,
107 params: Option<RollingFnParams>,
108) -> ArrayRef
109where
110 T: NativeType
111 + IsFloat
112 + Float
113 + std::iter::Sum
114 + AddAssign
115 + SubAssign
116 + Div<Output = T>
117 + NumCast
118 + One
119 + Zero
120 + PartialOrd
121 + Sub<Output = T>,
122{
123 if weights.is_some() {
124 panic!("weights not yet supported on array with null values")
125 }
126 let offset_fn = match center {
127 true => det_offsets_center,
128 false => det_offsets,
129 };
130 if !center {
131 let params = params.as_ref().unwrap();
132 let RollingFnParams::Quantile(params) = params else {
133 unreachable!("expected Quantile params");
134 };
135
136 let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>(
137 params.method,
138 min_periods,
139 window_size,
140 arr.clone(),
141 params.prob,
142 );
143 let out: PrimitiveArray<T> = out.into();
144 return Box::new(out);
145 }
146 rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
147 arr.values().as_slice(),
148 arr.validity().as_ref().unwrap(),
149 window_size,
150 min_periods,
151 offset_fn,
152 params,
153 )
154}
155
156#[cfg(test)]
157mod test {
158 use super::*;
159 use crate::buffer::Buffer;
160 use crate::datatypes::ArrowDataType;
161
162 #[test]
163 fn test_rolling_median_nulls() {
164 let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
165 let arr = &PrimitiveArray::new(
166 ArrowDataType::Float64,
167 buf,
168 Some(Bitmap::from(&[true, false, true, true])),
169 );
170 let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
171 prob: 0.5,
172 method: QuantileMethod::Linear,
173 }));
174
175 let out = rolling_quantile(arr, 2, 2, false, None, med_pars.clone());
176 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
177 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
178 assert_eq!(out, &[None, None, None, Some(3.5)]);
179
180 let out = rolling_quantile(arr, 2, 1, false, None, med_pars.clone());
181 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
182 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
183 assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.5)]);
184
185 let out = rolling_quantile(arr, 4, 1, false, None, med_pars.clone());
186 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
187 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
188 assert_eq!(out, &[Some(1.0), Some(1.0), Some(2.0), Some(3.0)]);
189
190 let out = rolling_quantile(arr, 4, 1, true, None, med_pars.clone());
191 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
192 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
193 assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(3.5)]);
194
195 let out = rolling_quantile(arr, 4, 4, true, None, med_pars.clone());
196 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
197 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
198 assert_eq!(out, &[None, None, None, None]);
199 }
200
201 #[test]
202 fn test_rolling_quantile_nulls_limits() {
203 let buf = Buffer::<f64>::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
205 let values = &PrimitiveArray::new(
206 ArrowDataType::Float64,
207 buf,
208 Some(Bitmap::from(&[true, false, false, true, true])),
209 );
210
211 let methods = vec![
212 QuantileMethod::Lower,
213 QuantileMethod::Higher,
214 QuantileMethod::Nearest,
215 QuantileMethod::Midpoint,
216 QuantileMethod::Linear,
217 QuantileMethod::Equiprobable,
218 ];
219
220 for method in methods {
221 let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
222 prob: 0.0,
223 method,
224 }));
225 let out1 = rolling_min(values, 2, 1, false, None, None);
226 let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
227 let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
228 let out2 = rolling_quantile(values, 2, 1, false, None, min_pars);
229 let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
230 let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
231 assert_eq!(out1, out2);
232
233 let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
234 prob: 1.0,
235 method,
236 }));
237 let out1 = rolling_max(values, 2, 1, false, None, None);
238 let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
239 let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
240 let out2 = rolling_quantile(values, 2, 1, false, None, max_pars);
241 let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
242 let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
243 assert_eq!(out1, out2);
244 }
245 }
246}