1use std::ops::{AddAssign, DivAssign, MulAssign};
2
3use num_traits::Float;
4
5use crate::array::PrimitiveArray;
6use crate::legacy::utils::CustomIterTools;
7use crate::trusted_len::TrustedLen;
8use crate::types::NativeType;
9
10#[allow(clippy::too_many_arguments)]
11fn ewm_cov_internal<I, T>(
12 xs: I,
13 ys: I,
14 alpha: T,
15 adjust: bool,
16 bias: bool,
17 min_periods: usize,
18 ignore_nulls: bool,
19 do_sqrt: bool,
20) -> PrimitiveArray<T>
21where
22 I: IntoIterator<Item = Option<T>>,
23 I::IntoIter: TrustedLen,
24 T: Float + NativeType + AddAssign + MulAssign + DivAssign,
25{
26 let old_wt_factor = T::one() - alpha;
27 let new_wt = if adjust { T::one() } else { alpha };
28 let mut sum_wt = T::one();
29 let mut sum_wt2 = T::one();
30 let mut old_wt = T::one();
31
32 let mut opt_mean_x = None;
33 let mut opt_mean_y = None;
34 let mut cov = T::zero();
35 let mut non_na_cnt = 0usize;
36 let min_periods_fixed = if min_periods == 0 { 1 } else { min_periods };
37
38 let res = xs
39 .into_iter()
40 .zip(ys)
41 .enumerate()
42 .map(|(i, (opt_x, opt_y))| {
43 let is_observation = opt_x.is_some() && opt_y.is_some();
44 if is_observation {
45 non_na_cnt += 1;
46 }
47 match (i, opt_mean_x, opt_mean_y) {
48 (0, _, _) => {
49 if is_observation {
50 opt_mean_x = opt_x;
51 opt_mean_y = opt_y;
52 }
53 },
54 (_, Some(mean_x), Some(mean_y)) => {
55 if is_observation || !ignore_nulls {
56 sum_wt *= old_wt_factor;
57 sum_wt2 *= old_wt_factor * old_wt_factor;
58 old_wt *= old_wt_factor;
59 if is_observation {
60 let x = opt_x.unwrap();
61 let y = opt_y.unwrap();
62 let old_mean_x = mean_x;
63 let old_mean_y = mean_y;
64
65 if mean_x != x {
67 opt_mean_x =
68 Some((old_wt * old_mean_x + new_wt * x) / (old_wt + new_wt));
69 }
70
71 if mean_y != y {
73 opt_mean_y =
74 Some((old_wt * old_mean_y + new_wt * y) / (old_wt + new_wt));
75 }
76
77 cov = ((old_wt
78 * (cov
79 + ((old_mean_x - opt_mean_x.unwrap())
80 * (old_mean_y - opt_mean_y.unwrap()))))
81 + (new_wt
82 * ((x - opt_mean_x.unwrap()) * (y - opt_mean_y.unwrap()))))
83 / (old_wt + new_wt);
84
85 sum_wt += new_wt;
86 sum_wt2 += new_wt * new_wt;
87 old_wt += new_wt;
88 if !adjust {
89 sum_wt /= old_wt;
90 sum_wt2 /= old_wt * old_wt;
91 old_wt = T::one();
92 }
93 }
94 }
95 },
96 _ => {
97 if is_observation {
98 opt_mean_x = opt_x;
99 opt_mean_y = opt_y;
100 }
101 },
102 }
103 match (non_na_cnt >= min_periods_fixed, bias, is_observation) {
104 (_, _, false) => None,
105 (false, _, true) => None,
106 (true, false, true) => {
107 if non_na_cnt == 1 {
108 Some(cov)
109 } else {
110 let numerator = sum_wt * sum_wt;
111 let denominator = numerator - sum_wt2;
112 if denominator > T::zero() {
113 Some((numerator / denominator) * cov)
114 } else {
115 None
116 }
117 }
118 },
119 (true, true, true) => Some(cov),
120 }
121 });
122
123 if do_sqrt {
124 res.map(|opt_x| opt_x.map(|x| x.sqrt())).collect_trusted()
125 } else {
126 res.collect_trusted()
127 }
128}
129
130pub fn ewm_cov<I, T>(
131 xs: I,
132 ys: I,
133 alpha: T,
134 adjust: bool,
135 bias: bool,
136 min_periods: usize,
137 ignore_nulls: bool,
138) -> PrimitiveArray<T>
139where
140 I: IntoIterator<Item = Option<T>>,
141 I::IntoIter: TrustedLen,
142 T: Float + NativeType + AddAssign + MulAssign + DivAssign,
143{
144 ewm_cov_internal(
145 xs,
146 ys,
147 alpha,
148 adjust,
149 bias,
150 min_periods,
151 ignore_nulls,
152 false,
153 )
154}
155
156pub fn ewm_var<I, T>(
157 xs: I,
158 alpha: T,
159 adjust: bool,
160 bias: bool,
161 min_periods: usize,
162 ignore_nulls: bool,
163) -> PrimitiveArray<T>
164where
165 I: IntoIterator<Item = Option<T>> + Clone,
166 I::IntoIter: TrustedLen,
167 T: Float + NativeType + AddAssign + MulAssign + DivAssign,
168{
169 ewm_cov_internal(
170 xs.clone(),
171 xs,
172 alpha,
173 adjust,
174 bias,
175 min_periods,
176 ignore_nulls,
177 false,
178 )
179}
180
181pub fn ewm_std<I, T>(
182 xs: I,
183 alpha: T,
184 adjust: bool,
185 bias: bool,
186 min_periods: usize,
187 ignore_nulls: bool,
188) -> PrimitiveArray<T>
189where
190 I: IntoIterator<Item = Option<T>> + Clone,
191 I::IntoIter: TrustedLen,
192 T: Float + NativeType + AddAssign + MulAssign + DivAssign,
193{
194 ewm_cov_internal(
195 xs.clone(),
196 xs,
197 alpha,
198 adjust,
199 bias,
200 min_periods,
201 ignore_nulls,
202 true,
203 )
204}
205
206#[cfg(test)]
207mod test {
208 use super::super::assert_allclose;
209 use super::*;
210 const ALPHA: f64 = 0.5;
211 const EPS: f64 = 1e-15;
212 use std::f64::consts::SQRT_2;
213
214 const XS: [Option<f64>; 7] = [
215 Some(1.0),
216 Some(5.0),
217 Some(7.0),
218 Some(1.0),
219 Some(2.0),
220 Some(1.0),
221 Some(4.0),
222 ];
223 const YS: [Option<f64>; 7] = [None, Some(5.0), Some(7.0), None, None, Some(1.0), Some(4.0)];
224
225 #[test]
226 fn test_ewm_var() {
227 assert_allclose!(
228 ewm_var(XS.to_vec(), ALPHA, true, true, 0, true),
229 PrimitiveArray::from([
230 Some(0.0),
231 Some(3.555_555_555_555_556),
232 Some(4.244_897_959_183_674),
233 Some(7.182_222_222_222_221),
234 Some(3.796_045_785_639_958),
235 Some(2.467_120_181_405_896),
236 Some(2.476_036_952_073_904_3),
237 ]),
238 EPS
239 );
240 assert_allclose!(
241 ewm_var(XS.to_vec(), ALPHA, true, true, 0, false),
242 PrimitiveArray::from([
243 Some(0.0),
244 Some(3.555_555_555_555_556),
245 Some(4.244_897_959_183_674),
246 Some(7.182_222_222_222_221),
247 Some(3.796_045_785_639_958),
248 Some(2.467_120_181_405_896),
249 Some(2.476_036_952_073_904_3),
250 ]),
251 EPS
252 );
253 assert_allclose!(
254 ewm_var(XS.to_vec(), ALPHA, true, false, 0, true),
255 PrimitiveArray::from([
256 Some(0.0),
257 Some(8.0),
258 Some(7.428_571_428_571_429),
259 Some(11.542_857_142_857_143),
260 Some(5.883_870_967_741_934_5),
261 Some(3.760_368_663_594_470_6),
262 Some(3.743_532_058_492_688_6),
263 ]),
264 EPS
265 );
266 assert_allclose!(
267 ewm_var(XS.to_vec(), ALPHA, true, false, 0, false),
268 PrimitiveArray::from([
269 Some(0.0),
270 Some(8.0),
271 Some(7.428_571_428_571_429),
272 Some(11.542_857_142_857_143),
273 Some(5.883_870_967_741_934_5),
274 Some(3.760_368_663_594_470_6),
275 Some(3.743_532_058_492_688_6),
276 ]),
277 EPS
278 );
279 assert_allclose!(
280 ewm_var(XS.to_vec(), ALPHA, false, true, 0, true),
281 PrimitiveArray::from([
282 Some(0.0),
283 Some(4.0),
284 Some(6.0),
285 Some(7.0),
286 Some(3.75),
287 Some(2.437_5),
288 Some(2.484_375),
289 ]),
290 EPS
291 );
292 assert_allclose!(
293 ewm_var(XS.to_vec(), ALPHA, false, true, 0, false),
294 PrimitiveArray::from([
295 Some(0.0),
296 Some(4.0),
297 Some(6.0),
298 Some(7.0),
299 Some(3.75),
300 Some(2.437_5),
301 Some(2.484_375),
302 ]),
303 EPS
304 );
305 assert_allclose!(
306 ewm_var(XS.to_vec(), ALPHA, false, true, 0, false),
307 PrimitiveArray::from([
308 Some(0.0),
309 Some(4.0),
310 Some(6.0),
311 Some(7.0),
312 Some(3.75),
313 Some(2.437_5),
314 Some(2.484_375),
315 ]),
316 EPS
317 );
318 assert_allclose!(
319 ewm_var(XS.to_vec(), ALPHA, false, false, 0, true),
320 PrimitiveArray::from([
321 Some(0.0),
322 Some(8.0),
323 Some(9.600_000_000_000_001),
324 Some(10.666_666_666_666_666),
325 Some(5.647_058_823_529_411),
326 Some(3.659_824_046_920_821),
327 Some(3.727_472_527_472_527_6),
328 ]),
329 EPS
330 );
331 assert_allclose!(
332 ewm_var(XS.to_vec(), ALPHA, false, false, 0, false),
333 PrimitiveArray::from([
334 Some(0.0),
335 Some(8.0),
336 Some(9.600_000_000_000_001),
337 Some(10.666_666_666_666_666),
338 Some(5.647_058_823_529_411),
339 Some(3.659_824_046_920_821),
340 Some(3.727_472_527_472_527_6),
341 ]),
342 EPS
343 );
344 assert_allclose!(
345 ewm_var(YS.to_vec(), ALPHA, true, true, 0, true),
346 PrimitiveArray::from([
347 None,
348 Some(0.0),
349 Some(0.888_888_888_888_889),
350 None,
351 None,
352 Some(7.346_938_775_510_203),
353 Some(3.555_555_555_555_555_4),
354 ]),
355 EPS
356 );
357 assert_allclose!(
358 ewm_var(YS.to_vec(), ALPHA, true, true, 0, false),
359 PrimitiveArray::from([
360 None,
361 Some(0.0),
362 Some(0.888_888_888_888_889),
363 None,
364 None,
365 Some(3.922_437_673_130_193_3),
366 Some(2.549_788_542_868_127_3),
367 ]),
368 EPS
369 );
370 assert_allclose!(
371 ewm_var(YS.to_vec(), ALPHA, true, false, 0, true),
372 PrimitiveArray::from([
373 None,
374 Some(0.0),
375 Some(2.0),
376 None,
377 None,
378 Some(12.857_142_857_142_856),
379 Some(5.714_285_714_285_714),
380 ]),
381 EPS
382 );
383 assert_allclose!(
384 ewm_var(YS.to_vec(), ALPHA, true, false, 0, false),
385 PrimitiveArray::from([
386 None,
387 Some(0.0),
388 Some(2.0),
389 None,
390 None,
391 Some(14.159_999_999_999_997),
392 Some(5.039_513_677_811_549_5),
393 ]),
394 EPS
395 );
396 assert_allclose!(
397 ewm_var(YS.to_vec(), ALPHA, false, true, 0, true),
398 PrimitiveArray::from([
399 None,
400 Some(0.0),
401 Some(1.0),
402 None,
403 None,
404 Some(6.75),
405 Some(3.437_5),
406 ]),
407 EPS
408 );
409 assert_allclose!(
410 ewm_var(YS.to_vec(), ALPHA, false, true, 0, false),
411 PrimitiveArray::from([None, Some(0.0), Some(1.0), None, None, Some(4.2), Some(3.1)]),
412 EPS
413 );
414 assert_allclose!(
415 ewm_var(YS.to_vec(), ALPHA, false, false, 0, true),
416 PrimitiveArray::from([
417 None,
418 Some(0.0),
419 Some(2.0),
420 None,
421 None,
422 Some(10.8),
423 Some(5.238_095_238_095_238),
424 ]),
425 EPS
426 );
427 assert_allclose!(
428 ewm_var(YS.to_vec(), ALPHA, false, false, 0, false),
429 PrimitiveArray::from([
430 None,
431 Some(0.0),
432 Some(2.0),
433 None,
434 None,
435 Some(12.352_941_176_470_589),
436 Some(5.299_145_299_145_3),
437 ]),
438 EPS
439 );
440 }
441
442 #[test]
443 fn test_ewm_cov() {
444 assert_allclose!(
445 ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, true, 0, true),
446 PrimitiveArray::from([
447 None,
448 Some(0.0),
449 Some(0.888_888_888_888_889),
450 None,
451 None,
452 Some(7.346_938_775_510_203),
453 Some(3.555_555_555_555_555_4)
454 ]),
455 EPS
456 );
457 assert_allclose!(
458 ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, true, 0, false),
459 PrimitiveArray::from([
460 None,
461 Some(0.0),
462 Some(0.888_888_888_888_889),
463 None,
464 None,
465 Some(3.922_437_673_130_193_3),
466 Some(2.549_788_542_868_127_3)
467 ]),
468 EPS
469 );
470 assert_allclose!(
471 ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, false, 0, true),
472 PrimitiveArray::from([
473 None,
474 Some(0.0),
475 Some(2.0),
476 None,
477 None,
478 Some(12.857_142_857_142_856),
479 Some(5.714_285_714_285_714)
480 ]),
481 EPS
482 );
483 assert_allclose!(
484 ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, false, 0, false),
485 PrimitiveArray::from([
486 None,
487 Some(0.0),
488 Some(2.0),
489 None,
490 None,
491 Some(14.159_999_999_999_997),
492 Some(5.039_513_677_811_549_5)
493 ]),
494 EPS
495 );
496 assert_allclose!(
497 ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, true, 0, true),
498 PrimitiveArray::from([
499 None,
500 Some(0.0),
501 Some(1.0),
502 None,
503 None,
504 Some(6.75),
505 Some(3.437_5)
506 ]),
507 EPS
508 );
509 assert_allclose!(
510 ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, true, 0, false),
511 PrimitiveArray::from([None, Some(0.0), Some(1.0), None, None, Some(4.2), Some(3.1)]),
512 EPS
513 );
514 assert_allclose!(
515 ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, false, 0, true),
516 PrimitiveArray::from([
517 None,
518 Some(0.0),
519 Some(2.0),
520 None,
521 None,
522 Some(10.8),
523 Some(5.238_095_238_095_238)
524 ]),
525 EPS
526 );
527 assert_allclose!(
528 ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, false, 0, false),
529 PrimitiveArray::from([
530 None,
531 Some(0.0),
532 Some(2.0),
533 None,
534 None,
535 Some(12.352_941_176_470_589),
536 Some(5.299_145_299_145_3)
537 ]),
538 EPS
539 );
540 }
541
542 #[test]
543 fn test_ewm_std() {
544 assert_allclose!(
545 ewm_std(XS.to_vec(), ALPHA, true, true, 0, true),
546 PrimitiveArray::from([
547 Some(0.0),
548 Some(1.885_618_083_164_126_7),
549 Some(2.060_315_014_550_851_3),
550 Some(2.679_966_832_298_904),
551 Some(1.948_344_370_392_451_5),
552 Some(1.570_706_904_997_204_2),
553 Some(1.573_542_802_746_053_2),
554 ]),
555 EPS
556 );
557 assert_allclose!(
558 ewm_std(XS.to_vec(), ALPHA, true, true, 0, false),
559 PrimitiveArray::from([
560 Some(0.0),
561 Some(1.885_618_083_164_126_7),
562 Some(2.060_315_014_550_851_3),
563 Some(2.679_966_832_298_904),
564 Some(1.948_344_370_392_451_5),
565 Some(1.570_706_904_997_204_2),
566 Some(1.573_542_802_746_053_2),
567 ]),
568 EPS
569 );
570 assert_allclose!(
571 ewm_std(XS.to_vec(), ALPHA, true, false, 0, true),
572 PrimitiveArray::from([
573 Some(0.0),
574 Some(2.828_427_124_746_190_3),
575 Some(2.725_540_575_476_987_5),
576 Some(3.397_478_056_273_085_3),
577 Some(2.425_669_179_369_259),
578 Some(1.939_167_002_502_484_5),
579 Some(1.934_820_937_061_796_6),
580 ]),
581 EPS
582 );
583 assert_allclose!(
584 ewm_std(XS.to_vec(), ALPHA, true, false, 0, false),
585 PrimitiveArray::from([
586 Some(0.0),
587 Some(2.828_427_124_746_190_3),
588 Some(2.725_540_575_476_987_5),
589 Some(3.397_478_056_273_085_3),
590 Some(2.425_669_179_369_259),
591 Some(1.939_167_002_502_484_5),
592 Some(1.934_820_937_061_796_6),
593 ]),
594 EPS
595 );
596 assert_allclose!(
597 ewm_std(XS.to_vec(), ALPHA, false, true, 0, true),
598 PrimitiveArray::from([
599 Some(0.0),
600 Some(2.0),
601 Some(2.449_489_742_783_178),
602 Some(2.645_751_311_064_590_7),
603 Some(1.936_491_673_103_708_5),
604 Some(1.561_249_499_599_599_6),
605 Some(1.576_190_026_614_811_4),
606 ]),
607 EPS
608 );
609 assert_allclose!(
610 ewm_std(XS.to_vec(), ALPHA, false, true, 0, false),
611 PrimitiveArray::from([
612 Some(0.0),
613 Some(2.0),
614 Some(2.449_489_742_783_178),
615 Some(2.645_751_311_064_590_7),
616 Some(1.936_491_673_103_708_5),
617 Some(1.561_249_499_599_599_6),
618 Some(1.576_190_026_614_811_4),
619 ]),
620 EPS
621 );
622 assert_allclose!(
623 ewm_std(XS.to_vec(), ALPHA, false, false, 0, true),
624 PrimitiveArray::from([
625 Some(0.0),
626 Some(2.828_427_124_746_190_3),
627 Some(3.098_386_676_965_933_6),
628 Some(3.265_986_323_710_904),
629 Some(2.376_354_103_144_018_3),
630 Some(1.913_066_660_344_281_2),
631 Some(1.930_666_342_865_210_7),
632 ]),
633 EPS
634 );
635 assert_allclose!(
636 ewm_std(XS.to_vec(), ALPHA, false, false, 0, false),
637 PrimitiveArray::from([
638 Some(0.0),
639 Some(2.828_427_124_746_190_3),
640 Some(3.098_386_676_965_933_6),
641 Some(3.265_986_323_710_904),
642 Some(2.376_354_103_144_018_3),
643 Some(1.913_066_660_344_281_2),
644 Some(1.930_666_342_865_210_7),
645 ]),
646 EPS
647 );
648 assert_allclose!(
649 ewm_std(YS.to_vec(), ALPHA, true, true, 0, true),
650 PrimitiveArray::from([
651 None,
652 Some(0.0),
653 Some(0.942_809_041_582_063_4),
654 None,
655 None,
656 Some(2.710_523_708_715_753_4),
657 Some(1.885_618_083_164_126_7),
658 ]),
659 EPS
660 );
661 assert_allclose!(
662 ewm_std(YS.to_vec(), ALPHA, true, true, 0, false),
663 PrimitiveArray::from([
664 None,
665 Some(0.0),
666 Some(0.942_809_041_582_063_4),
667 None,
668 None,
669 Some(1.980_514_497_076_503),
670 Some(1.596_805_731_098_222),
671 ]),
672 EPS
673 );
674 assert_allclose!(
675 ewm_std(YS.to_vec(), ALPHA, true, false, 0, true),
676 PrimitiveArray::from([
677 None,
678 Some(0.0),
679 Some(SQRT_2),
680 None,
681 None,
682 Some(3.585_685_828_003_181),
683 Some(2.390_457_218_668_787),
684 ]),
685 EPS
686 );
687 assert_allclose!(
688 ewm_std(YS.to_vec(), ALPHA, true, false, 0, false),
689 PrimitiveArray::from([
690 None,
691 Some(0.0),
692 Some(SQRT_2),
693 None,
694 None,
695 Some(3.762_977_544_445_355_3),
696 Some(2.244_886_116_891_356),
697 ]),
698 EPS
699 );
700 assert_allclose!(
701 ewm_std(YS.to_vec(), ALPHA, false, true, 0, true),
702 PrimitiveArray::from([
703 None,
704 Some(0.0),
705 Some(1.0),
706 None,
707 None,
708 Some(2.598_076_211_353_316),
709 Some(1.854_049_621_773_915_7),
710 ]),
711 EPS
712 );
713 assert_allclose!(
714 ewm_std(YS.to_vec(), ALPHA, false, true, 0, false),
715 PrimitiveArray::from([
716 None,
717 Some(0.0),
718 Some(1.0),
719 None,
720 None,
721 Some(2.049_390_153_191_92),
722 Some(1.760_681_686_165_901),
723 ]),
724 EPS
725 );
726 assert_allclose!(
727 ewm_std(YS.to_vec(), ALPHA, false, false, 0, true),
728 PrimitiveArray::from([
729 None,
730 Some(0.0),
731 Some(SQRT_2),
732 None,
733 None,
734 Some(3.286_335_345_030_997),
735 Some(2.288_688_541_085_317_5),
736 ]),
737 EPS
738 );
739 assert_allclose!(
740 ewm_std(YS.to_vec(), ALPHA, false, false, 0, false),
741 PrimitiveArray::from([
742 None,
743 Some(0.0),
744 Some(SQRT_2),
745 None,
746 None,
747 Some(3.514_675_116_774_036_7),
748 Some(2.301_987_249_996_250_4),
749 ]),
750 EPS
751 );
752 }
753
754 #[test]
755 fn test_ewm_min_periods() {
756 assert_allclose!(
757 ewm_var(YS.to_vec(), ALPHA, true, true, 0, false),
758 PrimitiveArray::from([
759 None,
760 Some(0.0),
761 Some(0.888_888_888_888_889),
762 None,
763 None,
764 Some(3.922_437_673_130_193_3),
765 Some(2.549_788_542_868_127_3),
766 ]),
767 EPS
768 );
769 assert_allclose!(
770 ewm_var(YS.to_vec(), ALPHA, true, true, 1, false),
771 PrimitiveArray::from([
772 None,
773 Some(0.0),
774 Some(0.888_888_888_888_889),
775 None,
776 None,
777 Some(3.922_437_673_130_193_3),
778 Some(2.549_788_542_868_127_3),
779 ]),
780 EPS
781 );
782 assert_allclose!(
783 ewm_var(YS.to_vec(), ALPHA, true, true, 2, false),
784 PrimitiveArray::from([
785 None,
786 None,
787 Some(0.888_888_888_888_889),
788 None,
789 None,
790 Some(3.922_437_673_130_193_3),
791 Some(2.549_788_542_868_127_3),
792 ]),
793 EPS
794 );
795 }
796}