polars_arrow/legacy/kernels/ewm/
average.rs1use std::ops::{AddAssign, MulAssign};
2
3use num_traits::Float;
4
5use crate::array::PrimitiveArray;
6use crate::legacy::utils::CustomIterTools;
7use crate::trusted_len::TrustedLen;
8use crate::types::NativeType;
9
10pub fn ewm_mean<I, T>(
11 xs: I,
12 alpha: T,
13 adjust: bool,
14 min_periods: usize,
15 ignore_nulls: bool,
16) -> PrimitiveArray<T>
17where
18 I: IntoIterator<Item = Option<T>>,
19 I::IntoIter: TrustedLen,
20 T: Float + NativeType + AddAssign + MulAssign,
21{
22 let new_wt = if adjust { T::one() } else { alpha };
23 let old_wt_factor = T::one() - alpha;
24 let mut old_wt = T::one();
25 let mut weighted_avg = None;
26 let mut non_null_cnt = 0usize;
27
28 xs.into_iter()
29 .enumerate()
30 .map(|(i, opt_x)| {
31 if opt_x.is_some() {
32 non_null_cnt += 1;
33 }
34 match (i, weighted_avg) {
35 (0, _) | (_, None) => weighted_avg = opt_x,
36 (_, Some(w_avg)) => {
37 if opt_x.is_some() || !ignore_nulls {
38 old_wt *= old_wt_factor;
39 if let Some(x) = opt_x {
40 if w_avg != x {
41 weighted_avg =
42 Some((old_wt * w_avg + new_wt * x) / (old_wt + new_wt));
43 }
44 old_wt = if adjust { old_wt + new_wt } else { T::one() };
45 }
46 }
47 },
48 }
49 match (non_null_cnt < min_periods, opt_x.is_some()) {
50 (_, false) => None,
51 (true, true) => None,
52 (false, true) => weighted_avg,
53 }
54 })
55 .collect_trusted()
56}
57
58#[cfg(test)]
59mod test {
60 use super::super::assert_allclose;
61 use super::*;
62 const ALPHA: f64 = 0.5;
63 const EPS: f64 = 1e-15;
64
65 #[test]
66 fn test_ewm_mean_without_null() {
67 let xs: Vec<Option<f64>> = vec![Some(1.0), Some(2.0), Some(3.0)];
68 for adjust in [false, true] {
69 for ignore_nulls in [false, true] {
70 for min_periods in [0, 1] {
71 let result = ewm_mean(xs.clone(), ALPHA, adjust, min_periods, ignore_nulls);
72 let expected = match adjust {
73 false => PrimitiveArray::from([Some(1.0f64), Some(1.5f64), Some(2.25f64)]),
74 true => PrimitiveArray::from([
75 Some(1.0),
76 Some(1.666_666_666_666_666_7),
77 Some(2.428_571_428_571_428_4),
78 ]),
79 };
80 assert_allclose!(result, expected, 1e-15);
81 }
82 let result = ewm_mean(xs.clone(), ALPHA, adjust, 2, ignore_nulls);
83 let expected = match adjust {
84 false => PrimitiveArray::from([None, Some(1.5f64), Some(2.25f64)]),
85 true => PrimitiveArray::from([
86 None,
87 Some(1.666_666_666_666_666_7),
88 Some(2.428_571_428_571_428_4),
89 ]),
90 };
91 assert_allclose!(result, expected, EPS);
92 }
93 }
94 }
95
96 #[test]
97 fn test_ewm_mean_with_null() {
98 let xs1 = vec![
99 None,
100 None,
101 Some(5.0f64),
102 Some(7.0f64),
103 None,
104 Some(2.0f64),
105 Some(1.0f64),
106 Some(4.0f64),
107 ];
108 assert_allclose!(
109 ewm_mean(xs1.clone(), 0.5, true, 0, true),
110 PrimitiveArray::from([
111 None,
112 None,
113 Some(5.0),
114 Some(6.333_333_333_333_333),
115 None,
116 Some(3.857_142_857_142_857),
117 Some(2.333_333_333_333_333_5),
118 Some(3.193_548_387_096_774),
119 ]),
120 EPS
121 );
122 assert_allclose!(
123 ewm_mean(xs1.clone(), 0.5, true, 0, false),
124 PrimitiveArray::from([
125 None,
126 None,
127 Some(5.0),
128 Some(6.333_333_333_333_333),
129 None,
130 Some(3.181_818_181_818_181_7),
131 Some(1.888_888_888_888_888_8),
132 Some(3.033_898_305_084_745_7),
133 ]),
134 EPS
135 );
136 assert_allclose!(
137 ewm_mean(xs1.clone(), 0.5, false, 0, true),
138 PrimitiveArray::from([
139 None,
140 None,
141 Some(5.0),
142 Some(6.0),
143 None,
144 Some(4.0),
145 Some(2.5),
146 Some(3.25),
147 ]),
148 EPS
149 );
150 assert_allclose!(
151 ewm_mean(xs1, 0.5, false, 0, false),
152 PrimitiveArray::from([
153 None,
154 None,
155 Some(5.0),
156 Some(6.0),
157 None,
158 Some(3.333_333_333_333_333_5),
159 Some(2.166_666_666_666_667),
160 Some(3.083_333_333_333_333_5),
161 ]),
162 EPS
163 );
164 }
165}