polars_arrow/legacy/kernels/rolling/no_nulls/
sum.rs

1use super::*;
2
3pub struct SumWindow<'a, T> {
4    slice: &'a [T],
5    sum: T,
6    last_start: usize,
7    last_end: usize,
8}
9
10impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign>
11    RollingAggWindowNoNulls<'a, T> for SumWindow<'a, T>
12{
13    fn new(slice: &'a [T], start: usize, end: usize, _params: Option<RollingFnParams>) -> Self {
14        let sum = slice[start..end].iter().copied().sum::<T>();
15        Self {
16            slice,
17            sum,
18            last_start: start,
19            last_end: end,
20        }
21    }
22
23    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
24        // if we exceed the end, we have a completely new window
25        // so we recompute
26        let recompute_sum = if start >= self.last_end {
27            true
28        } else {
29            // remove elements that should leave the window
30            let mut recompute_sum = false;
31            for idx in self.last_start..start {
32                // SAFETY:
33                // we are in bounds
34                let leaving_value = self.slice.get_unchecked(idx);
35
36                if T::is_float() && !leaving_value.is_finite() {
37                    recompute_sum = true;
38                    break;
39                }
40
41                self.sum -= *leaving_value;
42            }
43            recompute_sum
44        };
45        self.last_start = start;
46
47        // we traverse all values and compute
48        if recompute_sum {
49            self.sum = self
50                .slice
51                .get_unchecked(start..end)
52                .iter()
53                .copied()
54                .sum::<T>();
55        }
56        // remove leaving values.
57        else {
58            for idx in self.last_end..end {
59                self.sum += *self.slice.get_unchecked(idx);
60            }
61        }
62        self.last_end = end;
63        Some(self.sum)
64    }
65}
66
67pub fn rolling_sum<T>(
68    values: &[T],
69    window_size: usize,
70    min_periods: usize,
71    center: bool,
72    weights: Option<&[f64]>,
73    _params: Option<RollingFnParams>,
74) -> PolarsResult<ArrayRef>
75where
76    T: NativeType
77        + std::iter::Sum
78        + NumCast
79        + Mul<Output = T>
80        + AddAssign
81        + SubAssign
82        + IsFloat
83        + Num,
84{
85    match (center, weights) {
86        (true, None) => rolling_apply_agg_window::<SumWindow<_>, _, _>(
87            values,
88            window_size,
89            min_periods,
90            det_offsets_center,
91            None,
92        ),
93        (false, None) => rolling_apply_agg_window::<SumWindow<_>, _, _>(
94            values,
95            window_size,
96            min_periods,
97            det_offsets,
98            None,
99        ),
100        (true, Some(weights)) => {
101            let weights = no_nulls::coerce_weights(weights);
102            no_nulls::rolling_apply_weights(
103                values,
104                window_size,
105                min_periods,
106                det_offsets_center,
107                no_nulls::compute_sum_weights,
108                &weights,
109            )
110        },
111        (false, Some(weights)) => {
112            let weights = no_nulls::coerce_weights(weights);
113            no_nulls::rolling_apply_weights(
114                values,
115                window_size,
116                min_periods,
117                det_offsets,
118                no_nulls::compute_sum_weights,
119                &weights,
120            )
121        },
122    }
123}
124
125#[cfg(test)]
126mod test {
127    use super::*;
128    #[test]
129    fn test_rolling_sum() {
130        let values = &[1.0f64, 2.0, 3.0, 4.0];
131
132        let out = rolling_sum(values, 2, 2, false, None, None).unwrap();
133        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
134        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
135        assert_eq!(out, &[None, Some(3.0), Some(5.0), Some(7.0)]);
136
137        let out = rolling_sum(values, 2, 1, false, None, None).unwrap();
138        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
139        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
140        assert_eq!(out, &[Some(1.0), Some(3.0), Some(5.0), Some(7.0)]);
141
142        let out = rolling_sum(values, 4, 1, false, None, None).unwrap();
143        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
144        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
145        assert_eq!(out, &[Some(1.0), Some(3.0), Some(6.0), Some(10.0)]);
146
147        let out = rolling_sum(values, 4, 1, true, None, None).unwrap();
148        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
149        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
150        assert_eq!(out, &[Some(3.0), Some(6.0), Some(10.0), Some(9.0)]);
151
152        let out = rolling_sum(values, 4, 4, true, None, None).unwrap();
153        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
154        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
155        assert_eq!(out, &[None, None, Some(10.0), None]);
156
157        // test nan handling.
158        let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
159        let out = rolling_sum(values, 3, 3, false, None, None).unwrap();
160        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
161        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
162
163        assert_eq!(
164            format!("{:?}", out.as_slice()),
165            format!(
166                "{:?}",
167                &[
168                    None,
169                    None,
170                    Some(6.0),
171                    Some(f64::nan()),
172                    Some(f64::nan()),
173                    Some(f64::nan()),
174                    Some(18.0)
175                ]
176            )
177        );
178    }
179}