1use crate::distribution::Continuous;
2use crate::function::gamma;
3use crate::prec;
4use crate::statistics::*;
5use nalgebra::{Dim, Dyn, OMatrix, OVector};
6use std::f64;
7
8#[derive(Clone, PartialEq, Debug)]
25pub struct Dirichlet<D>
26where
27 D: Dim,
28 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
29{
30 alpha: OVector<f64, D>,
31}
32
33#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
35#[non_exhaustive]
36pub enum DirichletError {
37 AlphaTooShort,
39
40 AlphaHasInvalidElements,
42}
43
44impl std::fmt::Display for DirichletError {
45 #[cfg_attr(coverage_nightly, coverage(off))]
46 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
47 match self {
48 DirichletError::AlphaTooShort => write!(f, "Alpha contains less than two elements"),
49 DirichletError::AlphaHasInvalidElements => write!(
50 f,
51 "Alpha contains an element that is NaN, infinite, zero or less than zero"
52 ),
53 }
54 }
55}
56
57impl std::error::Error for DirichletError {}
58
59impl Dirichlet<Dyn> {
60 pub fn new(alpha: Vec<f64>) -> Result<Self, DirichletError> {
84 Self::new_from_nalgebra(alpha.into())
85 }
86
87 pub fn new_with_param(alpha: f64, n: usize) -> Result<Self, DirichletError> {
107 Self::new(vec![alpha; n])
108 }
109}
110
111impl<D> Dirichlet<D>
112where
113 D: Dim,
114 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
115{
116 pub fn new_from_nalgebra(alpha: OVector<f64, D>) -> Result<Self, DirichletError> {
124 if alpha.len() < 2 {
125 return Err(DirichletError::AlphaTooShort);
126 }
127
128 if alpha.iter().any(|&a_i| !a_i.is_finite() || a_i <= 0.0) {
129 return Err(DirichletError::AlphaHasInvalidElements);
130 }
131
132 Ok(Self { alpha })
133 }
134
135 pub fn alpha(&self) -> &nalgebra::OVector<f64, D> {
148 &self.alpha
149 }
150
151 fn alpha_sum(&self) -> f64 {
152 self.alpha.sum()
153 }
154
155 pub fn entropy(&self) -> Option<f64> {
174 let sum = self.alpha_sum();
175 let num = self.alpha.iter().fold(0.0, |acc, &x| {
176 acc + gamma::ln_gamma(x) + (x - 1.0) * gamma::digamma(x)
177 });
178 let entr =
179 -gamma::ln_gamma(sum) + (sum - self.alpha.len() as f64) * gamma::digamma(sum) - num;
180 Some(entr)
181 }
182}
183
184impl<D> std::fmt::Display for Dirichlet<D>
185where
186 D: Dim,
187 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
188{
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 write!(f, "Dir({}, {})", self.alpha.len(), &self.alpha)
191 }
192}
193
194#[cfg(feature = "rand")]
195#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
196impl<D> ::rand::distributions::Distribution<OVector<f64, D>> for Dirichlet<D>
197where
198 D: Dim,
199 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
200{
201 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<f64, D> {
202 let mut sum = 0.0;
203 OVector::from_iterator_generic(
204 self.alpha.shape_generic().0,
205 nalgebra::Const::<1>,
206 self.alpha.iter().map(|&a| {
207 let sample = super::gamma::sample_unchecked(rng, a, 1.0);
208 sum += sample;
209 sample
210 }),
211 )
212 }
213}
214
215impl<D> MeanN<OVector<f64, D>> for Dirichlet<D>
216where
217 D: Dim,
218 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
219{
220 fn mean(&self) -> Option<OVector<f64, D>> {
231 let sum = self.alpha_sum();
232 Some(self.alpha.map(|x| x / sum))
233 }
234}
235
236impl<D> VarianceN<OMatrix<f64, D, D>> for Dirichlet<D>
237where
238 D: Dim,
239 nalgebra::DefaultAllocator:
240 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
241{
242 fn variance(&self) -> Option<OMatrix<f64, D, D>> {
253 let sum = self.alpha_sum();
254 let normalizing = sum * sum * (sum + 1.0);
255 let mut cov = OMatrix::from_diagonal(&self.alpha.map(|x| x * (sum - x) / normalizing));
256 let mut offdiag = |x: usize, y: usize| {
257 let elt = -self.alpha[x] * self.alpha[y] / normalizing;
258 cov[(x, y)] = elt;
259 cov[(y, x)] = elt;
260 };
261 for i in 0..self.alpha.len() {
262 for j in 0..i {
263 offdiag(i, j);
264 }
265 }
266 Some(cov)
267 }
268}
269
270impl<D> Continuous<&OVector<f64, D>, f64> for Dirichlet<D>
271where
272 D: Dim,
273 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>
274 + nalgebra::allocator::Allocator<D, D>
275 + nalgebra::allocator::Allocator<nalgebra::Const<1>, D>,
276{
277 fn pdf(&self, x: &OVector<f64, D>) -> f64 {
308 self.ln_pdf(x).exp()
309 }
310
311 fn ln_pdf(&self, x: &OVector<f64, D>) -> f64 {
342 if self.alpha.len() != x.len() {
343 panic!("Arguments must have correct dimensions.");
344 }
345
346 let mut term = 0.0;
347 let mut sum_x = 0.0;
348 let mut sum_alpha = 0.0;
349
350 for (&x_i, &alpha_i) in x.iter().zip(self.alpha.iter()) {
351 assert!(0.0 < x_i && x_i < 1.0, "Arguments must be in (0, 1)");
352
353 term += (alpha_i - 1.0) * x_i.ln() - gamma::ln_gamma(alpha_i);
354 sum_x += x_i;
355 sum_alpha += alpha_i;
356 }
357
358 assert!(
359 prec::almost_eq(sum_x, 1.0, 1e-4),
360 "Arguments must sum up to 1"
361 );
362 term + gamma::ln_gamma(sum_alpha)
363 }
364}
365
366#[rustfmt::skip]
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 use std::fmt::{Debug, Display};
372
373 use nalgebra::{dmatrix, dvector, vector, DimMin, OVector};
374
375 fn try_create<D>(alpha: OVector<f64, D>) -> Dirichlet<D>
376 where
377 D: DimMin<D, Output = D>,
378 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
379 {
380 let mvn = Dirichlet::new_from_nalgebra(alpha);
381 assert!(mvn.is_ok());
382 mvn.unwrap()
383 }
384
385 fn bad_create_case<D>(alpha: OVector<f64, D>)
386 where
387 D: DimMin<D, Output = D>,
388 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
389 {
390 let dd = Dirichlet::new_from_nalgebra(alpha);
391 assert!(dd.is_err());
392 }
393
394 fn test_almost<F, T, D>(alpha: OVector<f64, D>, expected: T, acc: f64, eval: F)
395 where
396 T: Debug + Display + approx::RelativeEq<Epsilon = f64>,
397 F: FnOnce(Dirichlet<D>) -> T,
398 D: DimMin<D, Output = D>,
399 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
400 {
401 let dd = try_create(alpha);
402 let x = eval(dd);
403 assert_relative_eq!(expected, x, epsilon = acc);
404 }
405
406 #[test]
407 fn test_create() {
408 try_create(vector![1.0, 2.0]);
409 try_create(vector![1.0, 2.0, 3.0, 4.0, 5.0]);
410 assert!(Dirichlet::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]).is_ok());
411 }
413
414 #[test]
415 fn test_bad_create() {
416 bad_create_case(vector![1.0, f64::NAN]);
417 bad_create_case(vector![1.0, 0.0]);
418 bad_create_case(vector![1.0, f64::INFINITY]);
419 bad_create_case(vector![-1.0, 2.0]);
420 bad_create_case(vector![1.0]);
421 bad_create_case(vector![1.0, 2.0, 0.0, 4.0, 5.0]);
422 bad_create_case(vector![1.0, f64::NAN, 3.0, 4.0, 5.0]);
423 bad_create_case(vector![0.0, 0.0, 0.0]);
424 bad_create_case(vector![0.001, f64::INFINITY, 3756.0]); }
426
427 #[test]
428 fn test_mean() {
429 let mean = |dd: Dirichlet<_>| dd.mean().unwrap();
430
431 test_almost(vec![0.5; 5].into(), vec![1.0 / 5.0; 5].into(), 1e-15, mean);
432
433 test_almost(
434 dvector![0.1, 0.2, 0.3, 0.4],
435 dvector![0.1, 0.2, 0.3, 0.4],
436 1e-15,
437 mean,
438 );
439
440 test_almost(
441 dvector![1.0, 2.0, 3.0, 4.0],
442 dvector![0.1, 0.2, 0.3, 0.4],
443 1e-15,
444 mean,
445 );
446 }
447
448 #[test]
449 fn test_variance() {
450 let variance = |dd: Dirichlet<_>| dd.variance().unwrap();
451
452 test_almost(
453 dvector![1.0, 2.0],
454 dmatrix![0.055555555555555, -0.055555555555555;
455 -0.055555555555555, 0.055555555555555;
456 ],
457 1e-15,
458 variance,
459 );
460
461 test_almost(
462 dvector![0.1, 0.2, 0.3, 0.4],
463 dmatrix![0.045, -0.010, -0.015, -0.020;
464 -0.010, 0.080, -0.030, -0.040;
465 -0.015, -0.030, 0.105, -0.060;
466 -0.020, -0.040, -0.060, 0.120;
467 ],
468 1e-15,
469 variance,
470 );
471 }
472
473 #[test]
486 fn test_entropy() {
487 let entropy = |x: Dirichlet<_>| x.entropy().unwrap();
488 test_almost(
489 vector![0.1, 0.3, 0.5, 0.8],
490 -17.46469081094079,
491 1e-30,
492 entropy,
493 );
494 test_almost(
495 vector![0.1, 0.2, 0.3, 0.4],
496 -21.53881433791513,
497 1e-30,
498 entropy,
499 );
500 }
501
502 #[test]
503 fn test_pdf() {
504 let pdf = |arg| move |x: Dirichlet<_>| x.pdf(&arg);
505 test_almost(
506 vector![0.1, 0.3, 0.5, 0.8],
507 18.77225681167061,
508 1e-12,
509 pdf([0.01, 0.03, 0.5, 0.46].into()),
510 );
511 test_almost(
512 vector![0.1, 0.3, 0.5, 0.8],
513 0.8314656481199253,
514 1e-14,
515 pdf([0.1, 0.2, 0.3, 0.4].into()),
516 );
517 }
518
519 #[test]
520 fn test_ln_pdf() {
521 let ln_pdf = |arg| move |x: Dirichlet<_>| x.ln_pdf(&arg);
522 test_almost(
523 vector![0.1, 0.3, 0.5, 0.8],
524 18.77225681167061_f64.ln(),
525 1e-12,
526 ln_pdf([0.01, 0.03, 0.5, 0.46].into()),
527 );
528 test_almost(
529 vector![0.1, 0.3, 0.5, 0.8],
530 0.8314656481199253_f64.ln(),
531 1e-14,
532 ln_pdf([0.1, 0.2, 0.3, 0.4].into()),
533 );
534 }
535
536 #[test]
537 #[should_panic]
538 fn test_pdf_bad_input_length() {
539 let n = try_create(dvector![0.1, 0.3, 0.5, 0.8]);
540 n.pdf(&dvector![0.5]);
541 }
542
543 #[test]
544 #[should_panic]
545 fn test_pdf_bad_input_range() {
546 let n = try_create(vector![0.1, 0.3, 0.5, 0.8]);
547 n.pdf(&vector![1.5, 0.0, 0.0, 0.0]);
548 }
549
550 #[test]
551 #[should_panic]
552 fn test_pdf_bad_input_sum() {
553 let n = try_create(vector![0.1, 0.3, 0.5, 0.8]);
554 n.pdf(&vector![0.5, 0.25, 0.8, 0.9]);
555 }
556
557 #[test]
558 #[should_panic]
559 fn test_ln_pdf_bad_input_length() {
560 let n = try_create(dvector![0.1, 0.3, 0.5, 0.8]);
561 n.ln_pdf(&dvector![0.5]);
562 }
563
564 #[test]
565 #[should_panic]
566 fn test_ln_pdf_bad_input_range() {
567 let n = try_create(vector![0.1, 0.3, 0.5, 0.8]);
568 n.ln_pdf(&vector![1.5, 0.0, 0.0, 0.0]);
569 }
570
571 #[test]
572 #[should_panic]
573 fn test_ln_pdf_bad_input_sum() {
574 let n = try_create(vector![0.1, 0.3, 0.5, 0.8]);
575 n.ln_pdf(&vector![0.5, 0.25, 0.8, 0.9]);
576 }
577
578 #[test]
579 fn test_error_is_sync_send() {
580 fn assert_sync_send<T: Sync + Send>() {}
581 assert_sync_send::<DirichletError>();
582 }
583}