polars_arrow/legacy/kernels/rolling/no_nulls/
min_max.rs

1use super::*;
2
3#[inline]
4fn new_is_min<T: NativeType + IsFloat + PartialOrd>(old: &T, new: &T) -> bool {
5    compare_fn_nan_min(old, new).is_ge()
6}
7
8#[inline]
9fn new_is_max<T: NativeType + IsFloat + PartialOrd>(old: &T, new: &T) -> bool {
10    compare_fn_nan_max(old, new).is_le()
11}
12
13#[inline]
14unsafe fn get_min_and_idx<T>(
15    slice: &[T],
16    start: usize,
17    end: usize,
18    sorted_to: usize,
19) -> Option<(usize, &T)>
20where
21    T: NativeType + IsFloat + PartialOrd,
22{
23    if sorted_to >= end {
24        // If we're sorted past the end we can just take the first element because this function
25        // won't be called on intervals that contain the previous min
26        Some((start, slice.get_unchecked(start)))
27    } else if sorted_to <= start {
28        // We have to inspect the whole range
29        // Reversed because min_by returns the first min if there's a tie but we want the last
30        slice
31            .get_unchecked(start..end)
32            .iter()
33            .enumerate()
34            .rev()
35            .min_by(|&a, &b| compare_fn_nan_min(a.1, b.1))
36            .map(|v| (v.0 + start, v.1))
37    } else {
38        // It's sorted in range start..sorted_to. Compare slice[start] to min over sorted_to..end
39        let s = (start, slice.get_unchecked(start));
40        slice
41            .get_unchecked(sorted_to..end)
42            .iter()
43            .enumerate()
44            .rev()
45            .min_by(|&a, &b| compare_fn_nan_min(a.1, b.1))
46            .map(|v| {
47                if new_is_min(s.1, v.1) {
48                    (v.0 + sorted_to, v.1)
49                } else {
50                    s
51                }
52            })
53    }
54}
55
56#[inline]
57unsafe fn get_max_and_idx<T>(
58    slice: &[T],
59    start: usize,
60    end: usize,
61    sorted_to: usize,
62) -> Option<(usize, &T)>
63where
64    T: NativeType + IsFloat + PartialOrd,
65{
66    if sorted_to >= end {
67        Some((start, slice.get_unchecked(start)))
68    } else if sorted_to <= start {
69        slice
70            .get_unchecked(start..end)
71            .iter()
72            .enumerate()
73            .max_by(|&a, &b| compare_fn_nan_max(a.1, b.1))
74            .map(|v| (v.0 + start, v.1))
75    } else {
76        let s = (start, slice.get_unchecked(start));
77        slice
78            .get_unchecked(sorted_to..end)
79            .iter()
80            .enumerate()
81            .max_by(|&a, &b| compare_fn_nan_max(a.1, b.1))
82            .map(|v| {
83                if new_is_max(s.1, v.1) {
84                    (v.0 + sorted_to, v.1)
85                } else {
86                    s
87                }
88            })
89    }
90}
91
92#[inline]
93fn n_sorted_past_min<T: NativeType + IsFloat + PartialOrd>(slice: &[T]) -> usize {
94    slice
95        .windows(2)
96        .position(|x| compare_fn_nan_min(&x[0], &x[1]).is_gt())
97        .unwrap_or(slice.len() - 1)
98}
99
100#[inline]
101fn n_sorted_past_max<T: NativeType + IsFloat + PartialOrd>(slice: &[T]) -> usize {
102    slice
103        .windows(2)
104        .position(|x| compare_fn_nan_max(&x[0], &x[1]).is_lt())
105        .unwrap_or(slice.len() - 1)
106}
107
108// Min and max really are the same thing up to a difference in comparison direction, as represented
109// here by helpers we pass in. Making both with a macro helps keep behavior synchronized
110macro_rules! minmax_window {
111    ($m_window:tt, $get_m_and_idx:ident, $new_is_m:ident, $n_sorted_past:ident) => {
112        pub struct $m_window<'a, T: NativeType + PartialOrd + IsFloat> {
113            slice: &'a [T],
114            m: T,
115            m_idx: usize,
116            sorted_to: usize,
117            last_start: usize,
118            last_end: usize,
119        }
120
121        impl<'a, T: NativeType + IsFloat + PartialOrd> $m_window<'a, T> {
122            #[inline]
123            unsafe fn update_m_and_m_idx(&mut self, m_and_idx: (usize, &T)) {
124                self.m = *m_and_idx.1;
125                self.m_idx = m_and_idx.0;
126                if self.sorted_to <= self.m_idx {
127                    // Track how far past the current extremum values are sorted. Direction depends on min/max
128                    // Tracking sorted ranges lets us only do comparisons when we have to.
129                    self.sorted_to =
130                        self.m_idx + 1 + $n_sorted_past(&self.slice.get_unchecked(self.m_idx..));
131                }
132            }
133        }
134
135        impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNoNulls<'a, T>
136            for $m_window<'a, T>
137        {
138            fn new(
139                slice: &'a [T],
140                start: usize,
141                end: usize,
142                _params: Option<RollingFnParams>,
143            ) -> Self {
144                let (idx, m) =
145                    unsafe { $get_m_and_idx(slice, start, end, 0).unwrap_or((0, &slice[start])) };
146                Self {
147                    slice,
148                    m: *m,
149                    m_idx: idx,
150                    sorted_to: idx + 1 + $n_sorted_past(&slice[idx..]),
151                    last_start: start,
152                    last_end: end,
153                }
154            }
155
156            unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
157                //For details see: https://github.com/pola-rs/polars/pull/9277#issuecomment-1581401692
158                self.last_start = start; // Don't care where the last one started
159                let old_last_end = self.last_end; // But we need this
160                self.last_end = end;
161                let entering_start = std::cmp::max(old_last_end, start);
162                let entering = if end - entering_start == 1 {
163                    // Faster in the special, but common, case of a fixed window rolling by one
164                    Some((entering_start, self.slice.get_unchecked(entering_start)))
165                } else if old_last_end == end {
166                    // Edge case for shrinking windows
167                    None
168                } else {
169                    $get_m_and_idx(self.slice, entering_start, end, self.sorted_to)
170                };
171                let empty_overlap = old_last_end <= start;
172
173                if entering.map(|em| $new_is_m(&self.m, em.1) || empty_overlap) == Some(true) {
174                    // The entering extremum "beats" the previous extremum so we can ignore the overlap
175                    self.update_m_and_m_idx(entering.unwrap());
176                    return Some(self.m);
177                } else if self.m_idx >= start || empty_overlap {
178                    // The previous extremum didn't drop off. Keep it
179                    return Some(self.m);
180                }
181                // Otherwise get the min of the overlapping window and the entering min
182                match (
183                    $get_m_and_idx(self.slice, start, old_last_end, self.sorted_to),
184                    entering,
185                ) {
186                    (Some(pm), Some(em)) => {
187                        if $new_is_m(pm.1, em.1) {
188                            self.update_m_and_m_idx(em);
189                        } else {
190                            self.update_m_and_m_idx(pm);
191                        }
192                    },
193                    (Some(pm), None) => self.update_m_and_m_idx(pm),
194                    (None, Some(em)) => self.update_m_and_m_idx(em),
195                    // This would mean both the entering and previous windows are empty
196                    (None, None) => unreachable!(),
197                }
198
199                Some(self.m)
200            }
201        }
202    };
203}
204
205minmax_window!(MinWindow, get_min_and_idx, new_is_min, n_sorted_past_min);
206minmax_window!(MaxWindow, get_max_and_idx, new_is_max, n_sorted_past_max);
207
208pub(crate) fn compute_min_weights<T>(values: &[T], weights: &[T]) -> T
209where
210    T: NativeType + PartialOrd + std::ops::Mul<Output = T>,
211{
212    values
213        .iter()
214        .zip(weights)
215        .map(|(v, w)| *v * *w)
216        .min_by(|a, b| a.partial_cmp(b).unwrap())
217        .unwrap()
218}
219
220pub(crate) fn compute_max_weights<T>(values: &[T], weights: &[T]) -> T
221where
222    T: NativeType + PartialOrd + IsFloat + Bounded + Mul<Output = T>,
223{
224    let mut max = T::min_value();
225    for v in values.iter().zip(weights).map(|(v, w)| *v * *w) {
226        if T::is_float() && v.is_nan() {
227            return v;
228        }
229        if v > max {
230            max = v
231        }
232    }
233
234    max
235}
236
237// Same as the window definition. The dispatch is identical up to the name.
238macro_rules! rolling_minmax_func {
239    ($rolling_m:ident, $window:tt, $wtd_f:ident) => {
240        pub fn $rolling_m<T>(
241            values: &[T],
242            window_size: usize,
243            min_periods: usize,
244            center: bool,
245            weights: Option<&[f64]>,
246            _params: Option<RollingFnParams>,
247        ) -> PolarsResult<ArrayRef>
248        where
249            T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul<Output = T> + Num,
250        {
251            let offset_fn = match center {
252                true => det_offsets_center,
253                false => det_offsets,
254            };
255            match weights {
256                None => rolling_apply_agg_window::<$window<_>, _, _>(
257                    values,
258                    window_size,
259                    min_periods,
260                    offset_fn,
261                    None,
262                ),
263                Some(weights) => {
264                    assert!(
265                        T::is_float(),
266                        "implementation error, should only be reachable by float types"
267                    );
268                    let weights = weights
269                        .iter()
270                        .map(|v| NumCast::from(*v).unwrap())
271                        .collect::<Vec<_>>();
272                    no_nulls::rolling_apply_weights(
273                        values,
274                        window_size,
275                        min_periods,
276                        offset_fn,
277                        $wtd_f,
278                        &weights,
279                    )
280                },
281            }
282        }
283    };
284}
285
286rolling_minmax_func!(rolling_min, MinWindow, compute_min_weights);
287rolling_minmax_func!(rolling_max, MaxWindow, compute_max_weights);
288
289#[cfg(test)]
290mod test {
291    use super::*;
292
293    #[test]
294    fn test_rolling_min_max() {
295        let values = &[1.0f64, 5.0, 3.0, 4.0];
296
297        let out = rolling_min(values, 2, 2, false, None, None).unwrap();
298        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
299        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
300        assert_eq!(out, &[None, Some(1.0), Some(3.0), Some(3.0)]);
301        let out = rolling_max(values, 2, 2, false, None, None).unwrap();
302        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
303        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
304        assert_eq!(out, &[None, Some(5.0), Some(5.0), Some(4.0)]);
305
306        let out = rolling_min(values, 2, 1, false, None, None).unwrap();
307        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
308        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
309        assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.0)]);
310        let out = rolling_max(values, 2, 1, false, None, None).unwrap();
311        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
312        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
313        assert_eq!(out, &[Some(1.0), Some(5.0), Some(5.0), Some(4.0)]);
314
315        let out = rolling_max(values, 3, 1, false, None, None).unwrap();
316        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
317        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
318        assert_eq!(out, &[Some(1.0), Some(5.0), Some(5.0), Some(5.0)]);
319
320        // test nan handling.
321        let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
322        let out = rolling_min(values, 3, 3, false, None, None).unwrap();
323        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
324        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
325        // we cannot compare nans, so we compare the string values
326        assert_eq!(
327            format!("{:?}", out.as_slice()),
328            format!(
329                "{:?}",
330                &[
331                    None,
332                    None,
333                    Some(1.0),
334                    Some(f64::nan()),
335                    Some(f64::nan()),
336                    Some(f64::nan()),
337                    Some(5.0)
338                ]
339            )
340        );
341
342        let out = rolling_max(values, 3, 3, false, None, None).unwrap();
343        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
344        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
345        assert_eq!(
346            format!("{:?}", out.as_slice()),
347            format!(
348                "{:?}",
349                &[
350                    None,
351                    None,
352                    Some(3.0),
353                    Some(f64::nan()),
354                    Some(f64::nan()),
355                    Some(f64::nan()),
356                    Some(7.0)
357                ]
358            )
359        );
360    }
361}