1use crate::distribution::Continuous;
2use crate::function::gamma;
3use crate::statistics::{Max, MeanN, Min, Mode, VarianceN};
4use nalgebra::{Cholesky, Const, DMatrix, Dim, DimMin, Dyn, OMatrix, OVector};
5use std::f64::consts::PI;
6
7#[derive(Debug, Clone, PartialEq)]
25pub struct MultivariateStudent<D>
26where
27 D: Dim,
28 nalgebra::DefaultAllocator:
29 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
30{
31 scale_chol_decomp: OMatrix<f64, D, D>,
32 location: OVector<f64, D>,
33 scale: OMatrix<f64, D, D>,
34 freedom: f64,
35 precision: OMatrix<f64, D, D>,
36 ln_pdf_const: f64,
37}
38
39#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
41#[non_exhaustive]
42pub enum MultivariateStudentError {
43 ScaleInvalid,
45
46 LocationInvalid,
48
49 FreedomInvalid,
51
52 DimensionMismatch,
55
56 CholeskyFailed,
59}
60
61impl std::fmt::Display for MultivariateStudentError {
62 #[cfg_attr(coverage_nightly, coverage(off))]
63 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
64 match self {
65 MultivariateStudentError::ScaleInvalid => {
66 write!(f, "Scale matrix is asymmetric or contains a NaN")
67 }
68 MultivariateStudentError::LocationInvalid => {
69 write!(f, "Location vector contains a NaN")
70 }
71 MultivariateStudentError::FreedomInvalid => {
72 write!(f, "Degrees of freedom are NaN, zero or less than zero")
73 }
74 MultivariateStudentError::DimensionMismatch => write!(
75 f,
76 "Location vector and scale matrix do not have the same number of rows"
77 ),
78 MultivariateStudentError::CholeskyFailed => {
79 write!(f, "Computing the Cholesky decomposition failed")
80 }
81 }
82 }
83}
84
85impl std::error::Error for MultivariateStudentError {}
86
87impl MultivariateStudent<Dyn> {
88 pub fn new(
96 location: Vec<f64>,
97 scale: Vec<f64>,
98 freedom: f64,
99 ) -> Result<Self, MultivariateStudentError> {
100 let dim = location.len();
101 Self::new_from_nalgebra(location.into(), DMatrix::from_vec(dim, dim, scale), freedom)
102 }
103
104 pub fn dim(&self) -> usize {
106 self.location.len()
107 }
108}
109
110impl<D> MultivariateStudent<D>
111where
112 D: DimMin<D, Output = D>,
113 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>
114 + nalgebra::allocator::Allocator<D, D>
115 + nalgebra::allocator::Allocator<D>,
116{
117 pub fn new_from_nalgebra(
118 location: OVector<f64, D>,
119 scale: OMatrix<f64, D, D>,
120 freedom: f64,
121 ) -> Result<Self, MultivariateStudentError> {
122 let dim = location.len();
123
124 if location.iter().any(|f| f.is_nan()) {
125 return Err(MultivariateStudentError::LocationInvalid);
126 }
127
128 if !scale.is_square()
129 || scale.lower_triangle() != scale.upper_triangle().transpose()
130 || scale.iter().any(|f| f.is_nan())
131 {
132 return Err(MultivariateStudentError::ScaleInvalid);
133 }
134
135 if freedom.is_nan() || freedom <= 0.0 {
136 return Err(MultivariateStudentError::FreedomInvalid);
137 }
138
139 if location.nrows() != scale.nrows() {
140 return Err(MultivariateStudentError::DimensionMismatch);
141 }
142
143 let scale_det = scale.determinant();
144 let ln_pdf_const = gamma::ln_gamma(0.5 * (freedom + dim as f64))
145 - gamma::ln_gamma(0.5 * freedom)
146 - 0.5 * (dim as f64) * (freedom * PI).ln()
147 - 0.5 * scale_det.ln();
148
149 match Cholesky::new(scale.clone()) {
150 None => Err(MultivariateStudentError::CholeskyFailed),
151 Some(cholesky_decomp) => {
152 let precision = cholesky_decomp.inverse();
153 Ok(MultivariateStudent {
154 scale_chol_decomp: cholesky_decomp.unpack(),
155 location,
156 scale,
157 freedom,
158 precision,
159 ln_pdf_const,
160 })
161 }
162 }
163 }
164
165 pub fn scale_chol_decomp(&self) -> &OMatrix<f64, D, D> {
169 &self.scale_chol_decomp
170 }
171
172 pub fn location(&self) -> &OVector<f64, D> {
174 &self.location
175 }
176
177 pub fn scale(&self) -> &OMatrix<f64, D, D> {
179 &self.scale
180 }
181
182 pub fn freedom(&self) -> f64 {
184 self.freedom
185 }
186
187 pub fn precision(&self) -> &OMatrix<f64, D, D> {
189 &self.precision
190 }
191
192 pub fn ln_pdf_const(&self) -> f64 {
195 self.ln_pdf_const
196 }
197}
198
199#[cfg(feature = "rand")]
200#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
201impl<D> ::rand::distributions::Distribution<OVector<f64, D>> for MultivariateStudent<D>
202where
203 D: Dim,
204 nalgebra::DefaultAllocator:
205 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
206{
207 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<f64, D> {
221 use crate::distribution::{ChiSquared, Normal};
222
223 let d = Normal::new(0., 1.).unwrap();
224 let s = ChiSquared::new(self.freedom).unwrap();
225 let w = (self.freedom / s.sample(rng)).sqrt();
226 let (r, c) = self.location.shape_generic();
227 let z = OVector::<f64, D>::from_distribution_generic(r, c, &d, rng);
228 (w * &self.scale_chol_decomp * z) + &self.location
229 }
230}
231
232impl<D> Min<OVector<f64, D>> for MultivariateStudent<D>
233where
234 D: Dim,
235 nalgebra::DefaultAllocator:
236 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
237{
238 fn min(&self) -> OVector<f64, D> {
241 OMatrix::repeat_generic(
242 self.location.shape_generic().0,
243 Const::<1>,
244 f64::NEG_INFINITY,
245 )
246 }
247}
248
249impl<D> Max<OVector<f64, D>> for MultivariateStudent<D>
250where
251 D: Dim,
252 nalgebra::DefaultAllocator:
253 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
254{
255 fn max(&self) -> OVector<f64, D> {
258 OMatrix::repeat_generic(self.location.shape_generic().0, Const::<1>, f64::INFINITY)
259 }
260}
261
262impl<D> MeanN<OVector<f64, D>> for MultivariateStudent<D>
263where
264 D: Dim,
265 nalgebra::DefaultAllocator:
266 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
267{
268 fn mean(&self) -> Option<OVector<f64, D>> {
275 if self.freedom > 1. {
276 Some(self.location.clone())
277 } else {
278 None
279 }
280 }
281}
282
283impl<D> VarianceN<OMatrix<f64, D, D>> for MultivariateStudent<D>
284where
285 D: Dim,
286 nalgebra::DefaultAllocator:
287 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
288{
289 fn variance(&self) -> Option<OMatrix<f64, D, D>> {
300 if self.freedom > 2. {
301 Some(self.scale.clone() * self.freedom / (self.freedom - 2.))
302 } else {
303 None
304 }
305 }
306}
307
308impl<D> Mode<OVector<f64, D>> for MultivariateStudent<D>
309where
310 D: Dim,
311 nalgebra::DefaultAllocator:
312 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
313{
314 fn mode(&self) -> OVector<f64, D> {
324 self.location.clone()
325 }
326}
327
328impl<D> Continuous<&OVector<f64, D>, f64> for MultivariateStudent<D>
329where
330 D: Dim + DimMin<D, Output = D>,
331 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>
332 + nalgebra::allocator::Allocator<D, D>
333 + nalgebra::allocator::Allocator<D>,
334{
335 fn pdf(&self, x: &OVector<f64, D>) -> f64 {
352 if self.freedom.is_infinite() {
353 use super::multivariate_normal::density_normalization_and_exponential;
354 let (pdf_const, exp_arg) = density_normalization_and_exponential(
355 &self.location,
356 &self.scale,
357 &self.precision,
358 x,
359 )
360 .unwrap();
361 return pdf_const * exp_arg.exp();
362 }
363
364 let dv = x - &self.location;
365 let exp_arg: f64 = (&self.precision * &dv).dot(&dv);
366 let base_term = 1. + exp_arg / self.freedom;
367 self.ln_pdf_const.exp() * base_term.powf(-(self.freedom + self.location.len() as f64) / 2.)
368 }
369
370 fn ln_pdf(&self, x: &OVector<f64, D>) -> f64 {
373 if self.freedom.is_infinite() {
374 use super::multivariate_normal::density_normalization_and_exponential;
375 let (pdf_const, exp_arg) = density_normalization_and_exponential(
376 &self.location,
377 &self.scale,
378 &self.precision,
379 x,
380 )
381 .unwrap();
382 return pdf_const.ln() + exp_arg;
383 }
384
385 let dv = x - &self.location;
386 let exp_arg: f64 = (&self.precision * &dv).dot(&dv);
387 let base_term = 1. + exp_arg / self.freedom;
388 self.ln_pdf_const - (self.freedom + self.location.len() as f64) / 2. * base_term.ln()
389 }
390}
391
392#[rustfmt::skip]
393#[cfg(test)]
394mod tests {
395 use core::fmt::Debug;
396
397 use approx::RelativeEq;
398 use nalgebra::{DMatrix, DVector, Dyn, OMatrix, OVector, U1, U2};
399
400 use crate::{
401 distribution::{Continuous, MultivariateStudent, MultivariateNormal},
402 statistics::{Max, MeanN, Min, Mode, VarianceN},
403 };
404
405 use super::MultivariateStudentError;
406
407 fn try_create(location: Vec<f64>, scale: Vec<f64>, freedom: f64) -> MultivariateStudent<Dyn>
408 {
409 let mvs = MultivariateStudent::new(location, scale, freedom);
410 assert!(mvs.is_ok());
411 mvs.unwrap()
412 }
413
414 fn create_case(location: Vec<f64>, scale: Vec<f64>, freedom: f64)
415 {
416 let mvs = try_create(location.clone(), scale.clone(), freedom);
417 assert_eq!(DMatrix::from_vec(location.len(), location.len(), scale), mvs.scale);
418 assert_eq!(DVector::from_vec(location), mvs.location);
419 }
420
421 fn bad_create_case(location: Vec<f64>, scale: Vec<f64>, freedom: f64)
422 {
423 let mvs = MultivariateStudent::new(location, scale, freedom);
424 assert!(mvs.is_err());
425 }
426
427 fn test_case<T, F>(location: Vec<f64>, scale: Vec<f64>, freedom: f64, expected: T, eval: F)
428 where
429 T: Debug + PartialEq,
430 F: FnOnce(MultivariateStudent<Dyn>) -> T,
431 {
432 let mvs = try_create(location, scale, freedom);
433 let x = eval(mvs);
434 assert_eq!(expected, x);
435 }
436
437 fn test_almost<F>(
438 location: Vec<f64>,
439 scale: Vec<f64>,
440 freedom: f64,
441 expected: f64,
442 acc: f64,
443 eval: F,
444 ) where
445 F: FnOnce(MultivariateStudent<Dyn>) -> f64,
446 {
447 let mvs = try_create(location, scale, freedom);
448 let x = eval(mvs);
449 assert_almost_eq!(expected, x, acc);
450 }
451
452 fn test_almost_multivariate_normal<F1, F2>(
453 location: Vec<f64>,
454 scale: Vec<f64>,
455 freedom: f64,
456 acc: f64,
457 x: DVector<f64>,
458 eval_mvs: F1,
459 eval_mvn: F2,
460 ) where
461 F1: FnOnce(MultivariateStudent<Dyn>, DVector<f64>) -> f64,
462 F2: FnOnce(MultivariateNormal<Dyn>, DVector<f64>) -> f64,
463 {
464 let mvs = try_create(location.clone(), scale.clone(), freedom);
465 let mvn0 = MultivariateNormal::new(location, scale);
466 assert!(mvn0.is_ok());
467 let mvn = mvn0.unwrap();
468 let mvs_x = eval_mvs(mvs, x.clone());
469 let mvn_x = eval_mvn(mvn, x.clone());
470 assert!(mvs_x.relative_eq(&mvn_x, acc, acc), "mvn: {mvn_x} =/=\nmvs: {mvs_x}");
471 }
473
474
475 macro_rules! dvec {
476 ($($x:expr),*) => (DVector::from_vec(vec![$($x),*]));
477 }
478
479 macro_rules! mat2 {
480 ($x11:expr, $x12:expr, $x21:expr, $x22:expr) => (DMatrix::from_vec(2,2,vec![$x11, $x12, $x21, $x22]));
481 }
482
483 #[test]
488 fn test_create() {
489 create_case(vec![0., 0.], vec![1., 0., 0., 1.], 1.);
490 create_case(vec![10., 5.], vec![2., 1., 1., 2.], 3.);
491 create_case(vec![4., 5., 6.], vec![2., 1., 0., 1., 2., 1., 0., 1., 2.], 14.);
492 create_case(vec![0., f64::INFINITY], vec![1., 0., 0., 1.], f64::INFINITY);
493 create_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 0.1);
494 }
495
496 #[test]
497 fn test_bad_create() {
498 bad_create_case(vec![0., 0.], vec![1., 1., 0., 1.], 1.);
500 bad_create_case(vec![0., 0.], vec![1., 2., 2., 1.], 1.);
502 bad_create_case(vec![0., f64::NAN], vec![1., 0., 0., 1.], 1.);
504 bad_create_case(vec![0., 0.], vec![1., 0., 0., f64::NAN], 1.);
506 bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], f64::NAN);
508 bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], 0.);
510 bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], -1.);
511 }
512
513 #[test]
514 fn test_variance() {
515 let variance = |x: MultivariateStudent<Dyn>| x.variance().unwrap();
516 test_case(vec![0., 0.], vec![1., 0., 0., 1.], 3., 3. * mat2![1., 0., 0., 1.], variance);
517 test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 3., mat2![f64::INFINITY, 0., 0., f64::INFINITY], variance);
518 }
519
520 #[test]
522 fn test_bad_variance() {
523 let variance = |x: MultivariateStudent<Dyn>| x.variance();
524 test_case(vec![0., 0.], vec![1., 0., 0., 1.], 2., None, variance);
525 }
526
527 #[test]
528 fn test_mode() {
529 let mode = |x: MultivariateStudent<Dyn>| x.mode();
530 test_case(vec![0., 0.], vec![1., 0., 0., 1.], 1., dvec![0., 0.], mode);
531 test_case(vec![f64::INFINITY, f64::INFINITY], vec![1., 0., 0., 1.], 1., dvec![f64::INFINITY, f64::INFINITY], mode);
532 }
533
534 #[test]
535 fn test_mean() {
536 let mean = |x: MultivariateStudent<Dyn>| x.mean().unwrap();
537 test_case(vec![0., 0.], vec![1., 0., 0., 1.], 2., dvec![0., 0.], mean);
538 test_case(vec![-1., 1., 3.], vec![1., 0., 0.5, 0., 2.0, 0., 0.5, 0., 3.0], 2., dvec![-1., 1., 3.], mean);
539 }
540
541 #[test]
543 fn test_bad_mean() {
544 let mean = |x: MultivariateStudent<Dyn>| x.mean();
545 test_case(vec![0., 0.], vec![1., 0., 0., 1.], 1., None, mean);
546 }
547
548 #[test]
549 fn test_min_max() {
550 let min = |x: MultivariateStudent<Dyn>| x.min();
551 let max = |x: MultivariateStudent<Dyn>| x.max();
552 test_case(vec![0., 0.], vec![1., 0., 0., 1.], 1., dvec![f64::NEG_INFINITY, f64::NEG_INFINITY], min);
553 test_case(vec![0., 0.], vec![1., 0., 0., 1.], 1., dvec![f64::INFINITY, f64::INFINITY], max);
554 test_case(vec![10., 1.], vec![1., 0., 0., 1.], 1., dvec![f64::NEG_INFINITY, f64::NEG_INFINITY], min);
555 test_case(vec![-3., 5.], vec![1., 0., 0., 1.], 1., dvec![f64::INFINITY, f64::INFINITY], max);
556 }
557
558 #[test]
559 fn test_pdf() {
560 let pdf = |arg: DVector<f64>| move |x: MultivariateStudent<Dyn>| x.pdf(&arg);
561 test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., 0.047157020175376416, 1e-15, pdf(dvec![1., 1.]));
562 test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., 0.013972450422333741737457302178882, 1e-15, pdf(dvec![1., 2.]));
563 test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 2., 0.012992240252399619, 1e-17, pdf(dvec![1., 2.]));
564 test_almost(vec![2., 1.], vec![5., 0., 0., 1.], 2.5, 2.639780816598878e-5, 1e-19, pdf(dvec![1., 10.]));
565 test_almost(vec![-1., 0.], vec![2., 1., 1., 6.], 1.5, 6.438051574348526e-5, 1e-19, pdf(dvec![10., 10.]));
566 test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 6.960998836915657e-16, 1e-30, pdf(dvec![0.9718, 0.1298, 0.8134]));
568 test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 7.369987979187023e-16, 1e-30, pdf(dvec![0.4922, 0.5522, 0.7185]));
569 test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8.,6.951631724511314e-16, 1e-30, pdf(dvec![0.3020, 0.1491, 0.5008]));
570 test_case(vec![-1., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 10., 0., pdf(dvec![10., 10.]));
571 }
572
573 #[test]
574 fn test_ln_pdf() {
575 let ln_pdf = |arg: DVector<f64>| move |x: MultivariateStudent<Dyn>| x.ln_pdf(&arg);
576 test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., -3.0542723907338383, 1e-14, ln_pdf(dvec![1., 1.]));
577 test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 2., -4.3434030034000815, 1e-14, ln_pdf(dvec![1., 2.]));
578 test_almost(vec![2., 1.], vec![5., 0., 0., 1.], 2.5, -10.542229575274265, 1e-14, ln_pdf(dvec![1., 10.]));
579 test_almost(vec![-1., 0.], vec![2., 1., 1., 6.], 1.5, -9.650699521198622, 1e-14, ln_pdf(dvec![10., 10.]));
580 }
582
583 #[test]
584 fn test_pdf_freedom_large() {
585 let pdf_mvs = |mv: MultivariateStudent<Dyn>, arg: DVector<f64>| mv.pdf(&arg);
586 let pdf_mvn = |mv: MultivariateNormal<Dyn>, arg: DVector<f64>| mv.pdf(&arg);
587 test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-6, dvec![1., 1.], pdf_mvs, pdf_mvn);
588 test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 1e-7, dvec![1., 1.], pdf_mvs, pdf_mvn);
589 test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn);
590 test_almost_multivariate_normal(vec![5., -1.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![5., 1.], pdf_mvs, pdf_mvn);
591 }
592 #[test]
593 fn test_ln_pdf_freedom_large() {
594 let pdf_mvs = |mv: MultivariateStudent<Dyn>, arg: DVector<f64>| mv.ln_pdf(&arg);
595 let pdf_mvn = |mv: MultivariateNormal<Dyn>, arg: DVector<f64>| mv.ln_pdf(&arg);
596 test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-5, dvec![1., 1.], pdf_mvs, pdf_mvn);
597 test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 5e-6, dvec![1., 1.], pdf_mvs, pdf_mvn);
598 test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn);
599 test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn);
600 }
601
602 #[test]
603 fn test_immut_field_access() {
604 let mvs = MultivariateStudent::new(vec![1., 1.], vec![1., 0., 0., 1.], 2.)
606 .expect("hard coded valid construction");
607 assert_eq!(mvs.freedom(), 2.);
608 assert_relative_eq!(mvs.ln_pdf_const(), std::f64::consts::TAU.recip().ln(), epsilon = 1e-15);
609
610 assert_eq!(mvs.dim(), 2);
612 assert!(mvs.location().eq(&OVector::<f64, U2>::new(1., 1.)));
613 assert!(mvs.scale().eq(&OMatrix::<f64, U2, U2>::identity()));
614 assert!(mvs.precision().eq(&OMatrix::<f64, U2, U2>::identity()));
615 assert!(mvs.scale_chol_decomp().eq(&OMatrix::<f64, U2, U2>::identity()));
616
617 assert_eq!(mvs.location(),&OVector::<f64, Dyn>::from_element_generic(Dyn(2), U1, 1.));
619 assert_eq!(mvs.scale(), &OMatrix::<f64, Dyn, Dyn>::identity(2, 2));
620 assert_eq!(mvs.precision(), &OMatrix::<f64, Dyn, Dyn>::identity(2, 2));
621 assert_eq!(mvs.scale_chol_decomp(), &OMatrix::<f64, Dyn, Dyn>::identity(2, 2));
622 }
623
624 #[test]
625 fn test_error_is_sync_send() {
626 fn assert_sync_send<T: Sync + Send>() {}
627 assert_sync_send::<MultivariateStudentError>();
628 }
629}