polars_arrow/legacy/kernels/rolling/
mod.rs

1pub mod no_nulls;
2pub mod nulls;
3pub mod quantile_filter;
4mod window;
5
6use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign};
7
8use num_traits::{Bounded, Float, NumCast, One, Zero};
9use polars_utils::float::IsFloat;
10use polars_utils::ord::{compare_fn_nan_max, compare_fn_nan_min};
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize};
13use window::*;
14
15use crate::array::{ArrayRef, PrimitiveArray};
16use crate::bitmap::{Bitmap, MutableBitmap};
17use crate::legacy::prelude::*;
18use crate::legacy::utils::CustomIterTools;
19use crate::types::NativeType;
20
21type Start = usize;
22type End = usize;
23type Idx = usize;
24type WindowSize = usize;
25type Len = usize;
26
27#[derive(Clone, Debug, PartialEq)]
28#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
29pub enum RollingFnParams {
30    Quantile(RollingQuantileParams),
31    Var(RollingVarParams),
32}
33
34fn det_offsets(i: Idx, window_size: WindowSize, _len: Len) -> (usize, usize) {
35    (i.saturating_sub(window_size - 1), i + 1)
36}
37fn det_offsets_center(i: Idx, window_size: WindowSize, len: Len) -> (usize, usize) {
38    let right_window = (window_size + 1) / 2;
39    (
40        i.saturating_sub(window_size - right_window),
41        std::cmp::min(len, i + right_window),
42    )
43}
44
45fn create_validity<Fo>(
46    min_periods: usize,
47    len: usize,
48    window_size: usize,
49    det_offsets_fn: Fo,
50) -> Option<MutableBitmap>
51where
52    Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
53{
54    if min_periods > 1 {
55        let mut validity = MutableBitmap::with_capacity(len);
56        validity.extend_constant(len, true);
57
58        // Set the null values at the boundaries
59
60        // Head.
61        for i in 0..len {
62            let (start, end) = det_offsets_fn(i, window_size, len);
63            if (end - start) < min_periods {
64                validity.set(i, false)
65            } else {
66                break;
67            }
68        }
69        // Tail.
70        for i in (0..len).rev() {
71            let (start, end) = det_offsets_fn(i, window_size, len);
72            if (end - start) < min_periods {
73                validity.set(i, false)
74            } else {
75                break;
76            }
77        }
78
79        Some(validity)
80    } else {
81        None
82    }
83}
84
85// Parameters allowed for rolling operations.
86#[derive(Clone, Copy, Debug, PartialEq)]
87#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
88pub struct RollingVarParams {
89    pub ddof: u8,
90}
91
92#[derive(Clone, Copy, Debug, PartialEq)]
93#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
94pub struct RollingQuantileParams {
95    pub prob: f64,
96    pub method: QuantileMethod,
97}