polars_arrow/legacy/kernels/ewm/
mod.rs1mod average;
2mod variance;
3
4use std::hash::{Hash, Hasher};
5
6pub use average::*;
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9pub use variance::*;
10
11#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
12#[derive(Debug, Copy, Clone, PartialEq)]
13#[must_use]
14pub struct EWMOptions {
15 pub alpha: f64,
16 pub adjust: bool,
17 pub bias: bool,
18 pub min_periods: usize,
19 pub ignore_nulls: bool,
20}
21
22impl Default for EWMOptions {
23 fn default() -> Self {
24 Self {
25 alpha: 0.5,
26 adjust: true,
27 bias: false,
28 min_periods: 1,
29 ignore_nulls: true,
30 }
31 }
32}
33
34impl Hash for EWMOptions {
35 fn hash<H: Hasher>(&self, state: &mut H) {
36 self.alpha.to_bits().hash(state);
37 self.adjust.hash(state);
38 self.bias.hash(state);
39 self.min_periods.hash(state);
40 self.ignore_nulls.hash(state);
41 }
42}
43
44impl EWMOptions {
45 pub fn and_min_periods(mut self, min_periods: usize) -> Self {
46 self.min_periods = min_periods;
47 self
48 }
49 pub fn and_adjust(mut self, adjust: bool) -> Self {
50 self.adjust = adjust;
51 self
52 }
53 pub fn and_span(mut self, span: usize) -> Self {
54 assert!(span >= 1);
55 self.alpha = 2.0 / (span as f64 + 1.0);
56 self
57 }
58 pub fn and_half_life(mut self, half_life: f64) -> Self {
59 assert!(half_life > 0.0);
60 self.alpha = 1.0 - (-(2.0f64.ln()) / half_life).exp();
61 self
62 }
63 pub fn and_com(mut self, com: f64) -> Self {
64 assert!(com > 0.0);
65 self.alpha = 1.0 / (1.0 + com);
66 self
67 }
68 pub fn and_ignore_nulls(mut self, ignore_nulls: bool) -> Self {
69 self.ignore_nulls = ignore_nulls;
70 self
71 }
72}
73
74#[cfg(test)]
75macro_rules! assert_allclose {
76 ($xs:expr, $ys:expr, $tol:expr) => {
77 assert!($xs
78 .iter()
79 .zip($ys.iter())
80 .map(|(x, z)| {
81 match (x, z) {
82 (Some(a), Some(b)) => (a - b).abs() < $tol,
83 (None, None) => true,
84 _ => false,
85 }
86 })
87 .fold(true, |acc, b| acc && b));
88 };
89}
90#[cfg(test)]
91pub(crate) use assert_allclose;