polars_arrow/legacy/kernels/ewm/
variance.rs

1use std::ops::{AddAssign, DivAssign, MulAssign};
2
3use num_traits::Float;
4
5use crate::array::PrimitiveArray;
6use crate::legacy::utils::CustomIterTools;
7use crate::trusted_len::TrustedLen;
8use crate::types::NativeType;
9
10#[allow(clippy::too_many_arguments)]
11fn ewm_cov_internal<I, T>(
12    xs: I,
13    ys: I,
14    alpha: T,
15    adjust: bool,
16    bias: bool,
17    min_periods: usize,
18    ignore_nulls: bool,
19    do_sqrt: bool,
20) -> PrimitiveArray<T>
21where
22    I: IntoIterator<Item = Option<T>>,
23    I::IntoIter: TrustedLen,
24    T: Float + NativeType + AddAssign + MulAssign + DivAssign,
25{
26    let old_wt_factor = T::one() - alpha;
27    let new_wt = if adjust { T::one() } else { alpha };
28    let mut sum_wt = T::one();
29    let mut sum_wt2 = T::one();
30    let mut old_wt = T::one();
31
32    let mut opt_mean_x = None;
33    let mut opt_mean_y = None;
34    let mut cov = T::zero();
35    let mut non_na_cnt = 0usize;
36    let min_periods_fixed = if min_periods == 0 { 1 } else { min_periods };
37
38    let res = xs
39        .into_iter()
40        .zip(ys)
41        .enumerate()
42        .map(|(i, (opt_x, opt_y))| {
43            let is_observation = opt_x.is_some() && opt_y.is_some();
44            if is_observation {
45                non_na_cnt += 1;
46            }
47            match (i, opt_mean_x, opt_mean_y) {
48                (0, _, _) => {
49                    if is_observation {
50                        opt_mean_x = opt_x;
51                        opt_mean_y = opt_y;
52                    }
53                },
54                (_, Some(mean_x), Some(mean_y)) => {
55                    if is_observation || !ignore_nulls {
56                        sum_wt *= old_wt_factor;
57                        sum_wt2 *= old_wt_factor * old_wt_factor;
58                        old_wt *= old_wt_factor;
59                        if is_observation {
60                            let x = opt_x.unwrap();
61                            let y = opt_y.unwrap();
62                            let old_mean_x = mean_x;
63                            let old_mean_y = mean_y;
64
65                            // avoid numerical errors on constant series
66                            if mean_x != x {
67                                opt_mean_x =
68                                    Some((old_wt * old_mean_x + new_wt * x) / (old_wt + new_wt));
69                            }
70
71                            // avoid numerical errors on constant series
72                            if mean_y != y {
73                                opt_mean_y =
74                                    Some((old_wt * old_mean_y + new_wt * y) / (old_wt + new_wt));
75                            }
76
77                            cov = ((old_wt
78                                * (cov
79                                    + ((old_mean_x - opt_mean_x.unwrap())
80                                        * (old_mean_y - opt_mean_y.unwrap()))))
81                                + (new_wt
82                                    * ((x - opt_mean_x.unwrap()) * (y - opt_mean_y.unwrap()))))
83                                / (old_wt + new_wt);
84
85                            sum_wt += new_wt;
86                            sum_wt2 += new_wt * new_wt;
87                            old_wt += new_wt;
88                            if !adjust {
89                                sum_wt /= old_wt;
90                                sum_wt2 /= old_wt * old_wt;
91                                old_wt = T::one();
92                            }
93                        }
94                    }
95                },
96                _ => {
97                    if is_observation {
98                        opt_mean_x = opt_x;
99                        opt_mean_y = opt_y;
100                    }
101                },
102            }
103            match (non_na_cnt >= min_periods_fixed, bias, is_observation) {
104                (_, _, false) => None,
105                (false, _, true) => None,
106                (true, false, true) => {
107                    if non_na_cnt == 1 {
108                        Some(cov)
109                    } else {
110                        let numerator = sum_wt * sum_wt;
111                        let denominator = numerator - sum_wt2;
112                        if denominator > T::zero() {
113                            Some((numerator / denominator) * cov)
114                        } else {
115                            None
116                        }
117                    }
118                },
119                (true, true, true) => Some(cov),
120            }
121        });
122
123    if do_sqrt {
124        res.map(|opt_x| opt_x.map(|x| x.sqrt())).collect_trusted()
125    } else {
126        res.collect_trusted()
127    }
128}
129
130pub fn ewm_cov<I, T>(
131    xs: I,
132    ys: I,
133    alpha: T,
134    adjust: bool,
135    bias: bool,
136    min_periods: usize,
137    ignore_nulls: bool,
138) -> PrimitiveArray<T>
139where
140    I: IntoIterator<Item = Option<T>>,
141    I::IntoIter: TrustedLen,
142    T: Float + NativeType + AddAssign + MulAssign + DivAssign,
143{
144    ewm_cov_internal(
145        xs,
146        ys,
147        alpha,
148        adjust,
149        bias,
150        min_periods,
151        ignore_nulls,
152        false,
153    )
154}
155
156pub fn ewm_var<I, T>(
157    xs: I,
158    alpha: T,
159    adjust: bool,
160    bias: bool,
161    min_periods: usize,
162    ignore_nulls: bool,
163) -> PrimitiveArray<T>
164where
165    I: IntoIterator<Item = Option<T>> + Clone,
166    I::IntoIter: TrustedLen,
167    T: Float + NativeType + AddAssign + MulAssign + DivAssign,
168{
169    ewm_cov_internal(
170        xs.clone(),
171        xs,
172        alpha,
173        adjust,
174        bias,
175        min_periods,
176        ignore_nulls,
177        false,
178    )
179}
180
181pub fn ewm_std<I, T>(
182    xs: I,
183    alpha: T,
184    adjust: bool,
185    bias: bool,
186    min_periods: usize,
187    ignore_nulls: bool,
188) -> PrimitiveArray<T>
189where
190    I: IntoIterator<Item = Option<T>> + Clone,
191    I::IntoIter: TrustedLen,
192    T: Float + NativeType + AddAssign + MulAssign + DivAssign,
193{
194    ewm_cov_internal(
195        xs.clone(),
196        xs,
197        alpha,
198        adjust,
199        bias,
200        min_periods,
201        ignore_nulls,
202        true,
203    )
204}
205
206#[cfg(test)]
207mod test {
208    use super::super::assert_allclose;
209    use super::*;
210    const ALPHA: f64 = 0.5;
211    const EPS: f64 = 1e-15;
212    use std::f64::consts::SQRT_2;
213
214    const XS: [Option<f64>; 7] = [
215        Some(1.0),
216        Some(5.0),
217        Some(7.0),
218        Some(1.0),
219        Some(2.0),
220        Some(1.0),
221        Some(4.0),
222    ];
223    const YS: [Option<f64>; 7] = [None, Some(5.0), Some(7.0), None, None, Some(1.0), Some(4.0)];
224
225    #[test]
226    fn test_ewm_var() {
227        assert_allclose!(
228            ewm_var(XS.to_vec(), ALPHA, true, true, 0, true),
229            PrimitiveArray::from([
230                Some(0.0),
231                Some(3.555_555_555_555_556),
232                Some(4.244_897_959_183_674),
233                Some(7.182_222_222_222_221),
234                Some(3.796_045_785_639_958),
235                Some(2.467_120_181_405_896),
236                Some(2.476_036_952_073_904_3),
237            ]),
238            EPS
239        );
240        assert_allclose!(
241            ewm_var(XS.to_vec(), ALPHA, true, true, 0, false),
242            PrimitiveArray::from([
243                Some(0.0),
244                Some(3.555_555_555_555_556),
245                Some(4.244_897_959_183_674),
246                Some(7.182_222_222_222_221),
247                Some(3.796_045_785_639_958),
248                Some(2.467_120_181_405_896),
249                Some(2.476_036_952_073_904_3),
250            ]),
251            EPS
252        );
253        assert_allclose!(
254            ewm_var(XS.to_vec(), ALPHA, true, false, 0, true),
255            PrimitiveArray::from([
256                Some(0.0),
257                Some(8.0),
258                Some(7.428_571_428_571_429),
259                Some(11.542_857_142_857_143),
260                Some(5.883_870_967_741_934_5),
261                Some(3.760_368_663_594_470_6),
262                Some(3.743_532_058_492_688_6),
263            ]),
264            EPS
265        );
266        assert_allclose!(
267            ewm_var(XS.to_vec(), ALPHA, true, false, 0, false),
268            PrimitiveArray::from([
269                Some(0.0),
270                Some(8.0),
271                Some(7.428_571_428_571_429),
272                Some(11.542_857_142_857_143),
273                Some(5.883_870_967_741_934_5),
274                Some(3.760_368_663_594_470_6),
275                Some(3.743_532_058_492_688_6),
276            ]),
277            EPS
278        );
279        assert_allclose!(
280            ewm_var(XS.to_vec(), ALPHA, false, true, 0, true),
281            PrimitiveArray::from([
282                Some(0.0),
283                Some(4.0),
284                Some(6.0),
285                Some(7.0),
286                Some(3.75),
287                Some(2.437_5),
288                Some(2.484_375),
289            ]),
290            EPS
291        );
292        assert_allclose!(
293            ewm_var(XS.to_vec(), ALPHA, false, true, 0, false),
294            PrimitiveArray::from([
295                Some(0.0),
296                Some(4.0),
297                Some(6.0),
298                Some(7.0),
299                Some(3.75),
300                Some(2.437_5),
301                Some(2.484_375),
302            ]),
303            EPS
304        );
305        assert_allclose!(
306            ewm_var(XS.to_vec(), ALPHA, false, true, 0, false),
307            PrimitiveArray::from([
308                Some(0.0),
309                Some(4.0),
310                Some(6.0),
311                Some(7.0),
312                Some(3.75),
313                Some(2.437_5),
314                Some(2.484_375),
315            ]),
316            EPS
317        );
318        assert_allclose!(
319            ewm_var(XS.to_vec(), ALPHA, false, false, 0, true),
320            PrimitiveArray::from([
321                Some(0.0),
322                Some(8.0),
323                Some(9.600_000_000_000_001),
324                Some(10.666_666_666_666_666),
325                Some(5.647_058_823_529_411),
326                Some(3.659_824_046_920_821),
327                Some(3.727_472_527_472_527_6),
328            ]),
329            EPS
330        );
331        assert_allclose!(
332            ewm_var(XS.to_vec(), ALPHA, false, false, 0, false),
333            PrimitiveArray::from([
334                Some(0.0),
335                Some(8.0),
336                Some(9.600_000_000_000_001),
337                Some(10.666_666_666_666_666),
338                Some(5.647_058_823_529_411),
339                Some(3.659_824_046_920_821),
340                Some(3.727_472_527_472_527_6),
341            ]),
342            EPS
343        );
344        assert_allclose!(
345            ewm_var(YS.to_vec(), ALPHA, true, true, 0, true),
346            PrimitiveArray::from([
347                None,
348                Some(0.0),
349                Some(0.888_888_888_888_889),
350                None,
351                None,
352                Some(7.346_938_775_510_203),
353                Some(3.555_555_555_555_555_4),
354            ]),
355            EPS
356        );
357        assert_allclose!(
358            ewm_var(YS.to_vec(), ALPHA, true, true, 0, false),
359            PrimitiveArray::from([
360                None,
361                Some(0.0),
362                Some(0.888_888_888_888_889),
363                None,
364                None,
365                Some(3.922_437_673_130_193_3),
366                Some(2.549_788_542_868_127_3),
367            ]),
368            EPS
369        );
370        assert_allclose!(
371            ewm_var(YS.to_vec(), ALPHA, true, false, 0, true),
372            PrimitiveArray::from([
373                None,
374                Some(0.0),
375                Some(2.0),
376                None,
377                None,
378                Some(12.857_142_857_142_856),
379                Some(5.714_285_714_285_714),
380            ]),
381            EPS
382        );
383        assert_allclose!(
384            ewm_var(YS.to_vec(), ALPHA, true, false, 0, false),
385            PrimitiveArray::from([
386                None,
387                Some(0.0),
388                Some(2.0),
389                None,
390                None,
391                Some(14.159_999_999_999_997),
392                Some(5.039_513_677_811_549_5),
393            ]),
394            EPS
395        );
396        assert_allclose!(
397            ewm_var(YS.to_vec(), ALPHA, false, true, 0, true),
398            PrimitiveArray::from([
399                None,
400                Some(0.0),
401                Some(1.0),
402                None,
403                None,
404                Some(6.75),
405                Some(3.437_5),
406            ]),
407            EPS
408        );
409        assert_allclose!(
410            ewm_var(YS.to_vec(), ALPHA, false, true, 0, false),
411            PrimitiveArray::from([None, Some(0.0), Some(1.0), None, None, Some(4.2), Some(3.1)]),
412            EPS
413        );
414        assert_allclose!(
415            ewm_var(YS.to_vec(), ALPHA, false, false, 0, true),
416            PrimitiveArray::from([
417                None,
418                Some(0.0),
419                Some(2.0),
420                None,
421                None,
422                Some(10.8),
423                Some(5.238_095_238_095_238),
424            ]),
425            EPS
426        );
427        assert_allclose!(
428            ewm_var(YS.to_vec(), ALPHA, false, false, 0, false),
429            PrimitiveArray::from([
430                None,
431                Some(0.0),
432                Some(2.0),
433                None,
434                None,
435                Some(12.352_941_176_470_589),
436                Some(5.299_145_299_145_3),
437            ]),
438            EPS
439        );
440    }
441
442    #[test]
443    fn test_ewm_cov() {
444        assert_allclose!(
445            ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, true, 0, true),
446            PrimitiveArray::from([
447                None,
448                Some(0.0),
449                Some(0.888_888_888_888_889),
450                None,
451                None,
452                Some(7.346_938_775_510_203),
453                Some(3.555_555_555_555_555_4)
454            ]),
455            EPS
456        );
457        assert_allclose!(
458            ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, true, 0, false),
459            PrimitiveArray::from([
460                None,
461                Some(0.0),
462                Some(0.888_888_888_888_889),
463                None,
464                None,
465                Some(3.922_437_673_130_193_3),
466                Some(2.549_788_542_868_127_3)
467            ]),
468            EPS
469        );
470        assert_allclose!(
471            ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, false, 0, true),
472            PrimitiveArray::from([
473                None,
474                Some(0.0),
475                Some(2.0),
476                None,
477                None,
478                Some(12.857_142_857_142_856),
479                Some(5.714_285_714_285_714)
480            ]),
481            EPS
482        );
483        assert_allclose!(
484            ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, false, 0, false),
485            PrimitiveArray::from([
486                None,
487                Some(0.0),
488                Some(2.0),
489                None,
490                None,
491                Some(14.159_999_999_999_997),
492                Some(5.039_513_677_811_549_5)
493            ]),
494            EPS
495        );
496        assert_allclose!(
497            ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, true, 0, true),
498            PrimitiveArray::from([
499                None,
500                Some(0.0),
501                Some(1.0),
502                None,
503                None,
504                Some(6.75),
505                Some(3.437_5)
506            ]),
507            EPS
508        );
509        assert_allclose!(
510            ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, true, 0, false),
511            PrimitiveArray::from([None, Some(0.0), Some(1.0), None, None, Some(4.2), Some(3.1)]),
512            EPS
513        );
514        assert_allclose!(
515            ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, false, 0, true),
516            PrimitiveArray::from([
517                None,
518                Some(0.0),
519                Some(2.0),
520                None,
521                None,
522                Some(10.8),
523                Some(5.238_095_238_095_238)
524            ]),
525            EPS
526        );
527        assert_allclose!(
528            ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, false, 0, false),
529            PrimitiveArray::from([
530                None,
531                Some(0.0),
532                Some(2.0),
533                None,
534                None,
535                Some(12.352_941_176_470_589),
536                Some(5.299_145_299_145_3)
537            ]),
538            EPS
539        );
540    }
541
542    #[test]
543    fn test_ewm_std() {
544        assert_allclose!(
545            ewm_std(XS.to_vec(), ALPHA, true, true, 0, true),
546            PrimitiveArray::from([
547                Some(0.0),
548                Some(1.885_618_083_164_126_7),
549                Some(2.060_315_014_550_851_3),
550                Some(2.679_966_832_298_904),
551                Some(1.948_344_370_392_451_5),
552                Some(1.570_706_904_997_204_2),
553                Some(1.573_542_802_746_053_2),
554            ]),
555            EPS
556        );
557        assert_allclose!(
558            ewm_std(XS.to_vec(), ALPHA, true, true, 0, false),
559            PrimitiveArray::from([
560                Some(0.0),
561                Some(1.885_618_083_164_126_7),
562                Some(2.060_315_014_550_851_3),
563                Some(2.679_966_832_298_904),
564                Some(1.948_344_370_392_451_5),
565                Some(1.570_706_904_997_204_2),
566                Some(1.573_542_802_746_053_2),
567            ]),
568            EPS
569        );
570        assert_allclose!(
571            ewm_std(XS.to_vec(), ALPHA, true, false, 0, true),
572            PrimitiveArray::from([
573                Some(0.0),
574                Some(2.828_427_124_746_190_3),
575                Some(2.725_540_575_476_987_5),
576                Some(3.397_478_056_273_085_3),
577                Some(2.425_669_179_369_259),
578                Some(1.939_167_002_502_484_5),
579                Some(1.934_820_937_061_796_6),
580            ]),
581            EPS
582        );
583        assert_allclose!(
584            ewm_std(XS.to_vec(), ALPHA, true, false, 0, false),
585            PrimitiveArray::from([
586                Some(0.0),
587                Some(2.828_427_124_746_190_3),
588                Some(2.725_540_575_476_987_5),
589                Some(3.397_478_056_273_085_3),
590                Some(2.425_669_179_369_259),
591                Some(1.939_167_002_502_484_5),
592                Some(1.934_820_937_061_796_6),
593            ]),
594            EPS
595        );
596        assert_allclose!(
597            ewm_std(XS.to_vec(), ALPHA, false, true, 0, true),
598            PrimitiveArray::from([
599                Some(0.0),
600                Some(2.0),
601                Some(2.449_489_742_783_178),
602                Some(2.645_751_311_064_590_7),
603                Some(1.936_491_673_103_708_5),
604                Some(1.561_249_499_599_599_6),
605                Some(1.576_190_026_614_811_4),
606            ]),
607            EPS
608        );
609        assert_allclose!(
610            ewm_std(XS.to_vec(), ALPHA, false, true, 0, false),
611            PrimitiveArray::from([
612                Some(0.0),
613                Some(2.0),
614                Some(2.449_489_742_783_178),
615                Some(2.645_751_311_064_590_7),
616                Some(1.936_491_673_103_708_5),
617                Some(1.561_249_499_599_599_6),
618                Some(1.576_190_026_614_811_4),
619            ]),
620            EPS
621        );
622        assert_allclose!(
623            ewm_std(XS.to_vec(), ALPHA, false, false, 0, true),
624            PrimitiveArray::from([
625                Some(0.0),
626                Some(2.828_427_124_746_190_3),
627                Some(3.098_386_676_965_933_6),
628                Some(3.265_986_323_710_904),
629                Some(2.376_354_103_144_018_3),
630                Some(1.913_066_660_344_281_2),
631                Some(1.930_666_342_865_210_7),
632            ]),
633            EPS
634        );
635        assert_allclose!(
636            ewm_std(XS.to_vec(), ALPHA, false, false, 0, false),
637            PrimitiveArray::from([
638                Some(0.0),
639                Some(2.828_427_124_746_190_3),
640                Some(3.098_386_676_965_933_6),
641                Some(3.265_986_323_710_904),
642                Some(2.376_354_103_144_018_3),
643                Some(1.913_066_660_344_281_2),
644                Some(1.930_666_342_865_210_7),
645            ]),
646            EPS
647        );
648        assert_allclose!(
649            ewm_std(YS.to_vec(), ALPHA, true, true, 0, true),
650            PrimitiveArray::from([
651                None,
652                Some(0.0),
653                Some(0.942_809_041_582_063_4),
654                None,
655                None,
656                Some(2.710_523_708_715_753_4),
657                Some(1.885_618_083_164_126_7),
658            ]),
659            EPS
660        );
661        assert_allclose!(
662            ewm_std(YS.to_vec(), ALPHA, true, true, 0, false),
663            PrimitiveArray::from([
664                None,
665                Some(0.0),
666                Some(0.942_809_041_582_063_4),
667                None,
668                None,
669                Some(1.980_514_497_076_503),
670                Some(1.596_805_731_098_222),
671            ]),
672            EPS
673        );
674        assert_allclose!(
675            ewm_std(YS.to_vec(), ALPHA, true, false, 0, true),
676            PrimitiveArray::from([
677                None,
678                Some(0.0),
679                Some(SQRT_2),
680                None,
681                None,
682                Some(3.585_685_828_003_181),
683                Some(2.390_457_218_668_787),
684            ]),
685            EPS
686        );
687        assert_allclose!(
688            ewm_std(YS.to_vec(), ALPHA, true, false, 0, false),
689            PrimitiveArray::from([
690                None,
691                Some(0.0),
692                Some(SQRT_2),
693                None,
694                None,
695                Some(3.762_977_544_445_355_3),
696                Some(2.244_886_116_891_356),
697            ]),
698            EPS
699        );
700        assert_allclose!(
701            ewm_std(YS.to_vec(), ALPHA, false, true, 0, true),
702            PrimitiveArray::from([
703                None,
704                Some(0.0),
705                Some(1.0),
706                None,
707                None,
708                Some(2.598_076_211_353_316),
709                Some(1.854_049_621_773_915_7),
710            ]),
711            EPS
712        );
713        assert_allclose!(
714            ewm_std(YS.to_vec(), ALPHA, false, true, 0, false),
715            PrimitiveArray::from([
716                None,
717                Some(0.0),
718                Some(1.0),
719                None,
720                None,
721                Some(2.049_390_153_191_92),
722                Some(1.760_681_686_165_901),
723            ]),
724            EPS
725        );
726        assert_allclose!(
727            ewm_std(YS.to_vec(), ALPHA, false, false, 0, true),
728            PrimitiveArray::from([
729                None,
730                Some(0.0),
731                Some(SQRT_2),
732                None,
733                None,
734                Some(3.286_335_345_030_997),
735                Some(2.288_688_541_085_317_5),
736            ]),
737            EPS
738        );
739        assert_allclose!(
740            ewm_std(YS.to_vec(), ALPHA, false, false, 0, false),
741            PrimitiveArray::from([
742                None,
743                Some(0.0),
744                Some(SQRT_2),
745                None,
746                None,
747                Some(3.514_675_116_774_036_7),
748                Some(2.301_987_249_996_250_4),
749            ]),
750            EPS
751        );
752    }
753
754    #[test]
755    fn test_ewm_min_periods() {
756        assert_allclose!(
757            ewm_var(YS.to_vec(), ALPHA, true, true, 0, false),
758            PrimitiveArray::from([
759                None,
760                Some(0.0),
761                Some(0.888_888_888_888_889),
762                None,
763                None,
764                Some(3.922_437_673_130_193_3),
765                Some(2.549_788_542_868_127_3),
766            ]),
767            EPS
768        );
769        assert_allclose!(
770            ewm_var(YS.to_vec(), ALPHA, true, true, 1, false),
771            PrimitiveArray::from([
772                None,
773                Some(0.0),
774                Some(0.888_888_888_888_889),
775                None,
776                None,
777                Some(3.922_437_673_130_193_3),
778                Some(2.549_788_542_868_127_3),
779            ]),
780            EPS
781        );
782        assert_allclose!(
783            ewm_var(YS.to_vec(), ALPHA, true, true, 2, false),
784            PrimitiveArray::from([
785                None,
786                None,
787                Some(0.888_888_888_888_889),
788                None,
789                None,
790                Some(3.922_437_673_130_193_3),
791                Some(2.549_788_542_868_127_3),
792            ]),
793            EPS
794        );
795    }
796}