polars_arrow/legacy/kernels/ewm/
average.rs

1use std::ops::{AddAssign, MulAssign};
2
3use num_traits::Float;
4
5use crate::array::PrimitiveArray;
6use crate::legacy::utils::CustomIterTools;
7use crate::trusted_len::TrustedLen;
8use crate::types::NativeType;
9
10pub fn ewm_mean<I, T>(
11    xs: I,
12    alpha: T,
13    adjust: bool,
14    min_periods: usize,
15    ignore_nulls: bool,
16) -> PrimitiveArray<T>
17where
18    I: IntoIterator<Item = Option<T>>,
19    I::IntoIter: TrustedLen,
20    T: Float + NativeType + AddAssign + MulAssign,
21{
22    let new_wt = if adjust { T::one() } else { alpha };
23    let old_wt_factor = T::one() - alpha;
24    let mut old_wt = T::one();
25    let mut weighted_avg = None;
26    let mut non_null_cnt = 0usize;
27
28    xs.into_iter()
29        .enumerate()
30        .map(|(i, opt_x)| {
31            if opt_x.is_some() {
32                non_null_cnt += 1;
33            }
34            match (i, weighted_avg) {
35                (0, _) | (_, None) => weighted_avg = opt_x,
36                (_, Some(w_avg)) => {
37                    if opt_x.is_some() || !ignore_nulls {
38                        old_wt *= old_wt_factor;
39                        if let Some(x) = opt_x {
40                            if w_avg != x {
41                                weighted_avg =
42                                    Some((old_wt * w_avg + new_wt * x) / (old_wt + new_wt));
43                            }
44                            old_wt = if adjust { old_wt + new_wt } else { T::one() };
45                        }
46                    }
47                },
48            }
49            match (non_null_cnt < min_periods, opt_x.is_some()) {
50                (_, false) => None,
51                (true, true) => None,
52                (false, true) => weighted_avg,
53            }
54        })
55        .collect_trusted()
56}
57
58#[cfg(test)]
59mod test {
60    use super::super::assert_allclose;
61    use super::*;
62    const ALPHA: f64 = 0.5;
63    const EPS: f64 = 1e-15;
64
65    #[test]
66    fn test_ewm_mean_without_null() {
67        let xs: Vec<Option<f64>> = vec![Some(1.0), Some(2.0), Some(3.0)];
68        for adjust in [false, true] {
69            for ignore_nulls in [false, true] {
70                for min_periods in [0, 1] {
71                    let result = ewm_mean(xs.clone(), ALPHA, adjust, min_periods, ignore_nulls);
72                    let expected = match adjust {
73                        false => PrimitiveArray::from([Some(1.0f64), Some(1.5f64), Some(2.25f64)]),
74                        true => PrimitiveArray::from([
75                            Some(1.0),
76                            Some(1.666_666_666_666_666_7),
77                            Some(2.428_571_428_571_428_4),
78                        ]),
79                    };
80                    assert_allclose!(result, expected, 1e-15);
81                }
82                let result = ewm_mean(xs.clone(), ALPHA, adjust, 2, ignore_nulls);
83                let expected = match adjust {
84                    false => PrimitiveArray::from([None, Some(1.5f64), Some(2.25f64)]),
85                    true => PrimitiveArray::from([
86                        None,
87                        Some(1.666_666_666_666_666_7),
88                        Some(2.428_571_428_571_428_4),
89                    ]),
90                };
91                assert_allclose!(result, expected, EPS);
92            }
93        }
94    }
95
96    #[test]
97    fn test_ewm_mean_with_null() {
98        let xs1 = vec![
99            None,
100            None,
101            Some(5.0f64),
102            Some(7.0f64),
103            None,
104            Some(2.0f64),
105            Some(1.0f64),
106            Some(4.0f64),
107        ];
108        assert_allclose!(
109            ewm_mean(xs1.clone(), 0.5, true, 0, true),
110            PrimitiveArray::from([
111                None,
112                None,
113                Some(5.0),
114                Some(6.333_333_333_333_333),
115                None,
116                Some(3.857_142_857_142_857),
117                Some(2.333_333_333_333_333_5),
118                Some(3.193_548_387_096_774),
119            ]),
120            EPS
121        );
122        assert_allclose!(
123            ewm_mean(xs1.clone(), 0.5, true, 0, false),
124            PrimitiveArray::from([
125                None,
126                None,
127                Some(5.0),
128                Some(6.333_333_333_333_333),
129                None,
130                Some(3.181_818_181_818_181_7),
131                Some(1.888_888_888_888_888_8),
132                Some(3.033_898_305_084_745_7),
133            ]),
134            EPS
135        );
136        assert_allclose!(
137            ewm_mean(xs1.clone(), 0.5, false, 0, true),
138            PrimitiveArray::from([
139                None,
140                None,
141                Some(5.0),
142                Some(6.0),
143                None,
144                Some(4.0),
145                Some(2.5),
146                Some(3.25),
147            ]),
148            EPS
149        );
150        assert_allclose!(
151            ewm_mean(xs1, 0.5, false, 0, false),
152            PrimitiveArray::from([
153                None,
154                None,
155                Some(5.0),
156                Some(6.0),
157                None,
158                Some(3.333_333_333_333_333_5),
159                Some(2.166_666_666_666_667),
160                Some(3.083_333_333_333_333_5),
161            ]),
162            EPS
163        );
164    }
165}