polars_arrow/legacy/kernels/rolling/nulls/
sum.rs1use super::*;
2
3pub struct SumWindow<'a, T> {
4 slice: &'a [T],
5 validity: &'a Bitmap,
6 sum: Option<T>,
7 last_start: usize,
8 last_end: usize,
9 pub(super) null_count: usize,
10}
11
12impl<T: NativeType + IsFloat + Add<Output = T> + Sub<Output = T>> SumWindow<'_, T> {
13 unsafe fn compute_sum_and_null_count(&mut self, start: usize, end: usize) -> Option<T> {
15 let mut sum = None;
16 let mut idx = start;
17 self.null_count = 0;
18 for value in &self.slice[start..end] {
19 let valid = self.validity.get_bit_unchecked(idx);
20 if valid {
21 match sum {
22 None => sum = Some(*value),
23 Some(current) => sum = Some(*value + current),
24 }
25 } else {
26 self.null_count += 1;
27 }
28 idx += 1;
29 }
30 self.sum = sum;
31 sum
32 }
33}
34
35impl<'a, T: NativeType + IsFloat + Add<Output = T> + Sub<Output = T>> RollingAggWindowNulls<'a, T>
36 for SumWindow<'a, T>
37{
38 unsafe fn new(
39 slice: &'a [T],
40 validity: &'a Bitmap,
41 start: usize,
42 end: usize,
43 _params: Option<RollingFnParams>,
44 ) -> Self {
45 let mut out = Self {
46 slice,
47 validity,
48 sum: None,
49 last_start: start,
50 last_end: end,
51 null_count: 0,
52 };
53 out.compute_sum_and_null_count(start, end);
54 out
55 }
56
57 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
58 let recompute_sum = if start >= self.last_end {
61 true
62 } else {
63 let mut recompute_sum = false;
65 for idx in self.last_start..start {
66 let valid = self.validity.get_bit_unchecked(idx);
69 if valid {
70 let leaving_value = self.slice.get_unchecked(idx);
71
72 if T::is_float() && !leaving_value.is_finite() {
74 recompute_sum = true;
75 break;
76 }
77 self.sum = self.sum.map(|v| v - *leaving_value)
78 } else {
79 self.null_count -= 1;
81
82 if self.sum.is_none() {
85 recompute_sum = true;
86 break;
87 }
88 }
89 }
90 recompute_sum
91 };
92
93 self.last_start = start;
94
95 if recompute_sum {
97 self.compute_sum_and_null_count(start, end);
98 } else {
99 for idx in self.last_end..end {
100 let valid = self.validity.get_bit_unchecked(idx);
101
102 if valid {
103 let value = *self.slice.get_unchecked(idx);
104 match self.sum {
105 None => self.sum = Some(value),
106 Some(current) => self.sum = Some(current + value),
107 }
108 } else {
109 self.null_count += 1;
111 }
112 }
113 }
114 self.last_end = end;
115 self.sum
116 }
117
118 fn is_valid(&self, min_periods: usize) -> bool {
119 ((self.last_end - self.last_start) - self.null_count) >= min_periods
120 }
121}
122
123pub fn rolling_sum<T>(
124 arr: &PrimitiveArray<T>,
125 window_size: usize,
126 min_periods: usize,
127 center: bool,
128 weights: Option<&[f64]>,
129 _params: Option<RollingFnParams>,
130) -> ArrayRef
131where
132 T: NativeType + IsFloat + PartialOrd + Add<Output = T> + Sub<Output = T>,
133{
134 if weights.is_some() {
135 panic!("weights not yet supported on array with null values")
136 }
137 if center {
138 rolling_apply_agg_window::<SumWindow<_>, _, _>(
139 arr.values().as_slice(),
140 arr.validity().as_ref().unwrap(),
141 window_size,
142 min_periods,
143 det_offsets_center,
144 None,
145 )
146 } else {
147 rolling_apply_agg_window::<SumWindow<_>, _, _>(
148 arr.values().as_slice(),
149 arr.validity().as_ref().unwrap(),
150 window_size,
151 min_periods,
152 det_offsets,
153 None,
154 )
155 }
156}