1use crate::distribution::Continuous;
2use crate::statistics::{Max, MeanN, Min, Mode, VarianceN};
3use nalgebra::{Cholesky, Const, DMatrix, DVector, Dim, DimMin, Dyn, OMatrix, OVector};
4use std::f64;
5use std::f64::consts::{E, PI};
6
7pub(super) fn density_normalization_and_exponential<D>(
10 mu: &OVector<f64, D>,
11 cov: &OMatrix<f64, D, D>,
12 precision: &OMatrix<f64, D, D>,
13 x: &OVector<f64, D>,
14) -> Option<(f64, f64)>
15where
16 D: DimMin<D, Output = D>,
17 nalgebra::DefaultAllocator:
18 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
19{
20 Some((
21 density_distribution_pdf_const(mu, cov)?,
22 density_distribution_exponential(mu, precision, x)?,
23 ))
24}
25
26#[inline]
29fn density_distribution_exponential<D>(
30 mu: &OVector<f64, D>,
31 precision: &OMatrix<f64, D, D>,
32 x: &OVector<f64, D>,
33) -> Option<f64>
34where
35 D: Dim,
36 nalgebra::DefaultAllocator:
37 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
38{
39 if x.shape_generic().0 != precision.shape_generic().0
40 || x.shape_generic().0 != mu.shape_generic().0
41 || !precision.is_square()
42 {
43 return None;
44 }
45
46 let dv = x - mu;
47 let exp_term: f64 = -0.5 * (precision * &dv).dot(&dv);
48 Some(exp_term)
49}
50
51#[inline]
54fn density_distribution_pdf_const<D>(mu: &OVector<f64, D>, cov: &OMatrix<f64, D, D>) -> Option<f64>
55where
56 D: DimMin<D, Output = D>,
57 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>
58 + nalgebra::allocator::Allocator<D, D>
59 + nalgebra::allocator::Allocator<D>,
60{
61 if cov.shape_generic().0 != mu.shape_generic().0 || !cov.is_square() {
62 return None;
63 }
64 let cov_det = cov.determinant();
65 Some(
66 ((2. * PI).powi(mu.nrows() as i32) * cov_det.abs())
67 .recip()
68 .sqrt(),
69 )
70}
71
72#[derive(Clone, PartialEq, Debug)]
88pub struct MultivariateNormal<D>
89where
90 D: Dim,
91 nalgebra::DefaultAllocator:
92 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
93{
94 cov_chol_decomp: OMatrix<f64, D, D>,
95 mu: OVector<f64, D>,
96 cov: OMatrix<f64, D, D>,
97 precision: OMatrix<f64, D, D>,
98 pdf_const: f64,
99}
100
101#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
103#[non_exhaustive]
104pub enum MultivariateNormalError {
105 CovInvalid,
107
108 MeanInvalid,
110
111 DimensionMismatch,
114
115 CholeskyFailed,
117}
118
119impl std::fmt::Display for MultivariateNormalError {
120 #[cfg_attr(coverage_nightly, coverage(off))]
121 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
122 match self {
123 MultivariateNormalError::CovInvalid => {
124 write!(f, "Covariance matrix is asymmetric or contains a NaN")
125 }
126 MultivariateNormalError::MeanInvalid => write!(f, "Mean vector contains a NaN"),
127 MultivariateNormalError::DimensionMismatch => write!(
128 f,
129 "Mean vector and covariance matrix do not have the same number of rows"
130 ),
131 MultivariateNormalError::CholeskyFailed => {
132 write!(f, "Computing the Cholesky decomposition failed")
133 }
134 }
135 }
136}
137
138impl std::error::Error for MultivariateNormalError {}
139
140impl MultivariateNormal<Dyn> {
141 pub fn new(mean: Vec<f64>, cov: Vec<f64>) -> Result<Self, MultivariateNormalError> {
149 let mean = DVector::from_vec(mean);
150 let cov = DMatrix::from_vec(mean.len(), mean.len(), cov);
151 MultivariateNormal::new_from_nalgebra(mean, cov)
152 }
153}
154
155impl<D> MultivariateNormal<D>
156where
157 D: DimMin<D, Output = D>,
158 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>
159 + nalgebra::allocator::Allocator<D, D>
160 + nalgebra::allocator::Allocator<D>,
161{
162 pub fn new_from_nalgebra(
171 mean: OVector<f64, D>,
172 cov: OMatrix<f64, D, D>,
173 ) -> Result<Self, MultivariateNormalError> {
174 if mean.iter().any(|f| f.is_nan()) {
175 return Err(MultivariateNormalError::MeanInvalid);
176 }
177
178 if !cov.is_square()
179 || cov.lower_triangle() != cov.upper_triangle().transpose()
180 || cov.iter().any(|f| f.is_nan())
181 {
182 return Err(MultivariateNormalError::CovInvalid);
183 }
184
185 if mean.shape_generic().0 != cov.shape_generic().0 {
187 return Err(MultivariateNormalError::DimensionMismatch);
188 }
189
190 match Cholesky::new(cov.clone()) {
193 None => Err(MultivariateNormalError::CholeskyFailed),
194 Some(cholesky_decomp) => {
195 let precision = cholesky_decomp.inverse();
196 Ok(MultivariateNormal {
197 pdf_const: density_distribution_pdf_const(&mean, &cov).unwrap(),
199 cov_chol_decomp: cholesky_decomp.unpack(),
200 mu: mean,
201 cov,
202 precision,
203 })
204 }
205 }
206 }
207
208 pub fn entropy(&self) -> Option<f64> {
218 Some(
219 0.5 * self
220 .variance()
221 .unwrap()
222 .scale(2. * PI * E)
223 .determinant()
224 .ln(),
225 )
226 }
227}
228
229impl<D> std::fmt::Display for MultivariateNormal<D>
230where
231 D: Dim,
232 nalgebra::DefaultAllocator:
233 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
234{
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 write!(f, "N({}, {})", &self.mu, &self.cov)
237 }
238}
239
240#[cfg(feature = "rand")]
241#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
242impl<D> ::rand::distributions::Distribution<OVector<f64, D>> for MultivariateNormal<D>
243where
244 D: Dim,
245 nalgebra::DefaultAllocator:
246 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
247{
248 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<f64, D> {
259 let d = crate::distribution::Normal::new(0., 1.).unwrap();
260 let z = OVector::from_distribution_generic(self.mu.shape_generic().0, Const::<1>, &d, rng);
261 (&self.cov_chol_decomp * z) + &self.mu
262 }
263}
264
265impl<D> Min<OVector<f64, D>> for MultivariateNormal<D>
266where
267 D: Dim,
268 nalgebra::DefaultAllocator:
269 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
270{
271 fn min(&self) -> OVector<f64, D> {
274 OMatrix::repeat_generic(self.mu.shape_generic().0, Const::<1>, f64::NEG_INFINITY)
275 }
276}
277
278impl<D> Max<OVector<f64, D>> for MultivariateNormal<D>
279where
280 D: Dim,
281 nalgebra::DefaultAllocator:
282 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
283{
284 fn max(&self) -> OVector<f64, D> {
287 OMatrix::repeat_generic(self.mu.shape_generic().0, Const::<1>, f64::INFINITY)
288 }
289}
290
291impl<D> MeanN<OVector<f64, D>> for MultivariateNormal<D>
292where
293 D: Dim,
294 nalgebra::DefaultAllocator:
295 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
296{
297 fn mean(&self) -> Option<OVector<f64, D>> {
303 Some(self.mu.clone())
304 }
305}
306
307impl<D> VarianceN<OMatrix<f64, D, D>> for MultivariateNormal<D>
308where
309 D: Dim,
310 nalgebra::DefaultAllocator:
311 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
312{
313 fn variance(&self) -> Option<OMatrix<f64, D, D>> {
315 Some(self.cov.clone())
316 }
317}
318
319impl<D> Mode<OVector<f64, D>> for MultivariateNormal<D>
320where
321 D: Dim,
322 nalgebra::DefaultAllocator:
323 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
324{
325 fn mode(&self) -> OVector<f64, D> {
335 self.mu.clone()
336 }
337}
338
339impl<D> Continuous<&OVector<f64, D>, f64> for MultivariateNormal<D>
340where
341 D: Dim,
342 nalgebra::DefaultAllocator:
343 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
344{
345 fn pdf(&self, x: &OVector<f64, D>) -> f64 {
357 self.pdf_const
358 * density_distribution_exponential(&self.mu, &self.precision, x)
359 .unwrap()
360 .exp()
361 }
362
363 fn ln_pdf(&self, x: &OVector<f64, D>) -> f64 {
366 self.pdf_const.ln()
367 + density_distribution_exponential(&self.mu, &self.precision, x).unwrap()
368 }
369}
370
371#[rustfmt::skip]
372#[cfg(test)]
373mod tests {
374 use core::fmt::Debug;
375
376 use nalgebra::{dmatrix, dvector, matrix, vector, DimMin, OMatrix, OVector};
377
378 use crate::{
379 distribution::{Continuous, MultivariateNormal},
380 statistics::{Max, MeanN, Min, Mode, VarianceN},
381 };
382
383 use super::MultivariateNormalError;
384
385 fn try_create<D>(mean: OVector<f64, D>, covariance: OMatrix<f64, D, D>) -> MultivariateNormal<D>
386 where
387 D: DimMin<D, Output = D>,
388 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>
389 + nalgebra::allocator::Allocator<D, D>
390 + nalgebra::allocator::Allocator<D>,
391 {
392 let mvn = MultivariateNormal::new_from_nalgebra(mean, covariance);
393 assert!(mvn.is_ok());
394 mvn.unwrap()
395 }
396
397 fn create_case<D>(mean: OVector<f64, D>, covariance: OMatrix<f64, D, D>)
398 where
399 D: DimMin<D, Output = D>,
400 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>
401 + nalgebra::allocator::Allocator<D, D>
402 + nalgebra::allocator::Allocator<D>,
403 {
404 let mvn = try_create(mean.clone(), covariance.clone());
405 assert_eq!(mean, mvn.mean().unwrap());
406 assert_eq!(covariance, mvn.variance().unwrap());
407 }
408
409 fn bad_create_case<D>(mean: OVector<f64, D>, covariance: OMatrix<f64, D, D>)
410 where
411 D: DimMin<D, Output = D>,
412 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>
413 + nalgebra::allocator::Allocator<D, D>
414 + nalgebra::allocator::Allocator<D>,
415 {
416 let mvn = MultivariateNormal::new_from_nalgebra(mean, covariance);
417 assert!(mvn.is_err());
418 }
419
420 fn test_case<T, F, D>(
421 mean: OVector<f64, D>, covariance: OMatrix<f64, D, D>, expected: T, eval: F,
422 ) where
423 T: Debug + PartialEq,
424 F: FnOnce(MultivariateNormal<D>) -> T,
425 D: DimMin<D, Output = D>,
426 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>
427 + nalgebra::allocator::Allocator<D, D>
428 + nalgebra::allocator::Allocator<D>,
429 {
430 let mvn = try_create(mean, covariance);
431 let x = eval(mvn);
432 assert_eq!(expected, x);
433 }
434
435 fn test_almost<F, D>(
436 mean: OVector<f64, D>, covariance: OMatrix<f64, D, D>, expected: f64, acc: f64, eval: F,
437 ) where
438 F: FnOnce(MultivariateNormal<D>) -> f64,
439 D: DimMin<D, Output = D>,
440 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>
441 + nalgebra::allocator::Allocator<D, D>
442 + nalgebra::allocator::Allocator<D>,
443 {
444 let mvn = try_create(mean, covariance);
445 let x = eval(mvn);
446 assert_almost_eq!(expected, x, acc);
447 }
448
449 #[test]
450 fn test_create() {
451 create_case(vector![0., 0.], matrix![1., 0.; 0., 1.]);
452 create_case(vector![10., 5.], matrix![2., 1.; 1., 2.]);
453 create_case(
454 vector![4., 5., 6.],
455 matrix![2., 1., 0.; 1., 2., 1.; 0., 1., 2.],
456 );
457 create_case(dvector![0., f64::INFINITY], dmatrix![1., 0.; 0., 1.]);
458 create_case(
459 dvector![0., 0.],
460 dmatrix![f64::INFINITY, 0.; 0., f64::INFINITY],
461 );
462 }
463
464 #[test]
465 fn test_bad_create() {
466 bad_create_case(vector![0., 0.], matrix![1., 1.; 0., 1.]);
468 bad_create_case(vector![0., 0.], matrix![1., 2.; 2., 1.]);
470 bad_create_case(dvector![0., f64::NAN], dmatrix![1., 0.; 0., 1.]);
472 bad_create_case(dvector![0., 0.], dmatrix![1., 0.; 0., f64::NAN]);
474 }
475
476 #[test]
477 fn test_variance() {
478 let variance = |x: MultivariateNormal<_>| x.variance().unwrap();
479 test_case(
480 vector![0., 0.],
481 matrix![1., 0.; 0., 1.],
482 matrix![1., 0.; 0., 1.],
483 variance,
484 );
485 test_case(
486 vector![0., 0.],
487 matrix![f64::INFINITY, 0.; 0., f64::INFINITY],
488 matrix![f64::INFINITY, 0.; 0., f64::INFINITY],
489 variance,
490 );
491 }
492
493 #[test]
494 fn test_entropy() {
495 let entropy = |x: MultivariateNormal<_>| x.entropy().unwrap();
496 test_case(
497 dvector![0., 0.],
498 dmatrix![1., 0.; 0., 1.],
499 2.8378770664093453,
500 entropy,
501 );
502 test_case(
503 dvector![0., 0.],
504 dmatrix![1., 0.5; 0.5, 1.],
505 2.694036030183455,
506 entropy,
507 );
508 test_case(
509 dvector![0., 0.],
510 dmatrix![f64::INFINITY, 0.; 0., f64::INFINITY],
511 f64::INFINITY,
512 entropy,
513 );
514 }
515
516 #[test]
517 fn test_mode() {
518 let mode = |x: MultivariateNormal<_>| x.mode();
519 test_case(
520 vector![0., 0.],
521 matrix![1., 0.; 0., 1.],
522 vector![0., 0.],
523 mode,
524 );
525 test_case(
526 vector![f64::INFINITY, f64::INFINITY],
527 matrix![1., 0.; 0., 1.],
528 vector![f64::INFINITY, f64::INFINITY],
529 mode,
530 );
531 }
532
533 #[test]
534 fn test_min_max() {
535 let min = |x: MultivariateNormal<_>| x.min();
536 let max = |x: MultivariateNormal<_>| x.max();
537 test_case(
538 dvector![0., 0.],
539 dmatrix![1., 0.; 0., 1.],
540 dvector![f64::NEG_INFINITY, f64::NEG_INFINITY],
541 min,
542 );
543 test_case(
544 dvector![0., 0.],
545 dmatrix![1., 0.; 0., 1.],
546 dvector![f64::INFINITY, f64::INFINITY],
547 max,
548 );
549 test_case(
550 dvector![10., 1.],
551 dmatrix![1., 0.; 0., 1.],
552 dvector![f64::NEG_INFINITY, f64::NEG_INFINITY],
553 min,
554 );
555 test_case(
556 dvector![-3., 5.],
557 dmatrix![1., 0.; 0., 1.],
558 dvector![f64::INFINITY, f64::INFINITY],
559 max,
560 );
561 }
562
563 #[test]
564 fn test_pdf() {
565 let pdf = |arg| move |x: MultivariateNormal<_>| x.pdf(&arg);
566 test_case(
567 vector![0., 0.],
568 matrix![1., 0.; 0., 1.],
569 0.05854983152431917,
570 pdf(vector![1., 1.]),
571 );
572 test_almost(
573 vector![0., 0.],
574 matrix![1., 0.; 0., 1.],
575 0.013064233284684921,
576 1e-15,
577 pdf(vector![1., 2.]),
578 );
579 test_almost(
580 vector![0., 0.],
581 matrix![1., 0.; 0., 1.],
582 1.8618676045881531e-23,
583 1e-35,
584 pdf(vector![1., 10.]),
585 );
586 test_almost(
587 vector![0., 0.],
588 matrix![1., 0.; 0., 1.],
589 5.920684802611216e-45,
590 1e-58,
591 pdf(vector![10., 10.]),
592 );
593 test_almost(
594 vector![0., 0.],
595 matrix![1., 0.9; 0.9, 1.],
596 1.6576716577547003e-05,
597 1e-18,
598 pdf(vector![1., -1.]),
599 );
600 test_almost(
601 vector![0., 0.],
602 matrix![1., 0.99; 0.99, 1.],
603 4.1970621773477824e-44,
604 1e-54,
605 pdf(vector![1., -1.]),
606 );
607 test_almost(
608 vector![0.5, -0.2],
609 matrix![2.0, 0.3; 0.3, 0.5],
610 0.0013075203140666656,
611 1e-15,
612 pdf(vector![2., 2.]),
613 );
614 test_case(
615 vector![0., 0.],
616 matrix![f64::INFINITY, 0.; 0., f64::INFINITY],
617 0.0,
618 pdf(vector![10., 10.]),
619 );
620 test_case(
621 vector![0., 0.],
622 matrix![f64::INFINITY, 0.; 0., f64::INFINITY],
623 0.0,
624 pdf(vector![100., 100.]),
625 );
626 }
627
628 #[test]
629 fn test_ln_pdf() {
630 let ln_pdf = |arg| move |x: MultivariateNormal<_>| x.ln_pdf(&arg);
631 test_case(
632 dvector![0., 0.],
633 dmatrix![1., 0.; 0., 1.],
634 (0.05854983152431917f64).ln(),
635 ln_pdf(dvector![1., 1.]),
636 );
637 test_almost(
638 dvector![0., 0.],
639 dmatrix![1., 0.; 0., 1.],
640 (0.013064233284684921f64).ln(),
641 1e-15,
642 ln_pdf(dvector![1., 2.]),
643 );
644 test_almost(
645 dvector![0., 0.],
646 dmatrix![1., 0.; 0., 1.],
647 (1.8618676045881531e-23f64).ln(),
648 1e-15,
649 ln_pdf(dvector![1., 10.]),
650 );
651 test_almost(
652 dvector![0., 0.],
653 dmatrix![1., 0.; 0., 1.],
654 (5.920684802611216e-45f64).ln(),
655 1e-15,
656 ln_pdf(dvector![10., 10.]),
657 );
658 test_almost(
659 dvector![0., 0.],
660 dmatrix![1., 0.9; 0.9, 1.],
661 (1.6576716577547003e-05f64).ln(),
662 1e-14,
663 ln_pdf(dvector![1., -1.]),
664 );
665 test_almost(
666 dvector![0., 0.],
667 dmatrix![1., 0.99; 0.99, 1.],
668 (4.1970621773477824e-44f64).ln(),
669 1e-12,
670 ln_pdf(dvector![1., -1.]),
671 );
672 test_almost(
673 dvector![0.5, -0.2],
674 dmatrix![2.0, 0.3; 0.3, 0.5],
675 (0.0013075203140666656f64).ln(),
676 1e-15,
677 ln_pdf(dvector![2., 2.]),
678 );
679 test_case(
680 dvector![0., 0.],
681 dmatrix![f64::INFINITY, 0.; 0., f64::INFINITY],
682 f64::NEG_INFINITY,
683 ln_pdf(dvector![10., 10.]),
684 );
685 test_case(
686 dvector![0., 0.],
687 dmatrix![f64::INFINITY, 0.; 0., f64::INFINITY],
688 f64::NEG_INFINITY,
689 ln_pdf(dvector![100., 100.]),
690 );
691 }
692
693 #[test]
694 #[should_panic]
695 fn test_pdf_mismatched_arg_size() {
696 let mvn = MultivariateNormal::new(vec![0., 0.], vec![1., 0., 0., 1.,]).unwrap();
697 mvn.pdf(&vec![1.].into()); }
699
700 #[test]
701 fn test_error_is_sync_send() {
702 fn assert_sync_send<T: Sync + Send>() {}
703 assert_sync_send::<MultivariateNormalError>();
704 }
705}