polars_arrow/legacy/kernels/rolling/nulls/
sum.rs

1use super::*;
2
3pub struct SumWindow<'a, T> {
4    slice: &'a [T],
5    validity: &'a Bitmap,
6    sum: Option<T>,
7    last_start: usize,
8    last_end: usize,
9    pub(super) null_count: usize,
10}
11
12impl<T: NativeType + IsFloat + Add<Output = T> + Sub<Output = T>> SumWindow<'_, T> {
13    // compute sum from the entire window
14    unsafe fn compute_sum_and_null_count(&mut self, start: usize, end: usize) -> Option<T> {
15        let mut sum = None;
16        let mut idx = start;
17        self.null_count = 0;
18        for value in &self.slice[start..end] {
19            let valid = self.validity.get_bit_unchecked(idx);
20            if valid {
21                match sum {
22                    None => sum = Some(*value),
23                    Some(current) => sum = Some(*value + current),
24                }
25            } else {
26                self.null_count += 1;
27            }
28            idx += 1;
29        }
30        self.sum = sum;
31        sum
32    }
33}
34
35impl<'a, T: NativeType + IsFloat + Add<Output = T> + Sub<Output = T>> RollingAggWindowNulls<'a, T>
36    for SumWindow<'a, T>
37{
38    unsafe fn new(
39        slice: &'a [T],
40        validity: &'a Bitmap,
41        start: usize,
42        end: usize,
43        _params: Option<RollingFnParams>,
44    ) -> Self {
45        let mut out = Self {
46            slice,
47            validity,
48            sum: None,
49            last_start: start,
50            last_end: end,
51            null_count: 0,
52        };
53        out.compute_sum_and_null_count(start, end);
54        out
55    }
56
57    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
58        // if we exceed the end, we have a completely new window
59        // so we recompute
60        let recompute_sum = if start >= self.last_end {
61            true
62        } else {
63            // remove elements that should leave the window
64            let mut recompute_sum = false;
65            for idx in self.last_start..start {
66                // SAFETY:
67                // we are in bounds
68                let valid = self.validity.get_bit_unchecked(idx);
69                if valid {
70                    let leaving_value = self.slice.get_unchecked(idx);
71
72                    // if the leaving value is nan we need to recompute the window
73                    if T::is_float() && !leaving_value.is_finite() {
74                        recompute_sum = true;
75                        break;
76                    }
77                    self.sum = self.sum.map(|v| v - *leaving_value)
78                } else {
79                    // null value leaving the window
80                    self.null_count -= 1;
81
82                    // self.sum is None and the leaving value is None
83                    // if the entering value is valid, we might get a new sum.
84                    if self.sum.is_none() {
85                        recompute_sum = true;
86                        break;
87                    }
88                }
89            }
90            recompute_sum
91        };
92
93        self.last_start = start;
94
95        // we traverse all values and compute
96        if recompute_sum {
97            self.compute_sum_and_null_count(start, end);
98        } else {
99            for idx in self.last_end..end {
100                let valid = self.validity.get_bit_unchecked(idx);
101
102                if valid {
103                    let value = *self.slice.get_unchecked(idx);
104                    match self.sum {
105                        None => self.sum = Some(value),
106                        Some(current) => self.sum = Some(current + value),
107                    }
108                } else {
109                    // null value entering the window
110                    self.null_count += 1;
111                }
112            }
113        }
114        self.last_end = end;
115        self.sum
116    }
117
118    fn is_valid(&self, min_periods: usize) -> bool {
119        ((self.last_end - self.last_start) - self.null_count) >= min_periods
120    }
121}
122
123pub fn rolling_sum<T>(
124    arr: &PrimitiveArray<T>,
125    window_size: usize,
126    min_periods: usize,
127    center: bool,
128    weights: Option<&[f64]>,
129    _params: Option<RollingFnParams>,
130) -> ArrayRef
131where
132    T: NativeType + IsFloat + PartialOrd + Add<Output = T> + Sub<Output = T>,
133{
134    if weights.is_some() {
135        panic!("weights not yet supported on array with null values")
136    }
137    if center {
138        rolling_apply_agg_window::<SumWindow<_>, _, _>(
139            arr.values().as_slice(),
140            arr.validity().as_ref().unwrap(),
141            window_size,
142            min_periods,
143            det_offsets_center,
144            None,
145        )
146    } else {
147        rolling_apply_agg_window::<SumWindow<_>, _, _>(
148            arr.values().as_slice(),
149            arr.validity().as_ref().unwrap(),
150            window_size,
151            min_periods,
152            det_offsets,
153            None,
154        )
155    }
156}