statrs/distribution/
multivariate_normal.rs

1use crate::distribution::Continuous;
2use crate::distribution::Normal;
3use crate::statistics::{Max, MeanN, Min, Mode, VarianceN};
4use crate::{Result, StatsError};
5use nalgebra::{
6    base::allocator::Allocator, base::dimension::DimName, Cholesky, DefaultAllocator, Dim, DimMin,
7    LU, U1,
8};
9use nalgebra::{DMatrix, DVector};
10use rand::Rng;
11use std::f64;
12use std::f64::consts::{E, PI};
13
14/// Implements the [Multivariate Normal](https://en.wikipedia.org/wiki/Multivariate_normal_distribution)
15/// distribution using the "nalgebra" crate for matrix operations
16///
17/// # Examples
18///
19/// ```
20/// use statrs::distribution::{MultivariateNormal, Continuous};
21/// use nalgebra::{DVector, DMatrix};
22/// use statrs::statistics::{MeanN, VarianceN};
23///
24/// let mvn = MultivariateNormal::new(vec![0., 0.], vec![1., 0., 0., 1.]).unwrap();
25/// assert_eq!(mvn.mean().unwrap(), DVector::from_vec(vec![0., 0.]));
26/// assert_eq!(mvn.variance().unwrap(), DMatrix::from_vec(2, 2, vec![1., 0., 0., 1.]));
27/// assert_eq!(mvn.pdf(&DVector::from_vec(vec![1.,  1.])), 0.05854983152431917);
28/// ```
29#[derive(Debug, Clone, PartialEq)]
30pub struct MultivariateNormal {
31    dim: usize,
32    cov_chol_decomp: DMatrix<f64>,
33    mu: DVector<f64>,
34    cov: DMatrix<f64>,
35    precision: DMatrix<f64>,
36    pdf_const: f64,
37}
38
39impl MultivariateNormal {
40    ///  Constructs a new multivariate normal distribution with a mean of `mean`
41    /// and covariance matrix `cov`
42    ///
43    /// # Errors
44    ///
45    /// Returns an error if the given covariance matrix is not
46    /// symmetric or positive-definite
47    pub fn new(mean: Vec<f64>, cov: Vec<f64>) -> Result<Self> {
48        let mean = DVector::from_vec(mean);
49        let cov = DMatrix::from_vec(mean.len(), mean.len(), cov);
50        let dim = mean.len();
51        // Check that the provided covariance matrix is symmetric
52        if cov.lower_triangle() != cov.upper_triangle().transpose()
53        // Check that mean and covariance do not contain NaN
54            || mean.iter().any(|f| f.is_nan())
55            || cov.iter().any(|f| f.is_nan())
56        // Check that the dimensions match
57            || mean.nrows() != cov.nrows() || cov.nrows() != cov.ncols()
58        {
59            return Err(StatsError::BadParams);
60        }
61        let cov_det = cov.determinant();
62        let pdf_const = ((2. * PI).powi(mean.nrows() as i32) * cov_det.abs())
63            .recip()
64            .sqrt();
65        // Store the Cholesky decomposition of the covariance matrix
66        // for sampling
67        match Cholesky::new(cov.clone()) {
68            None => Err(StatsError::BadParams),
69            Some(cholesky_decomp) => {
70                let precision = cholesky_decomp.inverse();
71                Ok(MultivariateNormal {
72                    dim,
73                    cov_chol_decomp: cholesky_decomp.unpack(),
74                    mu: mean,
75                    cov,
76                    precision,
77                    pdf_const,
78                })
79            }
80        }
81    }
82    /// Returns the entropy of the multivariate normal distribution
83    ///
84    /// # Formula
85    ///
86    /// ```ignore
87    /// (1 / 2) * ln(det(2 * π * e * Σ))
88    /// ```
89    ///
90    /// where `Σ` is the covariance matrix and `det` is the determinant
91    pub fn entropy(&self) -> Option<f64> {
92        Some(
93            0.5 * self
94                .variance()
95                .unwrap()
96                .scale(2. * PI * E)
97                .determinant()
98                .ln(),
99        )
100    }
101}
102
103impl ::rand::distributions::Distribution<DVector<f64>> for MultivariateNormal {
104    /// Samples from the multivariate normal distribution
105    ///
106    /// # Formula
107    /// L * Z + μ
108    ///
109    /// where `L` is the Cholesky decomposition of the covariance matrix,
110    /// `Z` is a vector of normally distributed random variables, and
111    /// `μ` is the mean vector
112
113    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> DVector<f64> {
114        let d = Normal::new(0., 1.).unwrap();
115        let z = DVector::<f64>::from_distribution(self.dim, &d, rng);
116        (&self.cov_chol_decomp * z) + &self.mu
117    }
118}
119
120impl Min<DVector<f64>> for MultivariateNormal {
121    /// Returns the minimum value in the domain of the
122    /// multivariate normal distribution represented by a real vector
123    fn min(&self) -> DVector<f64> {
124        DVector::from_vec(vec![f64::NEG_INFINITY; self.dim])
125    }
126}
127
128impl Max<DVector<f64>> for MultivariateNormal {
129    /// Returns the maximum value in the domain of the
130    /// multivariate normal distribution represented by a real vector
131    fn max(&self) -> DVector<f64> {
132        DVector::from_vec(vec![f64::INFINITY; self.dim])
133    }
134}
135
136impl MeanN<DVector<f64>> for MultivariateNormal {
137    /// Returns the mean of the normal distribution
138    ///
139    /// # Remarks
140    ///
141    /// This is the same mean used to construct the distribution
142    fn mean(&self) -> Option<DVector<f64>> {
143        let mut vec = vec![];
144        for elt in self.mu.clone().into_iter() {
145            vec.push(*elt);
146        }
147        Some(DVector::from_vec(vec))
148    }
149}
150
151impl VarianceN<DMatrix<f64>> for MultivariateNormal {
152    /// Returns the covariance matrix of the multivariate normal distribution
153    fn variance(&self) -> Option<DMatrix<f64>> {
154        Some(self.cov.clone())
155    }
156}
157
158impl Mode<DVector<f64>> for MultivariateNormal {
159    /// Returns the mode of the multivariate normal distribution
160    ///
161    /// # Formula
162    ///
163    /// ```ignore
164    /// μ
165    /// ```
166    ///
167    /// where `μ` is the mean
168    fn mode(&self) -> DVector<f64> {
169        self.mu.clone()
170    }
171}
172
173impl<'a> Continuous<&'a DVector<f64>, f64> for MultivariateNormal {
174    /// Calculates the probability density function for the multivariate
175    /// normal distribution at `x`
176    ///
177    /// # Formula
178    ///
179    /// ```ignore
180    /// (2 * π) ^ (-k / 2) * det(Σ) ^ (1 / 2) * e ^ ( -(1 / 2) * transpose(x - μ) * inv(Σ) * (x - μ))
181    /// ```
182    ///
183    /// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant
184    /// of the covariance matrix, and `k` is the dimension of the distribution
185    fn pdf(&self, x: &'a DVector<f64>) -> f64 {
186        let dv = x - &self.mu;
187        let exp_term = -0.5
188            * *(&dv.transpose() * &self.precision * &dv)
189                .get((0, 0))
190                .unwrap();
191        self.pdf_const * exp_term.exp()
192    }
193    /// Calculates the log probability density function for the multivariate
194    /// normal distribution at `x`. Equivalent to pdf(x).ln().
195    fn ln_pdf(&self, x: &'a DVector<f64>) -> f64 {
196        let dv = x - &self.mu;
197        let exp_term = -0.5
198            * *(&dv.transpose() * &self.precision * &dv)
199                .get((0, 0))
200                .unwrap();
201        self.pdf_const.ln() + exp_term
202    }
203}
204
205#[rustfmt::skip]
206#[cfg(all(test, feature = "nightly"))]
207mod tests  {
208    use crate::distribution::{Continuous, MultivariateNormal};
209    use crate::statistics::*;
210    use crate::consts::ACC;
211    use core::fmt::Debug;
212    use nalgebra::base::allocator::Allocator;
213    use nalgebra::{
214        DefaultAllocator, Dim, DimMin, DimName, Matrix2, Matrix3, Vector2, Vector3,
215        U1, U2,
216    };
217
218    fn try_create(mean: Vec<f64>, covariance: Vec<f64>) -> MultivariateNormal
219    {
220        let mvn = MultivariateNormal::new(mean, covariance);
221        assert!(mvn.is_ok());
222        mvn.unwrap()
223    }
224
225    fn create_case(mean: Vec<f64>, covariance: Vec<f64>)
226    {
227        let mvn = try_create(mean.clone(), covariance.clone());
228        assert_eq!(DVector::from_vec(mean.clone()), mvn.mean().unwrap());
229        assert_eq!(DMatrix::from_vec(mean.len(), mean.len(), covariance), mvn.variance().unwrap());
230    }
231
232    fn bad_create_case(mean: Vec<f64>, covariance: Vec<f64>)
233    {
234        let mvn = MultivariateNormal::new(mean, covariance);
235        assert!(mvn.is_err());
236    }
237
238    fn test_case<T, F>(mean: Vec<f64>, covariance: Vec<f64>, expected: T, eval: F)
239    where
240        T: Debug + PartialEq,
241        F: FnOnce(MultivariateNormal) -> T,
242    {
243        let mvn = try_create(mean, covariance);
244        let x = eval(mvn);
245        assert_eq!(expected, x);
246    }
247
248    fn test_almost<F>(
249        mean: Vec<f64>,
250        covariance: Vec<f64>,
251        expected: f64,
252        acc: f64,
253        eval: F,
254    ) where
255        F: FnOnce(MultivariateNormal) -> f64,
256    {
257        let mvn = try_create(mean, covariance);
258        let x = eval(mvn);
259        assert_almost_eq!(expected, x, acc);
260    }
261
262    use super::*;
263
264    macro_rules! dvec {
265        ($($x:expr),*) => (DVector::from_vec(vec![$($x),*]));
266    }
267
268    macro_rules! mat2 {
269        ($x11:expr, $x12:expr, $x21:expr, $x22:expr) => (DMatrix::from_vec(2,2,vec![$x11, $x12, $x21, $x22]));
270    }
271
272    // macro_rules! mat3 {
273    //     ($x11:expr, $x12:expr, $x13:expr, $x21:expr, $x22:expr, $x23:expr, $x31:expr, $x32:expr, $x33:expr) => (DMatrix::from_vec(3,3,vec![$x11, $x12, $x13, $x21, $x22, $x23, $x31, $x32, $x33]));
274    // }
275
276    #[test]
277    fn test_create() {
278        create_case(vec![0., 0.], vec![1., 0., 0., 1.]);
279        create_case(vec![10.,  5.], vec![2., 1., 1., 2.]);
280        create_case(vec![4., 5., 6.], vec![2., 1., 0., 1., 2., 1., 0., 1., 2.]);
281        create_case(vec![0., f64::INFINITY], vec![1., 0., 0., 1.]);
282        create_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY]);
283    }
284
285    #[test]
286    fn test_bad_create() {
287        // Covariance not symmetric
288        bad_create_case(vec![0., 0.], vec![1., 1., 0., 1.]);
289        // Covariance not positive-definite
290        bad_create_case(vec![0., 0.], vec![1., 2., 2., 1.]);
291        // NaN in mean
292        bad_create_case(vec![0., f64::NAN], vec![1., 0., 0., 1.]);
293        // NaN in Covariance Matrix
294        bad_create_case(vec![0., 0.], vec![1., 0., 0., f64::NAN]);
295    }
296
297    #[test]
298    fn test_variance() {
299        let variance = |x: MultivariateNormal| x.variance().unwrap();
300        test_case(vec![0., 0.], vec![1., 0., 0., 1.], mat2![1., 0., 0., 1.], variance);
301        test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], mat2![f64::INFINITY, 0., 0., f64::INFINITY], variance);
302    }
303
304    #[test]
305    fn test_entropy() {
306        let entropy = |x: MultivariateNormal| x.entropy().unwrap();
307        test_case(vec![0., 0.], vec![1., 0., 0., 1.], 2.8378770664093453, entropy);
308        test_case(vec![0., 0.], vec![1., 0.5, 0.5, 1.], 2.694036030183455, entropy);
309        test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], f64::INFINITY, entropy);
310    }
311
312    #[test]
313    fn test_mode() {
314        let mode = |x: MultivariateNormal| x.mode();
315        test_case(vec![0., 0.], vec![1., 0., 0., 1.], dvec![0.,  0.], mode);
316        test_case(vec![f64::INFINITY, f64::INFINITY], vec![1., 0., 0., 1.], dvec![f64::INFINITY,  f64::INFINITY], mode);
317    }
318
319    #[test]
320    fn test_min_max() {
321        let min = |x: MultivariateNormal| x.min();
322        let max = |x: MultivariateNormal| x.max();
323        test_case(vec![0., 0.], vec![1., 0., 0., 1.], dvec![f64::NEG_INFINITY, f64::NEG_INFINITY], min);
324        test_case(vec![0., 0.], vec![1., 0., 0., 1.], dvec![f64::INFINITY, f64::INFINITY], max);
325        test_case(vec![10., 1.], vec![1., 0., 0., 1.], dvec![f64::NEG_INFINITY, f64::NEG_INFINITY], min);
326        test_case(vec![-3., 5.], vec![1., 0., 0., 1.], dvec![f64::INFINITY, f64::INFINITY], max);
327    }
328
329    #[test]
330    fn test_pdf() {
331        let pdf = |arg: DVector<f64>| move |x: MultivariateNormal| x.pdf(&arg);
332        test_case(vec![0., 0.], vec![1., 0., 0., 1.], 0.05854983152431917, pdf(dvec![1., 1.]));
333        test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 0.013064233284684921, 1e-15, pdf(dvec![1., 2.]));
334        test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 1.8618676045881531e-23, 1e-35, pdf(dvec![1., 10.]));
335        test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 5.920684802611216e-45, 1e-58, pdf(dvec![10., 10.]));
336        test_almost(vec![0., 0.], vec![1., 0.9, 0.9, 1.], 1.6576716577547003e-05, 1e-18, pdf(dvec![1., -1.]));
337        test_almost(vec![0., 0.], vec![1., 0.99, 0.99, 1.], 4.1970621773477824e-44, 1e-54, pdf(dvec![1., -1.]));
338        test_almost(vec![0.5, -0.2], vec![2.0, 0.3, 0.3,  0.5], 0.0013075203140666656, 1e-15, pdf(dvec![2., 2.]));
339        test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 0.0, pdf(dvec![10., 10.]));
340        test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 0.0, pdf(dvec![100., 100.]));
341    }
342
343    #[test]
344    fn test_ln_pdf() {
345        let ln_pdf = |arg: DVector<_>| move |x: MultivariateNormal| x.ln_pdf(&arg);
346        test_case(vec![0., 0.], vec![1., 0., 0., 1.], (0.05854983152431917f64).ln(), ln_pdf(dvec![1., 1.]));
347        test_almost(vec![0., 0.], vec![1., 0., 0., 1.], (0.013064233284684921f64).ln(), 1e-15, ln_pdf(dvec![1., 2.]));
348        test_almost(vec![0., 0.], vec![1., 0., 0., 1.], (1.8618676045881531e-23f64).ln(), 1e-15, ln_pdf(dvec![1., 10.]));
349        test_almost(vec![0., 0.], vec![1., 0., 0., 1.], (5.920684802611216e-45f64).ln(), 1e-15, ln_pdf(dvec![10., 10.]));
350        test_almost(vec![0., 0.], vec![1., 0.9, 0.9, 1.], (1.6576716577547003e-05f64).ln(), 1e-14, ln_pdf(dvec![1., -1.]));
351        test_almost(vec![0., 0.], vec![1., 0.99, 0.99, 1.], (4.1970621773477824e-44f64).ln(), 1e-12, ln_pdf(dvec![1., -1.]));
352        test_almost(vec![0.5, -0.2], vec![2.0, 0.3, 0.3, 0.5],  (0.0013075203140666656f64).ln(), 1e-15, ln_pdf(dvec![2., 2.]));
353        test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], f64::NEG_INFINITY, ln_pdf(dvec![10., 10.]));
354        test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], f64::NEG_INFINITY, ln_pdf(dvec![100., 100.]));
355    }
356}