polars_arrow/legacy/kernels/rolling/nulls/
min_max.rs

1use super::*;
2use crate::array::iterator::NonNullValuesIter;
3use crate::bitmap::utils::count_zeros;
4
5pub fn is_reverse_sorted_max_nulls<T: NativeType>(values: &[T], validity: &Bitmap) -> bool {
6    let mut it = NonNullValuesIter::new(values, Some(validity));
7    let Some(mut prev) = it.next() else {
8        return true;
9    };
10    for v in it {
11        if prev.tot_lt(&v) {
12            return false;
13        }
14        prev = v
15    }
16
17    true
18}
19
20pub struct SortedMinMax<'a, T: NativeType> {
21    slice: &'a [T],
22    validity: &'a Bitmap,
23    last_start: usize,
24    last_end: usize,
25    null_count: usize,
26}
27
28impl<T: NativeType> SortedMinMax<'_, T> {
29    fn count_nulls(&self, start: usize, end: usize) -> usize {
30        let (bytes, offset, _) = self.validity.as_slice();
31        count_zeros(bytes, offset + start, end - start)
32    }
33}
34
35impl<'a, T: NativeType> RollingAggWindowNulls<'a, T> for SortedMinMax<'a, T> {
36    unsafe fn new(
37        slice: &'a [T],
38        validity: &'a Bitmap,
39        start: usize,
40        end: usize,
41        _params: Option<RollingFnParams>,
42    ) -> Self {
43        let mut out = Self {
44            slice,
45            validity,
46            last_start: start,
47            last_end: end,
48            null_count: 0,
49        };
50        let nulls = out.count_nulls(start, end);
51        out.null_count = nulls;
52        out
53    }
54
55    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
56        self.null_count -= self.count_nulls(self.last_start, start);
57        self.null_count += self.count_nulls(self.last_end, end);
58
59        self.last_start = start;
60        self.last_end = end;
61
62        // return first non null
63        for idx in start..end {
64            let valid = self.validity.get_bit_unchecked(idx);
65
66            if valid {
67                return Some(*self.slice.get_unchecked(idx));
68            }
69        }
70
71        None
72    }
73
74    fn is_valid(&self, min_periods: usize) -> bool {
75        ((self.last_end - self.last_start) - self.null_count) >= min_periods
76    }
77}
78
79/// Generic `Min` / `Max` kernel.
80pub struct MinMaxWindow<'a, T: NativeType + PartialOrd + IsFloat> {
81    slice: &'a [T],
82    validity: &'a Bitmap,
83    extremum: Option<T>,
84    last_start: usize,
85    last_end: usize,
86    null_count: usize,
87    is_better: fn(&T, &T) -> bool,
88    take_extremum: fn(T, T) -> T,
89    // ordering on which the window needs to act.
90    // for min kernel this is Less
91    // for max kernel this is Greater
92}
93
94impl<'a, T: NativeType + IsFloat + PartialOrd> MinMaxWindow<'a, T> {
95    unsafe fn compute_extremum_in_between_leaving_and_entering(&self, start: usize) -> Option<T> {
96        // check the values in between the window that remains e.g. is not leaving
97        // this between `start..last_end`
98        //
99        // because we know the current `min` (which might be leaving), we know we can stop
100        // searching if any value is equal to current `min`.
101        let mut extremum_in_between = None;
102        for idx in start..self.last_end {
103            let valid = self.validity.get_bit_unchecked(idx);
104            let value = self.slice.get_unchecked(idx);
105
106            if valid {
107                // early return
108                if let Some(current_min) = self.extremum {
109                    if value.tot_eq(&current_min) {
110                        return Some(current_min);
111                    }
112                }
113
114                match extremum_in_between {
115                    None => extremum_in_between = Some(*value),
116                    Some(current) => {
117                        extremum_in_between = Some((self.take_extremum)(*value, current))
118                    },
119                }
120            }
121        }
122        extremum_in_between
123    }
124
125    // compute min from the entire window
126    unsafe fn compute_extremum_and_update_null_count(
127        &mut self,
128        start: usize,
129        end: usize,
130    ) -> Option<T> {
131        let mut extremum = None;
132        let mut idx = start;
133        for value in &self.slice[start..end] {
134            let valid = self.validity.get_bit_unchecked(idx);
135            if valid {
136                match extremum {
137                    None => extremum = Some(*value),
138                    Some(current) => extremum = Some((self.take_extremum)(*value, current)),
139                }
140            } else {
141                self.null_count += 1;
142            }
143            idx += 1;
144        }
145        extremum
146    }
147
148    unsafe fn new(
149        slice: &'a [T],
150        validity: &'a Bitmap,
151        start: usize,
152        end: usize,
153        is_better: fn(&T, &T) -> bool,
154        take_extremum: fn(T, T) -> T,
155    ) -> Self {
156        let mut out = Self {
157            slice,
158            validity,
159            extremum: None,
160            last_start: start,
161            last_end: end,
162            null_count: 0,
163            is_better,
164            take_extremum,
165        };
166        let extremum = out.compute_extremum_and_update_null_count(start, end);
167        out.extremum = extremum;
168        out
169    }
170
171    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
172        // recompute min
173        if start >= self.last_end {
174            self.extremum = self.compute_extremum_and_update_null_count(start, end);
175            self.last_end = end;
176            self.last_start = start;
177            return self.extremum;
178        }
179
180        // remove elements that should leave the window
181        let mut recompute_extremum = false;
182        for idx in self.last_start..start {
183            // SAFETY:
184            // we are in bounds
185            let valid = self.validity.get_bit_unchecked(idx);
186            if valid {
187                let leaving_value = self.slice.get_unchecked(idx);
188
189                // if the leaving value is the
190                // min value, we need to recompute the min.
191                if leaving_value.tot_eq(&self.extremum.unwrap()) {
192                    recompute_extremum = true;
193                    break;
194                }
195            } else {
196                // null value leaving the window
197                self.null_count -= 1;
198
199                // self.min is None and the leaving value is None
200                // if the entering value is valid, we might get a new min.
201                if self.extremum.is_none() {
202                    recompute_extremum = true;
203                    break;
204                }
205            }
206        }
207
208        let entering_extremum = self.compute_extremum_and_update_null_count(self.last_end, end);
209
210        match (self.extremum, entering_extremum) {
211            // all remains `None`
212            (None, None) => {},
213            (None, Some(new_min)) => self.extremum = Some(new_min),
214            // entering min is `None` and the `min` is leaving, so the `in_between` min is the new
215            // minimum.
216            // if min is not leaving, we don't do anything
217            (Some(_current_min), None) => {
218                if recompute_extremum {
219                    self.extremum = self.compute_extremum_in_between_leaving_and_entering(start);
220                }
221            },
222            (Some(current_extremum), Some(entering_extremum)) => {
223                if (self.is_better)(&entering_extremum, &current_extremum) {
224                    self.extremum = Some(entering_extremum)
225                } else if recompute_extremum
226                    && (self.is_better)(&current_extremum, &entering_extremum)
227                {
228                    // leaving value could be the smallest, we might need to recompute
229                    let min_in_between =
230                        self.compute_extremum_in_between_leaving_and_entering(start);
231                    match min_in_between {
232                        None => self.extremum = Some(entering_extremum),
233                        Some(extremum_in_between) => {
234                            self.extremum =
235                                Some((self.take_extremum)(extremum_in_between, entering_extremum));
236                        },
237                    }
238                }
239            },
240        }
241        self.last_start = start;
242        self.last_end = end;
243        self.extremum
244    }
245
246    fn is_valid(&self, min_periods: usize) -> bool {
247        ((self.last_end - self.last_start) - self.null_count) >= min_periods
248    }
249}
250
251pub struct MinWindow<'a, T: NativeType + PartialOrd + IsFloat> {
252    inner: MinMaxWindow<'a, T>,
253}
254
255impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNulls<'a, T> for MinWindow<'a, T> {
256    unsafe fn new(
257        slice: &'a [T],
258        validity: &'a Bitmap,
259        start: usize,
260        end: usize,
261        _params: Option<RollingFnParams>,
262    ) -> Self {
263        Self {
264            inner: MinMaxWindow::new(
265                slice,
266                validity,
267                start,
268                end,
269                |a, b| a.nan_max_lt(b),
270                |a, b| a.min_ignore_nan(b),
271            ),
272        }
273    }
274
275    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
276        self.inner.update(start, end)
277    }
278
279    fn is_valid(&self, min_periods: usize) -> bool {
280        self.inner.is_valid(min_periods)
281    }
282}
283
284pub fn rolling_min<T>(
285    arr: &PrimitiveArray<T>,
286    window_size: usize,
287    min_periods: usize,
288    center: bool,
289    weights: Option<&[f64]>,
290    _params: Option<RollingFnParams>,
291) -> ArrayRef
292where
293    T: NativeType + std::iter::Sum + Zero + AddAssign + Copy + PartialOrd + Bounded + IsFloat,
294{
295    if weights.is_some() {
296        panic!("weights not yet supported on array with null values")
297    }
298    if center {
299        rolling_apply_agg_window::<MinWindow<_>, _, _>(
300            arr.values().as_slice(),
301            arr.validity().as_ref().unwrap(),
302            window_size,
303            min_periods,
304            det_offsets_center,
305            None,
306        )
307    } else {
308        rolling_apply_agg_window::<MinWindow<_>, _, _>(
309            arr.values().as_slice(),
310            arr.validity().as_ref().unwrap(),
311            window_size,
312            min_periods,
313            det_offsets,
314            None,
315        )
316    }
317}
318
319pub struct MaxWindow<'a, T: NativeType + PartialOrd + IsFloat> {
320    inner: MinMaxWindow<'a, T>,
321}
322
323impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNulls<'a, T> for MaxWindow<'a, T> {
324    unsafe fn new(
325        slice: &'a [T],
326        validity: &'a Bitmap,
327        start: usize,
328        end: usize,
329        _params: Option<RollingFnParams>,
330    ) -> Self {
331        Self {
332            inner: MinMaxWindow::new(
333                slice,
334                validity,
335                start,
336                end,
337                |a, b| b.nan_min_lt(a),
338                |a, b| a.max_ignore_nan(b),
339            ),
340        }
341    }
342
343    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
344        self.inner.update(start, end)
345    }
346
347    fn is_valid(&self, min_periods: usize) -> bool {
348        self.inner.is_valid(min_periods)
349    }
350}
351
352pub fn rolling_max<T>(
353    arr: &PrimitiveArray<T>,
354    window_size: usize,
355    min_periods: usize,
356    center: bool,
357    weights: Option<&[f64]>,
358    _params: Option<RollingFnParams>,
359) -> ArrayRef
360where
361    T: NativeType + std::iter::Sum + Zero + AddAssign + Copy + PartialOrd + Bounded + IsFloat,
362{
363    if weights.is_some() {
364        panic!("weights not yet supported on array with null values")
365    }
366    if center {
367        if is_reverse_sorted_max_nulls(arr.values().as_slice(), arr.validity().as_ref().unwrap()) {
368            rolling_apply_agg_window::<SortedMinMax<_>, _, _>(
369                arr.values().as_slice(),
370                arr.validity().as_ref().unwrap(),
371                window_size,
372                min_periods,
373                det_offsets_center,
374                None,
375            )
376        } else {
377            rolling_apply_agg_window::<MaxWindow<_>, _, _>(
378                arr.values().as_slice(),
379                arr.validity().as_ref().unwrap(),
380                window_size,
381                min_periods,
382                det_offsets_center,
383                None,
384            )
385        }
386    } else if is_reverse_sorted_max_nulls(arr.values().as_slice(), arr.validity().as_ref().unwrap())
387    {
388        rolling_apply_agg_window::<SortedMinMax<_>, _, _>(
389            arr.values().as_slice(),
390            arr.validity().as_ref().unwrap(),
391            window_size,
392            min_periods,
393            det_offsets,
394            None,
395        )
396    } else {
397        rolling_apply_agg_window::<MaxWindow<_>, _, _>(
398            arr.values().as_slice(),
399            arr.validity().as_ref().unwrap(),
400            window_size,
401            min_periods,
402            det_offsets,
403            None,
404        )
405    }
406}