polars_arrow/legacy/kernels/ewm/
mod.rs

1mod average;
2mod variance;
3
4use std::hash::{Hash, Hasher};
5
6pub use average::*;
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9pub use variance::*;
10
11#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
12#[derive(Debug, Copy, Clone, PartialEq)]
13#[must_use]
14pub struct EWMOptions {
15    pub alpha: f64,
16    pub adjust: bool,
17    pub bias: bool,
18    pub min_periods: usize,
19    pub ignore_nulls: bool,
20}
21
22impl Default for EWMOptions {
23    fn default() -> Self {
24        Self {
25            alpha: 0.5,
26            adjust: true,
27            bias: false,
28            min_periods: 1,
29            ignore_nulls: true,
30        }
31    }
32}
33
34impl Hash for EWMOptions {
35    fn hash<H: Hasher>(&self, state: &mut H) {
36        self.alpha.to_bits().hash(state);
37        self.adjust.hash(state);
38        self.bias.hash(state);
39        self.min_periods.hash(state);
40        self.ignore_nulls.hash(state);
41    }
42}
43
44impl EWMOptions {
45    pub fn and_min_periods(mut self, min_periods: usize) -> Self {
46        self.min_periods = min_periods;
47        self
48    }
49    pub fn and_adjust(mut self, adjust: bool) -> Self {
50        self.adjust = adjust;
51        self
52    }
53    pub fn and_span(mut self, span: usize) -> Self {
54        assert!(span >= 1);
55        self.alpha = 2.0 / (span as f64 + 1.0);
56        self
57    }
58    pub fn and_half_life(mut self, half_life: f64) -> Self {
59        assert!(half_life > 0.0);
60        self.alpha = 1.0 - (-(2.0f64.ln()) / half_life).exp();
61        self
62    }
63    pub fn and_com(mut self, com: f64) -> Self {
64        assert!(com > 0.0);
65        self.alpha = 1.0 / (1.0 + com);
66        self
67    }
68    pub fn and_ignore_nulls(mut self, ignore_nulls: bool) -> Self {
69        self.ignore_nulls = ignore_nulls;
70        self
71    }
72}
73
74#[cfg(test)]
75macro_rules! assert_allclose {
76    ($xs:expr, $ys:expr, $tol:expr) => {
77        assert!($xs
78            .iter()
79            .zip($ys.iter())
80            .map(|(x, z)| {
81                match (x, z) {
82                    (Some(a), Some(b)) => (a - b).abs() < $tol,
83                    (None, None) => true,
84                    _ => false,
85                }
86            })
87            .fold(true, |acc, b| acc && b));
88    };
89}
90#[cfg(test)]
91pub(crate) use assert_allclose;