polars_arrow/legacy/kernels/rolling/no_nulls/
mod.rs1mod mean;
2mod min_max;
3mod quantile;
4mod sum;
5mod variance;
6use std::fmt::Debug;
7
8pub use mean::*;
9pub use min_max::*;
10use num_traits::{Float, Num, NumCast};
11pub use quantile::*;
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14use strum_macros::IntoStaticStr;
15pub use sum::*;
16pub use variance::*;
17
18use super::*;
19use crate::array::PrimitiveArray;
20use crate::datatypes::ArrowDataType;
21use crate::legacy::error::PolarsResult;
22use crate::types::NativeType;
23
24pub trait RollingAggWindowNoNulls<'a, T: NativeType> {
25 fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self;
26
27 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T>;
32}
33
34pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(
36 values: &'a [T],
37 window_size: usize,
38 min_periods: usize,
39 det_offsets_fn: Fo,
40 params: Option<RollingFnParams>,
41) -> PolarsResult<ArrayRef>
42where
43 Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
44 Agg: RollingAggWindowNoNulls<'a, T>,
45 T: Debug + NativeType + Num,
46{
47 let len = values.len();
48 let (start, end) = det_offsets_fn(0, window_size, len);
49 let mut agg_window = Agg::new(values, start, end, params);
50 if let Some(validity) = create_validity(min_periods, len, window_size, &det_offsets_fn) {
51 if validity.iter().all(|x| !x) {
52 return Ok(Box::new(PrimitiveArray::<T>::new_null(
53 T::PRIMITIVE.into(),
54 len,
55 )));
56 }
57 }
58
59 let out = (0..len).map(|idx| {
60 let (start, end) = det_offsets_fn(idx, window_size, len);
61 if end - start < min_periods {
62 None
63 } else {
64 unsafe { agg_window.update(start, end) }
67 }
68 });
69 let arr = PrimitiveArray::from_trusted_len_iter(out);
70 Ok(Box::new(arr))
71}
72
73#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]
74#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
75#[strum(serialize_all = "snake_case")]
76pub enum QuantileMethod {
77 #[default]
78 Nearest,
79 Lower,
80 Higher,
81 Midpoint,
82 Linear,
83 Equiprobable,
84}
85
86#[deprecated(note = "use QuantileMethod instead")]
87pub type QuantileInterpolOptions = QuantileMethod;
88
89pub(super) fn rolling_apply_weights<T, Fo, Fa>(
90 values: &[T],
91 window_size: usize,
92 min_periods: usize,
93 det_offsets_fn: Fo,
94 aggregator: Fa,
95 weights: &[T],
96) -> PolarsResult<ArrayRef>
97where
98 T: NativeType,
99 Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
100 Fa: Fn(&[T], &[T]) -> T,
101{
102 assert_eq!(weights.len(), window_size);
103 let len = values.len();
104 let out = (0..len)
105 .map(|idx| {
106 let (start, end) = det_offsets_fn(idx, window_size, len);
107 let vals = unsafe { values.get_unchecked(start..end) };
108
109 aggregator(vals, weights)
110 })
111 .collect_trusted::<Vec<T>>();
112
113 let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
114 Ok(Box::new(PrimitiveArray::new(
115 ArrowDataType::from(T::PRIMITIVE),
116 out.into(),
117 validity.map(|b| b.into()),
118 )))
119}
120
121fn compute_var_weights<T>(vals: &[T], weights: &[T]) -> T
122where
123 T: Float + std::ops::AddAssign,
124{
125 debug_assert!(
127 weights.iter().fold(T::zero(), |acc, x| acc + *x) == T::one(),
128 "Rolling weighted variance Weights don't sum to 1"
129 );
130 let (wssq, wmean) = vals
131 .iter()
132 .zip(weights)
133 .fold((T::zero(), T::zero()), |(wssq, wsum), (&v, &w)| {
134 (wssq + v * v * w, wsum + v * w)
135 });
136
137 wssq - wmean * wmean
138}
139
140pub(crate) fn compute_sum_weights<T>(values: &[T], weights: &[T]) -> T
141where
142 T: std::iter::Sum<T> + Copy + std::ops::Mul<Output = T>,
143{
144 values.iter().zip(weights).map(|(v, w)| *v * *w).sum()
145}
146
147pub(super) fn coerce_weights<T: NumCast>(weights: &[f64]) -> Vec<T>
148where
149{
150 weights
151 .iter()
152 .map(|v| NumCast::from(*v).unwrap())
153 .collect::<Vec<_>>()
154}