1use super::*;
2
3#[inline]
4fn new_is_min<T: NativeType + IsFloat + PartialOrd>(old: &T, new: &T) -> bool {
5 compare_fn_nan_min(old, new).is_ge()
6}
7
8#[inline]
9fn new_is_max<T: NativeType + IsFloat + PartialOrd>(old: &T, new: &T) -> bool {
10 compare_fn_nan_max(old, new).is_le()
11}
12
13#[inline]
14unsafe fn get_min_and_idx<T>(
15 slice: &[T],
16 start: usize,
17 end: usize,
18 sorted_to: usize,
19) -> Option<(usize, &T)>
20where
21 T: NativeType + IsFloat + PartialOrd,
22{
23 if sorted_to >= end {
24 Some((start, slice.get_unchecked(start)))
27 } else if sorted_to <= start {
28 slice
31 .get_unchecked(start..end)
32 .iter()
33 .enumerate()
34 .rev()
35 .min_by(|&a, &b| compare_fn_nan_min(a.1, b.1))
36 .map(|v| (v.0 + start, v.1))
37 } else {
38 let s = (start, slice.get_unchecked(start));
40 slice
41 .get_unchecked(sorted_to..end)
42 .iter()
43 .enumerate()
44 .rev()
45 .min_by(|&a, &b| compare_fn_nan_min(a.1, b.1))
46 .map(|v| {
47 if new_is_min(s.1, v.1) {
48 (v.0 + sorted_to, v.1)
49 } else {
50 s
51 }
52 })
53 }
54}
55
56#[inline]
57unsafe fn get_max_and_idx<T>(
58 slice: &[T],
59 start: usize,
60 end: usize,
61 sorted_to: usize,
62) -> Option<(usize, &T)>
63where
64 T: NativeType + IsFloat + PartialOrd,
65{
66 if sorted_to >= end {
67 Some((start, slice.get_unchecked(start)))
68 } else if sorted_to <= start {
69 slice
70 .get_unchecked(start..end)
71 .iter()
72 .enumerate()
73 .max_by(|&a, &b| compare_fn_nan_max(a.1, b.1))
74 .map(|v| (v.0 + start, v.1))
75 } else {
76 let s = (start, slice.get_unchecked(start));
77 slice
78 .get_unchecked(sorted_to..end)
79 .iter()
80 .enumerate()
81 .max_by(|&a, &b| compare_fn_nan_max(a.1, b.1))
82 .map(|v| {
83 if new_is_max(s.1, v.1) {
84 (v.0 + sorted_to, v.1)
85 } else {
86 s
87 }
88 })
89 }
90}
91
92#[inline]
93fn n_sorted_past_min<T: NativeType + IsFloat + PartialOrd>(slice: &[T]) -> usize {
94 slice
95 .windows(2)
96 .position(|x| compare_fn_nan_min(&x[0], &x[1]).is_gt())
97 .unwrap_or(slice.len() - 1)
98}
99
100#[inline]
101fn n_sorted_past_max<T: NativeType + IsFloat + PartialOrd>(slice: &[T]) -> usize {
102 slice
103 .windows(2)
104 .position(|x| compare_fn_nan_max(&x[0], &x[1]).is_lt())
105 .unwrap_or(slice.len() - 1)
106}
107
108macro_rules! minmax_window {
111 ($m_window:tt, $get_m_and_idx:ident, $new_is_m:ident, $n_sorted_past:ident) => {
112 pub struct $m_window<'a, T: NativeType + PartialOrd + IsFloat> {
113 slice: &'a [T],
114 m: T,
115 m_idx: usize,
116 sorted_to: usize,
117 last_start: usize,
118 last_end: usize,
119 }
120
121 impl<'a, T: NativeType + IsFloat + PartialOrd> $m_window<'a, T> {
122 #[inline]
123 unsafe fn update_m_and_m_idx(&mut self, m_and_idx: (usize, &T)) {
124 self.m = *m_and_idx.1;
125 self.m_idx = m_and_idx.0;
126 if self.sorted_to <= self.m_idx {
127 self.sorted_to =
130 self.m_idx + 1 + $n_sorted_past(&self.slice.get_unchecked(self.m_idx..));
131 }
132 }
133 }
134
135 impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNoNulls<'a, T>
136 for $m_window<'a, T>
137 {
138 fn new(
139 slice: &'a [T],
140 start: usize,
141 end: usize,
142 _params: Option<RollingFnParams>,
143 ) -> Self {
144 let (idx, m) =
145 unsafe { $get_m_and_idx(slice, start, end, 0).unwrap_or((0, &slice[start])) };
146 Self {
147 slice,
148 m: *m,
149 m_idx: idx,
150 sorted_to: idx + 1 + $n_sorted_past(&slice[idx..]),
151 last_start: start,
152 last_end: end,
153 }
154 }
155
156 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
157 self.last_start = start; let old_last_end = self.last_end; self.last_end = end;
161 let entering_start = std::cmp::max(old_last_end, start);
162 let entering = if end - entering_start == 1 {
163 Some((entering_start, self.slice.get_unchecked(entering_start)))
165 } else if old_last_end == end {
166 None
168 } else {
169 $get_m_and_idx(self.slice, entering_start, end, self.sorted_to)
170 };
171 let empty_overlap = old_last_end <= start;
172
173 if entering.map(|em| $new_is_m(&self.m, em.1) || empty_overlap) == Some(true) {
174 self.update_m_and_m_idx(entering.unwrap());
176 return Some(self.m);
177 } else if self.m_idx >= start || empty_overlap {
178 return Some(self.m);
180 }
181 match (
183 $get_m_and_idx(self.slice, start, old_last_end, self.sorted_to),
184 entering,
185 ) {
186 (Some(pm), Some(em)) => {
187 if $new_is_m(pm.1, em.1) {
188 self.update_m_and_m_idx(em);
189 } else {
190 self.update_m_and_m_idx(pm);
191 }
192 },
193 (Some(pm), None) => self.update_m_and_m_idx(pm),
194 (None, Some(em)) => self.update_m_and_m_idx(em),
195 (None, None) => unreachable!(),
197 }
198
199 Some(self.m)
200 }
201 }
202 };
203}
204
205minmax_window!(MinWindow, get_min_and_idx, new_is_min, n_sorted_past_min);
206minmax_window!(MaxWindow, get_max_and_idx, new_is_max, n_sorted_past_max);
207
208pub(crate) fn compute_min_weights<T>(values: &[T], weights: &[T]) -> T
209where
210 T: NativeType + PartialOrd + std::ops::Mul<Output = T>,
211{
212 values
213 .iter()
214 .zip(weights)
215 .map(|(v, w)| *v * *w)
216 .min_by(|a, b| a.partial_cmp(b).unwrap())
217 .unwrap()
218}
219
220pub(crate) fn compute_max_weights<T>(values: &[T], weights: &[T]) -> T
221where
222 T: NativeType + PartialOrd + IsFloat + Bounded + Mul<Output = T>,
223{
224 let mut max = T::min_value();
225 for v in values.iter().zip(weights).map(|(v, w)| *v * *w) {
226 if T::is_float() && v.is_nan() {
227 return v;
228 }
229 if v > max {
230 max = v
231 }
232 }
233
234 max
235}
236
237macro_rules! rolling_minmax_func {
239 ($rolling_m:ident, $window:tt, $wtd_f:ident) => {
240 pub fn $rolling_m<T>(
241 values: &[T],
242 window_size: usize,
243 min_periods: usize,
244 center: bool,
245 weights: Option<&[f64]>,
246 _params: Option<RollingFnParams>,
247 ) -> PolarsResult<ArrayRef>
248 where
249 T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul<Output = T> + Num,
250 {
251 let offset_fn = match center {
252 true => det_offsets_center,
253 false => det_offsets,
254 };
255 match weights {
256 None => rolling_apply_agg_window::<$window<_>, _, _>(
257 values,
258 window_size,
259 min_periods,
260 offset_fn,
261 None,
262 ),
263 Some(weights) => {
264 assert!(
265 T::is_float(),
266 "implementation error, should only be reachable by float types"
267 );
268 let weights = weights
269 .iter()
270 .map(|v| NumCast::from(*v).unwrap())
271 .collect::<Vec<_>>();
272 no_nulls::rolling_apply_weights(
273 values,
274 window_size,
275 min_periods,
276 offset_fn,
277 $wtd_f,
278 &weights,
279 )
280 },
281 }
282 }
283 };
284}
285
286rolling_minmax_func!(rolling_min, MinWindow, compute_min_weights);
287rolling_minmax_func!(rolling_max, MaxWindow, compute_max_weights);
288
289#[cfg(test)]
290mod test {
291 use super::*;
292
293 #[test]
294 fn test_rolling_min_max() {
295 let values = &[1.0f64, 5.0, 3.0, 4.0];
296
297 let out = rolling_min(values, 2, 2, false, None, None).unwrap();
298 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
299 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
300 assert_eq!(out, &[None, Some(1.0), Some(3.0), Some(3.0)]);
301 let out = rolling_max(values, 2, 2, false, None, None).unwrap();
302 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
303 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
304 assert_eq!(out, &[None, Some(5.0), Some(5.0), Some(4.0)]);
305
306 let out = rolling_min(values, 2, 1, false, None, None).unwrap();
307 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
308 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
309 assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.0)]);
310 let out = rolling_max(values, 2, 1, false, None, None).unwrap();
311 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
312 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
313 assert_eq!(out, &[Some(1.0), Some(5.0), Some(5.0), Some(4.0)]);
314
315 let out = rolling_max(values, 3, 1, false, None, None).unwrap();
316 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
317 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
318 assert_eq!(out, &[Some(1.0), Some(5.0), Some(5.0), Some(5.0)]);
319
320 let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
322 let out = rolling_min(values, 3, 3, false, None, None).unwrap();
323 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
324 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
325 assert_eq!(
327 format!("{:?}", out.as_slice()),
328 format!(
329 "{:?}",
330 &[
331 None,
332 None,
333 Some(1.0),
334 Some(f64::nan()),
335 Some(f64::nan()),
336 Some(f64::nan()),
337 Some(5.0)
338 ]
339 )
340 );
341
342 let out = rolling_max(values, 3, 3, false, None, None).unwrap();
343 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
344 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
345 assert_eq!(
346 format!("{:?}", out.as_slice()),
347 format!(
348 "{:?}",
349 &[
350 None,
351 None,
352 Some(3.0),
353 Some(f64::nan()),
354 Some(f64::nan()),
355 Some(f64::nan()),
356 Some(7.0)
357 ]
358 )
359 );
360 }
361}