polars_arrow/legacy/kernels/rolling/no_nulls/
mod.rs

1mod mean;
2mod min_max;
3mod quantile;
4mod sum;
5mod variance;
6use std::fmt::Debug;
7
8pub use mean::*;
9pub use min_max::*;
10use num_traits::{Float, Num, NumCast};
11pub use quantile::*;
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14use strum_macros::IntoStaticStr;
15pub use sum::*;
16pub use variance::*;
17
18use super::*;
19use crate::array::PrimitiveArray;
20use crate::datatypes::ArrowDataType;
21use crate::legacy::error::PolarsResult;
22use crate::types::NativeType;
23
24pub trait RollingAggWindowNoNulls<'a, T: NativeType> {
25    fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self;
26
27    /// Update and recompute the window
28    ///
29    /// # Safety
30    /// `start` and `end` must be within the windows bounds
31    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T>;
32}
33
34// Use an aggregation window that maintains the state
35pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(
36    values: &'a [T],
37    window_size: usize,
38    min_periods: usize,
39    det_offsets_fn: Fo,
40    params: Option<RollingFnParams>,
41) -> PolarsResult<ArrayRef>
42where
43    Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
44    Agg: RollingAggWindowNoNulls<'a, T>,
45    T: Debug + NativeType + Num,
46{
47    let len = values.len();
48    let (start, end) = det_offsets_fn(0, window_size, len);
49    let mut agg_window = Agg::new(values, start, end, params);
50    if let Some(validity) = create_validity(min_periods, len, window_size, &det_offsets_fn) {
51        if validity.iter().all(|x| !x) {
52            return Ok(Box::new(PrimitiveArray::<T>::new_null(
53                T::PRIMITIVE.into(),
54                len,
55            )));
56        }
57    }
58
59    let out = (0..len).map(|idx| {
60        let (start, end) = det_offsets_fn(idx, window_size, len);
61        if end - start < min_periods {
62            None
63        } else {
64            // SAFETY:
65            // we are in bounds
66            unsafe { agg_window.update(start, end) }
67        }
68    });
69    let arr = PrimitiveArray::from_trusted_len_iter(out);
70    Ok(Box::new(arr))
71}
72
73#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]
74#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
75#[strum(serialize_all = "snake_case")]
76pub enum QuantileMethod {
77    #[default]
78    Nearest,
79    Lower,
80    Higher,
81    Midpoint,
82    Linear,
83    Equiprobable,
84}
85
86#[deprecated(note = "use QuantileMethod instead")]
87pub type QuantileInterpolOptions = QuantileMethod;
88
89pub(super) fn rolling_apply_weights<T, Fo, Fa>(
90    values: &[T],
91    window_size: usize,
92    min_periods: usize,
93    det_offsets_fn: Fo,
94    aggregator: Fa,
95    weights: &[T],
96) -> PolarsResult<ArrayRef>
97where
98    T: NativeType,
99    Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
100    Fa: Fn(&[T], &[T]) -> T,
101{
102    assert_eq!(weights.len(), window_size);
103    let len = values.len();
104    let out = (0..len)
105        .map(|idx| {
106            let (start, end) = det_offsets_fn(idx, window_size, len);
107            let vals = unsafe { values.get_unchecked(start..end) };
108
109            aggregator(vals, weights)
110        })
111        .collect_trusted::<Vec<T>>();
112
113    let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
114    Ok(Box::new(PrimitiveArray::new(
115        ArrowDataType::from(T::PRIMITIVE),
116        out.into(),
117        validity.map(|b| b.into()),
118    )))
119}
120
121fn compute_var_weights<T>(vals: &[T], weights: &[T]) -> T
122where
123    T: Float + std::ops::AddAssign,
124{
125    // Assumes the weights have already been standardized to 1
126    debug_assert!(
127        weights.iter().fold(T::zero(), |acc, x| acc + *x) == T::one(),
128        "Rolling weighted variance Weights don't sum to 1"
129    );
130    let (wssq, wmean) = vals
131        .iter()
132        .zip(weights)
133        .fold((T::zero(), T::zero()), |(wssq, wsum), (&v, &w)| {
134            (wssq + v * v * w, wsum + v * w)
135        });
136
137    wssq - wmean * wmean
138}
139
140pub(crate) fn compute_sum_weights<T>(values: &[T], weights: &[T]) -> T
141where
142    T: std::iter::Sum<T> + Copy + std::ops::Mul<Output = T>,
143{
144    values.iter().zip(weights).map(|(v, w)| *v * *w).sum()
145}
146
147pub(super) fn coerce_weights<T: NumCast>(weights: &[f64]) -> Vec<T>
148where
149{
150    weights
151        .iter()
152        .map(|v| NumCast::from(*v).unwrap())
153        .collect::<Vec<_>>()
154}