1mod mean;
2mod min_max;
3mod quantile;
4mod sum;
5mod variance;
6
7pub use mean::*;
8pub use min_max::*;
9pub use quantile::*;
10pub use sum::*;
11pub use variance::*;
12
13use super::*;
14
15pub trait RollingAggWindowNulls<'a, T: NativeType> {
16 unsafe fn new(
19 slice: &'a [T],
20 validity: &'a Bitmap,
21 start: usize,
22 end: usize,
23 params: Option<RollingFnParams>,
24 ) -> Self;
25
26 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T>;
29
30 fn is_valid(&self, min_periods: usize) -> bool;
31}
32
33pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(
35 values: &'a [T],
36 validity: &'a Bitmap,
37 window_size: usize,
38 min_periods: usize,
39 det_offsets_fn: Fo,
40 params: Option<RollingFnParams>,
41) -> ArrayRef
42where
43 Fo: Fn(Idx, WindowSize, Len) -> (Start, End) + Copy,
44 Agg: RollingAggWindowNulls<'a, T>,
45 T: IsFloat + NativeType,
46{
47 let len = values.len();
48 let (start, end) = det_offsets_fn(0, window_size, len);
49 let mut agg_window = unsafe { Agg::new(values, validity, start, end, params) };
51
52 let mut validity = create_validity(min_periods, len, window_size, det_offsets_fn)
53 .unwrap_or_else(|| {
54 let mut validity = MutableBitmap::with_capacity(len);
55 validity.extend_constant(len, true);
56 validity
57 });
58
59 let out = (0..len)
60 .map(|idx| {
61 let (start, end) = det_offsets_fn(idx, window_size, len);
62 let agg = unsafe { agg_window.update(start, end) };
65 match agg {
66 Some(val) => {
67 if agg_window.is_valid(min_periods) {
68 val
69 } else {
70 unsafe { validity.set_unchecked(idx, false) };
72 T::default()
73 }
74 },
75 None => {
76 unsafe { validity.set_unchecked(idx, false) };
78 T::default()
79 },
80 }
81 })
82 .collect_trusted::<Vec<_>>();
83
84 Box::new(PrimitiveArray::new(
85 T::PRIMITIVE.into(),
86 out.into(),
87 Some(validity.into()),
88 ))
89}
90
91#[cfg(test)]
92mod test {
93 use super::*;
94 use crate::array::{Array, Int32Array};
95 use crate::buffer::Buffer;
96 use crate::datatypes::ArrowDataType;
97
98 fn get_null_arr() -> PrimitiveArray<f64> {
99 let buf = Buffer::from(vec![1.0, 0.0, -1.0, 4.0]);
101 PrimitiveArray::new(
102 ArrowDataType::Float64,
103 buf,
104 Some(Bitmap::from(&[true, false, true, true])),
105 )
106 }
107
108 #[test]
109 fn test_rolling_sum_nulls() {
110 let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
111 let arr = &PrimitiveArray::new(
112 ArrowDataType::Float64,
113 buf,
114 Some(Bitmap::from(&[true, false, true, true])),
115 );
116
117 let out = rolling_sum(arr, 2, 2, false, None, None);
118 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
119 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
120 assert_eq!(out, &[None, None, None, Some(7.0)]);
121
122 let out = rolling_sum(arr, 2, 1, false, None, None);
123 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
124 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
125 assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(7.0)]);
126
127 let out = rolling_sum(arr, 4, 1, false, None, None);
128 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
129 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
130 assert_eq!(out, &[Some(1.0), Some(1.0), Some(4.0), Some(8.0)]);
131
132 let out = rolling_sum(arr, 4, 1, true, None, None);
133 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
134 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
135 assert_eq!(out, &[Some(1.0), Some(4.0), Some(8.0), Some(7.0)]);
136
137 let out = rolling_sum(arr, 4, 4, true, None, None);
138 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
139 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
140 assert_eq!(out, &[None, None, None, None]);
141 }
142
143 #[test]
144 fn test_rolling_mean_nulls() {
145 let arr = get_null_arr();
146 let arr = &arr;
147
148 let out = rolling_mean(arr, 2, 2, false, None, None);
149 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
150 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
151 assert_eq!(out, &[None, None, None, Some(1.5)]);
152
153 let out = rolling_mean(arr, 2, 1, false, None, None);
154 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
155 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
156 assert_eq!(out, &[Some(1.0), Some(1.0), Some(-1.0), Some(1.5)]);
157
158 let out = rolling_mean(arr, 4, 1, false, None, None);
159 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
160 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
161 assert_eq!(out, &[Some(1.0), Some(1.0), Some(0.0), Some(4.0 / 3.0)]);
162 }
163
164 #[test]
165 fn test_rolling_var_nulls() {
166 let arr = get_null_arr();
167 let arr = &arr;
168
169 let out = rolling_var(arr, 3, 1, false, None, None);
170 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
171 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
172
173 assert_eq!(out, &[None, None, Some(2.0), Some(12.5)]);
174
175 let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 }));
176 let out = rolling_var(arr, 3, 1, false, None, testpars.clone());
177 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
178 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
179
180 assert_eq!(out, &[Some(0.0), Some(0.0), Some(1.0), Some(6.25)]);
181
182 let out = rolling_var(arr, 4, 1, false, None, None);
183 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
184 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
185 assert_eq!(out, &[None, None, Some(2.0), Some(6.333333333333334)]);
186
187 let out = rolling_var(arr, 4, 1, false, None, testpars.clone());
188 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
189 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
190 assert_eq!(
191 out,
192 &[Some(0.), Some(0.0), Some(1.0), Some(4.222222222222222)]
193 );
194 }
195
196 #[test]
197 fn test_rolling_max_no_nulls() {
198 let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
199 let arr = &PrimitiveArray::new(
200 ArrowDataType::Float64,
201 buf,
202 Some(Bitmap::from(&[true, true, true, true])),
203 );
204 let out = rolling_max(arr, 4, 1, false, None, None);
205 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
206 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
207 assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(4.0)]);
208
209 let out = rolling_max(arr, 2, 2, false, None, None);
210 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
211 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
212 assert_eq!(out, &[None, Some(2.0), Some(3.0), Some(4.0)]);
213
214 let out = rolling_max(arr, 4, 4, false, None, None);
215 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
216 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
217 assert_eq!(out, &[None, None, None, Some(4.0)]);
218
219 let buf = Buffer::from(vec![4.0, 3.0, 2.0, 1.0]);
220 let arr = &PrimitiveArray::new(
221 ArrowDataType::Float64,
222 buf,
223 Some(Bitmap::from(&[true, true, true, true])),
224 );
225 let out = rolling_max(arr, 2, 1, false, None, None);
226 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
227 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
228 assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);
229
230 let out =
231 super::no_nulls::rolling_max(arr.values().as_slice(), 2, 1, false, None, None).unwrap();
232 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
233 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
234 assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);
235 }
236
237 #[test]
238 fn test_rolling_extrema_nulls() {
239 let vals = vec![3, 3, 3, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1];
240 let validity = Bitmap::new_with_value(true, vals.len());
241 let window_size = 3;
242 let min_periods = 3;
243
244 let arr = Int32Array::new(ArrowDataType::Int32, vals.into(), Some(validity));
245
246 let out = rolling_apply_agg_window::<MaxWindow<_>, _, _>(
247 arr.values().as_slice(),
248 arr.validity().as_ref().unwrap(),
249 window_size,
250 min_periods,
251 det_offsets,
252 None,
253 );
254 let arr = out.as_any().downcast_ref::<Int32Array>().unwrap();
255 assert_eq!(arr.null_count(), 2);
256 assert_eq!(
257 &arr.values().as_slice()[2..],
258 &[3, 10, 10, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3]
259 );
260 }
261}