polars_arrow/legacy/kernels/rolling/no_nulls/
sum.rs1use super::*;
2
3pub struct SumWindow<'a, T> {
4 slice: &'a [T],
5 sum: T,
6 last_start: usize,
7 last_end: usize,
8}
9
10impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign>
11 RollingAggWindowNoNulls<'a, T> for SumWindow<'a, T>
12{
13 fn new(slice: &'a [T], start: usize, end: usize, _params: Option<RollingFnParams>) -> Self {
14 let sum = slice[start..end].iter().copied().sum::<T>();
15 Self {
16 slice,
17 sum,
18 last_start: start,
19 last_end: end,
20 }
21 }
22
23 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
24 let recompute_sum = if start >= self.last_end {
27 true
28 } else {
29 let mut recompute_sum = false;
31 for idx in self.last_start..start {
32 let leaving_value = self.slice.get_unchecked(idx);
35
36 if T::is_float() && !leaving_value.is_finite() {
37 recompute_sum = true;
38 break;
39 }
40
41 self.sum -= *leaving_value;
42 }
43 recompute_sum
44 };
45 self.last_start = start;
46
47 if recompute_sum {
49 self.sum = self
50 .slice
51 .get_unchecked(start..end)
52 .iter()
53 .copied()
54 .sum::<T>();
55 }
56 else {
58 for idx in self.last_end..end {
59 self.sum += *self.slice.get_unchecked(idx);
60 }
61 }
62 self.last_end = end;
63 Some(self.sum)
64 }
65}
66
67pub fn rolling_sum<T>(
68 values: &[T],
69 window_size: usize,
70 min_periods: usize,
71 center: bool,
72 weights: Option<&[f64]>,
73 _params: Option<RollingFnParams>,
74) -> PolarsResult<ArrayRef>
75where
76 T: NativeType
77 + std::iter::Sum
78 + NumCast
79 + Mul<Output = T>
80 + AddAssign
81 + SubAssign
82 + IsFloat
83 + Num,
84{
85 match (center, weights) {
86 (true, None) => rolling_apply_agg_window::<SumWindow<_>, _, _>(
87 values,
88 window_size,
89 min_periods,
90 det_offsets_center,
91 None,
92 ),
93 (false, None) => rolling_apply_agg_window::<SumWindow<_>, _, _>(
94 values,
95 window_size,
96 min_periods,
97 det_offsets,
98 None,
99 ),
100 (true, Some(weights)) => {
101 let weights = no_nulls::coerce_weights(weights);
102 no_nulls::rolling_apply_weights(
103 values,
104 window_size,
105 min_periods,
106 det_offsets_center,
107 no_nulls::compute_sum_weights,
108 &weights,
109 )
110 },
111 (false, Some(weights)) => {
112 let weights = no_nulls::coerce_weights(weights);
113 no_nulls::rolling_apply_weights(
114 values,
115 window_size,
116 min_periods,
117 det_offsets,
118 no_nulls::compute_sum_weights,
119 &weights,
120 )
121 },
122 }
123}
124
125#[cfg(test)]
126mod test {
127 use super::*;
128 #[test]
129 fn test_rolling_sum() {
130 let values = &[1.0f64, 2.0, 3.0, 4.0];
131
132 let out = rolling_sum(values, 2, 2, false, None, None).unwrap();
133 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
134 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
135 assert_eq!(out, &[None, Some(3.0), Some(5.0), Some(7.0)]);
136
137 let out = rolling_sum(values, 2, 1, false, None, None).unwrap();
138 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
139 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
140 assert_eq!(out, &[Some(1.0), Some(3.0), Some(5.0), Some(7.0)]);
141
142 let out = rolling_sum(values, 4, 1, false, None, None).unwrap();
143 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
144 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
145 assert_eq!(out, &[Some(1.0), Some(3.0), Some(6.0), Some(10.0)]);
146
147 let out = rolling_sum(values, 4, 1, true, None, None).unwrap();
148 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
149 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
150 assert_eq!(out, &[Some(3.0), Some(6.0), Some(10.0), Some(9.0)]);
151
152 let out = rolling_sum(values, 4, 4, true, None, None).unwrap();
153 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
154 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
155 assert_eq!(out, &[None, None, Some(10.0), None]);
156
157 let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
159 let out = rolling_sum(values, 3, 3, false, None, None).unwrap();
160 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
161 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
162
163 assert_eq!(
164 format!("{:?}", out.as_slice()),
165 format!(
166 "{:?}",
167 &[
168 None,
169 None,
170 Some(6.0),
171 Some(f64::nan()),
172 Some(f64::nan()),
173 Some(f64::nan()),
174 Some(18.0)
175 ]
176 )
177 );
178 }
179}