1use crate::distribution::Continuous;
2use crate::distribution::Normal;
3use crate::statistics::{Max, MeanN, Min, Mode, VarianceN};
4use crate::{Result, StatsError};
5use nalgebra::{
6 base::allocator::Allocator, base::dimension::DimName, Cholesky, DefaultAllocator, Dim, DimMin,
7 LU, U1,
8};
9use nalgebra::{DMatrix, DVector};
10use rand::Rng;
11use std::f64;
12use std::f64::consts::{E, PI};
13
14#[derive(Debug, Clone, PartialEq)]
30pub struct MultivariateNormal {
31 dim: usize,
32 cov_chol_decomp: DMatrix<f64>,
33 mu: DVector<f64>,
34 cov: DMatrix<f64>,
35 precision: DMatrix<f64>,
36 pdf_const: f64,
37}
38
39impl MultivariateNormal {
40 pub fn new(mean: Vec<f64>, cov: Vec<f64>) -> Result<Self> {
48 let mean = DVector::from_vec(mean);
49 let cov = DMatrix::from_vec(mean.len(), mean.len(), cov);
50 let dim = mean.len();
51 if cov.lower_triangle() != cov.upper_triangle().transpose()
53 || mean.iter().any(|f| f.is_nan())
55 || cov.iter().any(|f| f.is_nan())
56 || mean.nrows() != cov.nrows() || cov.nrows() != cov.ncols()
58 {
59 return Err(StatsError::BadParams);
60 }
61 let cov_det = cov.determinant();
62 let pdf_const = ((2. * PI).powi(mean.nrows() as i32) * cov_det.abs())
63 .recip()
64 .sqrt();
65 match Cholesky::new(cov.clone()) {
68 None => Err(StatsError::BadParams),
69 Some(cholesky_decomp) => {
70 let precision = cholesky_decomp.inverse();
71 Ok(MultivariateNormal {
72 dim,
73 cov_chol_decomp: cholesky_decomp.unpack(),
74 mu: mean,
75 cov,
76 precision,
77 pdf_const,
78 })
79 }
80 }
81 }
82 pub fn entropy(&self) -> Option<f64> {
92 Some(
93 0.5 * self
94 .variance()
95 .unwrap()
96 .scale(2. * PI * E)
97 .determinant()
98 .ln(),
99 )
100 }
101}
102
103impl ::rand::distributions::Distribution<DVector<f64>> for MultivariateNormal {
104 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> DVector<f64> {
114 let d = Normal::new(0., 1.).unwrap();
115 let z = DVector::<f64>::from_distribution(self.dim, &d, rng);
116 (&self.cov_chol_decomp * z) + &self.mu
117 }
118}
119
120impl Min<DVector<f64>> for MultivariateNormal {
121 fn min(&self) -> DVector<f64> {
124 DVector::from_vec(vec![f64::NEG_INFINITY; self.dim])
125 }
126}
127
128impl Max<DVector<f64>> for MultivariateNormal {
129 fn max(&self) -> DVector<f64> {
132 DVector::from_vec(vec![f64::INFINITY; self.dim])
133 }
134}
135
136impl MeanN<DVector<f64>> for MultivariateNormal {
137 fn mean(&self) -> Option<DVector<f64>> {
143 let mut vec = vec![];
144 for elt in self.mu.clone().into_iter() {
145 vec.push(*elt);
146 }
147 Some(DVector::from_vec(vec))
148 }
149}
150
151impl VarianceN<DMatrix<f64>> for MultivariateNormal {
152 fn variance(&self) -> Option<DMatrix<f64>> {
154 Some(self.cov.clone())
155 }
156}
157
158impl Mode<DVector<f64>> for MultivariateNormal {
159 fn mode(&self) -> DVector<f64> {
169 self.mu.clone()
170 }
171}
172
173impl<'a> Continuous<&'a DVector<f64>, f64> for MultivariateNormal {
174 fn pdf(&self, x: &'a DVector<f64>) -> f64 {
186 let dv = x - &self.mu;
187 let exp_term = -0.5
188 * *(&dv.transpose() * &self.precision * &dv)
189 .get((0, 0))
190 .unwrap();
191 self.pdf_const * exp_term.exp()
192 }
193 fn ln_pdf(&self, x: &'a DVector<f64>) -> f64 {
196 let dv = x - &self.mu;
197 let exp_term = -0.5
198 * *(&dv.transpose() * &self.precision * &dv)
199 .get((0, 0))
200 .unwrap();
201 self.pdf_const.ln() + exp_term
202 }
203}
204
205#[rustfmt::skip]
206#[cfg(all(test, feature = "nightly"))]
207mod tests {
208 use crate::distribution::{Continuous, MultivariateNormal};
209 use crate::statistics::*;
210 use crate::consts::ACC;
211 use core::fmt::Debug;
212 use nalgebra::base::allocator::Allocator;
213 use nalgebra::{
214 DefaultAllocator, Dim, DimMin, DimName, Matrix2, Matrix3, Vector2, Vector3,
215 U1, U2,
216 };
217
218 fn try_create(mean: Vec<f64>, covariance: Vec<f64>) -> MultivariateNormal
219 {
220 let mvn = MultivariateNormal::new(mean, covariance);
221 assert!(mvn.is_ok());
222 mvn.unwrap()
223 }
224
225 fn create_case(mean: Vec<f64>, covariance: Vec<f64>)
226 {
227 let mvn = try_create(mean.clone(), covariance.clone());
228 assert_eq!(DVector::from_vec(mean.clone()), mvn.mean().unwrap());
229 assert_eq!(DMatrix::from_vec(mean.len(), mean.len(), covariance), mvn.variance().unwrap());
230 }
231
232 fn bad_create_case(mean: Vec<f64>, covariance: Vec<f64>)
233 {
234 let mvn = MultivariateNormal::new(mean, covariance);
235 assert!(mvn.is_err());
236 }
237
238 fn test_case<T, F>(mean: Vec<f64>, covariance: Vec<f64>, expected: T, eval: F)
239 where
240 T: Debug + PartialEq,
241 F: FnOnce(MultivariateNormal) -> T,
242 {
243 let mvn = try_create(mean, covariance);
244 let x = eval(mvn);
245 assert_eq!(expected, x);
246 }
247
248 fn test_almost<F>(
249 mean: Vec<f64>,
250 covariance: Vec<f64>,
251 expected: f64,
252 acc: f64,
253 eval: F,
254 ) where
255 F: FnOnce(MultivariateNormal) -> f64,
256 {
257 let mvn = try_create(mean, covariance);
258 let x = eval(mvn);
259 assert_almost_eq!(expected, x, acc);
260 }
261
262 use super::*;
263
264 macro_rules! dvec {
265 ($($x:expr),*) => (DVector::from_vec(vec![$($x),*]));
266 }
267
268 macro_rules! mat2 {
269 ($x11:expr, $x12:expr, $x21:expr, $x22:expr) => (DMatrix::from_vec(2,2,vec![$x11, $x12, $x21, $x22]));
270 }
271
272 #[test]
277 fn test_create() {
278 create_case(vec![0., 0.], vec![1., 0., 0., 1.]);
279 create_case(vec![10., 5.], vec![2., 1., 1., 2.]);
280 create_case(vec![4., 5., 6.], vec![2., 1., 0., 1., 2., 1., 0., 1., 2.]);
281 create_case(vec![0., f64::INFINITY], vec![1., 0., 0., 1.]);
282 create_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY]);
283 }
284
285 #[test]
286 fn test_bad_create() {
287 bad_create_case(vec![0., 0.], vec![1., 1., 0., 1.]);
289 bad_create_case(vec![0., 0.], vec![1., 2., 2., 1.]);
291 bad_create_case(vec![0., f64::NAN], vec![1., 0., 0., 1.]);
293 bad_create_case(vec![0., 0.], vec![1., 0., 0., f64::NAN]);
295 }
296
297 #[test]
298 fn test_variance() {
299 let variance = |x: MultivariateNormal| x.variance().unwrap();
300 test_case(vec![0., 0.], vec![1., 0., 0., 1.], mat2![1., 0., 0., 1.], variance);
301 test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], mat2![f64::INFINITY, 0., 0., f64::INFINITY], variance);
302 }
303
304 #[test]
305 fn test_entropy() {
306 let entropy = |x: MultivariateNormal| x.entropy().unwrap();
307 test_case(vec![0., 0.], vec![1., 0., 0., 1.], 2.8378770664093453, entropy);
308 test_case(vec![0., 0.], vec![1., 0.5, 0.5, 1.], 2.694036030183455, entropy);
309 test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], f64::INFINITY, entropy);
310 }
311
312 #[test]
313 fn test_mode() {
314 let mode = |x: MultivariateNormal| x.mode();
315 test_case(vec![0., 0.], vec![1., 0., 0., 1.], dvec![0., 0.], mode);
316 test_case(vec![f64::INFINITY, f64::INFINITY], vec![1., 0., 0., 1.], dvec![f64::INFINITY, f64::INFINITY], mode);
317 }
318
319 #[test]
320 fn test_min_max() {
321 let min = |x: MultivariateNormal| x.min();
322 let max = |x: MultivariateNormal| x.max();
323 test_case(vec![0., 0.], vec![1., 0., 0., 1.], dvec![f64::NEG_INFINITY, f64::NEG_INFINITY], min);
324 test_case(vec![0., 0.], vec![1., 0., 0., 1.], dvec![f64::INFINITY, f64::INFINITY], max);
325 test_case(vec![10., 1.], vec![1., 0., 0., 1.], dvec![f64::NEG_INFINITY, f64::NEG_INFINITY], min);
326 test_case(vec![-3., 5.], vec![1., 0., 0., 1.], dvec![f64::INFINITY, f64::INFINITY], max);
327 }
328
329 #[test]
330 fn test_pdf() {
331 let pdf = |arg: DVector<f64>| move |x: MultivariateNormal| x.pdf(&arg);
332 test_case(vec![0., 0.], vec![1., 0., 0., 1.], 0.05854983152431917, pdf(dvec![1., 1.]));
333 test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 0.013064233284684921, 1e-15, pdf(dvec![1., 2.]));
334 test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 1.8618676045881531e-23, 1e-35, pdf(dvec![1., 10.]));
335 test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 5.920684802611216e-45, 1e-58, pdf(dvec![10., 10.]));
336 test_almost(vec![0., 0.], vec![1., 0.9, 0.9, 1.], 1.6576716577547003e-05, 1e-18, pdf(dvec![1., -1.]));
337 test_almost(vec![0., 0.], vec![1., 0.99, 0.99, 1.], 4.1970621773477824e-44, 1e-54, pdf(dvec![1., -1.]));
338 test_almost(vec![0.5, -0.2], vec![2.0, 0.3, 0.3, 0.5], 0.0013075203140666656, 1e-15, pdf(dvec![2., 2.]));
339 test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 0.0, pdf(dvec![10., 10.]));
340 test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 0.0, pdf(dvec![100., 100.]));
341 }
342
343 #[test]
344 fn test_ln_pdf() {
345 let ln_pdf = |arg: DVector<_>| move |x: MultivariateNormal| x.ln_pdf(&arg);
346 test_case(vec![0., 0.], vec![1., 0., 0., 1.], (0.05854983152431917f64).ln(), ln_pdf(dvec![1., 1.]));
347 test_almost(vec![0., 0.], vec![1., 0., 0., 1.], (0.013064233284684921f64).ln(), 1e-15, ln_pdf(dvec![1., 2.]));
348 test_almost(vec![0., 0.], vec![1., 0., 0., 1.], (1.8618676045881531e-23f64).ln(), 1e-15, ln_pdf(dvec![1., 10.]));
349 test_almost(vec![0., 0.], vec![1., 0., 0., 1.], (5.920684802611216e-45f64).ln(), 1e-15, ln_pdf(dvec![10., 10.]));
350 test_almost(vec![0., 0.], vec![1., 0.9, 0.9, 1.], (1.6576716577547003e-05f64).ln(), 1e-14, ln_pdf(dvec![1., -1.]));
351 test_almost(vec![0., 0.], vec![1., 0.99, 0.99, 1.], (4.1970621773477824e-44f64).ln(), 1e-12, ln_pdf(dvec![1., -1.]));
352 test_almost(vec![0.5, -0.2], vec![2.0, 0.3, 0.3, 0.5], (0.0013075203140666656f64).ln(), 1e-15, ln_pdf(dvec![2., 2.]));
353 test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], f64::NEG_INFINITY, ln_pdf(dvec![10., 10.]));
354 test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], f64::NEG_INFINITY, ln_pdf(dvec![100., 100.]));
355 }
356}