statrs/distribution/
dirichlet.rs

1use crate::distribution::Continuous;
2use crate::function::gamma;
3use crate::prec;
4use crate::statistics::*;
5use nalgebra::{Dim, Dyn, OMatrix, OVector};
6use std::f64;
7
8/// Implements the
9/// [Dirichlet](https://en.wikipedia.org/wiki/Dirichlet_distribution)
10/// distribution
11///
12/// # Examples
13///
14/// ```
15/// use statrs::distribution::{Dirichlet, Continuous};
16/// use statrs::statistics::Distribution;
17/// use nalgebra::DVector;
18/// use statrs::statistics::MeanN;
19///
20/// let n = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap();
21/// assert_eq!(n.mean().unwrap(), DVector::from_vec(vec![1.0 / 6.0, 1.0 / 3.0, 0.5]));
22/// assert_eq!(n.pdf(&DVector::from_vec(vec![0.33333, 0.33333, 0.33333])), 2.222155556222205);
23/// ```
24#[derive(Clone, PartialEq, Debug)]
25pub struct Dirichlet<D>
26where
27    D: Dim,
28    nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
29{
30    alpha: OVector<f64, D>,
31}
32
33/// Represents the errors that can occur when creating a [`Dirichlet`].
34#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
35#[non_exhaustive]
36pub enum DirichletError {
37    /// Alpha contains less than two elements.
38    AlphaTooShort,
39
40    /// Alpha contains an element that is NaN, infinite, zero or less than zero.
41    AlphaHasInvalidElements,
42}
43
44impl std::fmt::Display for DirichletError {
45    #[cfg_attr(coverage_nightly, coverage(off))]
46    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
47        match self {
48            DirichletError::AlphaTooShort => write!(f, "Alpha contains less than two elements"),
49            DirichletError::AlphaHasInvalidElements => write!(
50                f,
51                "Alpha contains an element that is NaN, infinite, zero or less than zero"
52            ),
53        }
54    }
55}
56
57impl std::error::Error for DirichletError {}
58
59impl Dirichlet<Dyn> {
60    /// Constructs a new dirichlet distribution with the given
61    /// concentration parameters (alpha)
62    ///
63    /// # Errors
64    ///
65    /// Returns an error if any element `x` in alpha exist
66    /// such that `x < = 0.0` or `x` is `NaN`, or if the length of alpha is
67    /// less than 2
68    ///
69    /// # Examples
70    ///
71    /// ```
72    /// use statrs::distribution::Dirichlet;
73    /// use nalgebra::DVector;
74    ///
75    /// let alpha_ok = vec![1.0, 2.0, 3.0];
76    /// let mut result = Dirichlet::new(alpha_ok);
77    /// assert!(result.is_ok());
78    ///
79    /// let alpha_err = vec![0.0];
80    /// result = Dirichlet::new(alpha_err);
81    /// assert!(result.is_err());
82    /// ```
83    pub fn new(alpha: Vec<f64>) -> Result<Self, DirichletError> {
84        Self::new_from_nalgebra(alpha.into())
85    }
86
87    /// Constructs a new dirichlet distribution with the given
88    /// concentration parameter (alpha) repeated `n` times
89    ///
90    /// # Errors
91    ///
92    /// Returns an error if `alpha < = 0.0` or `alpha` is `NaN`,
93    /// or if `n < 2`
94    ///
95    /// # Examples
96    ///
97    /// ```
98    /// use statrs::distribution::Dirichlet;
99    ///
100    /// let mut result = Dirichlet::new_with_param(1.0, 3);
101    /// assert!(result.is_ok());
102    ///
103    /// result = Dirichlet::new_with_param(0.0, 1);
104    /// assert!(result.is_err());
105    /// ```
106    pub fn new_with_param(alpha: f64, n: usize) -> Result<Self, DirichletError> {
107        Self::new(vec![alpha; n])
108    }
109}
110
111impl<D> Dirichlet<D>
112where
113    D: Dim,
114    nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
115{
116    /// Constructs a new distribution with the given vector for `alpha`
117    /// Does not clone the vector it takes ownership of
118    ///
119    /// # Error
120    ///
121    /// Returns an error if vector has length less than 2 or if any element
122    /// of alpha is NOT finite positive
123    pub fn new_from_nalgebra(alpha: OVector<f64, D>) -> Result<Self, DirichletError> {
124        if alpha.len() < 2 {
125            return Err(DirichletError::AlphaTooShort);
126        }
127
128        if alpha.iter().any(|&a_i| !a_i.is_finite() || a_i <= 0.0) {
129            return Err(DirichletError::AlphaHasInvalidElements);
130        }
131
132        Ok(Self { alpha })
133    }
134
135    /// Returns the concentration parameters of
136    /// the dirichlet distribution as a slice
137    ///
138    /// # Examples
139    ///
140    /// ```
141    /// use statrs::distribution::Dirichlet;
142    /// use nalgebra::DVector;
143    ///
144    /// let n = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap();
145    /// assert_eq!(n.alpha(), &DVector::from_vec(vec![1.0, 2.0, 3.0]));
146    /// ```
147    pub fn alpha(&self) -> &nalgebra::OVector<f64, D> {
148        &self.alpha
149    }
150
151    fn alpha_sum(&self) -> f64 {
152        self.alpha.sum()
153    }
154
155    /// Returns the entropy of the dirichlet distribution
156    ///
157    /// # Formula
158    ///
159    /// ```text
160    /// ln(B(α)) - (K - α_0)ψ(α_0) - Σ((α_i - 1)ψ(α_i))
161    /// ```
162    ///
163    /// where
164    ///
165    /// ```text
166    /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i))
167    /// ```
168    ///
169    /// `α_0` is the sum of all concentration parameters,
170    /// `K` is the number of concentration parameters, `ψ` is the digamma
171    /// function, `α_i`
172    /// is the `i`th concentration parameter, and `Σ` is the sum from `1` to `K`
173    pub fn entropy(&self) -> Option<f64> {
174        let sum = self.alpha_sum();
175        let num = self.alpha.iter().fold(0.0, |acc, &x| {
176            acc + gamma::ln_gamma(x) + (x - 1.0) * gamma::digamma(x)
177        });
178        let entr =
179            -gamma::ln_gamma(sum) + (sum - self.alpha.len() as f64) * gamma::digamma(sum) - num;
180        Some(entr)
181    }
182}
183
184impl<D> std::fmt::Display for Dirichlet<D>
185where
186    D: Dim,
187    nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
188{
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        write!(f, "Dir({}, {})", self.alpha.len(), &self.alpha)
191    }
192}
193
194#[cfg(feature = "rand")]
195#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
196impl<D> ::rand::distributions::Distribution<OVector<f64, D>> for Dirichlet<D>
197where
198    D: Dim,
199    nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
200{
201    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<f64, D> {
202        let mut sum = 0.0;
203        OVector::from_iterator_generic(
204            self.alpha.shape_generic().0,
205            nalgebra::Const::<1>,
206            self.alpha.iter().map(|&a| {
207                let sample = super::gamma::sample_unchecked(rng, a, 1.0);
208                sum += sample;
209                sample
210            }),
211        )
212    }
213}
214
215impl<D> MeanN<OVector<f64, D>> for Dirichlet<D>
216where
217    D: Dim,
218    nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
219{
220    /// Returns the means of the dirichlet distribution
221    ///
222    /// # Formula
223    ///
224    /// ```text
225    /// α_i / α_0
226    /// ```
227    ///
228    /// for the `i`th element where `α_i` is the `i`th concentration parameter
229    /// and `α_0` is the sum of all concentration parameters
230    fn mean(&self) -> Option<OVector<f64, D>> {
231        let sum = self.alpha_sum();
232        Some(self.alpha.map(|x| x / sum))
233    }
234}
235
236impl<D> VarianceN<OMatrix<f64, D, D>> for Dirichlet<D>
237where
238    D: Dim,
239    nalgebra::DefaultAllocator:
240        nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
241{
242    /// Returns the variances of the dirichlet distribution
243    ///
244    /// # Formula
245    ///
246    /// ```text
247    /// (α_i * (α_0 - α_i)) / (α_0^2 * (α_0 + 1))
248    /// ```
249    ///
250    /// for the `i`th element where `α_i` is the `i`th concentration parameter
251    /// and `α_0` is the sum of all concentration parameters
252    fn variance(&self) -> Option<OMatrix<f64, D, D>> {
253        let sum = self.alpha_sum();
254        let normalizing = sum * sum * (sum + 1.0);
255        let mut cov = OMatrix::from_diagonal(&self.alpha.map(|x| x * (sum - x) / normalizing));
256        let mut offdiag = |x: usize, y: usize| {
257            let elt = -self.alpha[x] * self.alpha[y] / normalizing;
258            cov[(x, y)] = elt;
259            cov[(y, x)] = elt;
260        };
261        for i in 0..self.alpha.len() {
262            for j in 0..i {
263                offdiag(i, j);
264            }
265        }
266        Some(cov)
267    }
268}
269
270impl<D> Continuous<&OVector<f64, D>, f64> for Dirichlet<D>
271where
272    D: Dim,
273    nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>
274        + nalgebra::allocator::Allocator<D, D>
275        + nalgebra::allocator::Allocator<nalgebra::Const<1>, D>,
276{
277    /// Calculates the probabiliy density function for the dirichlet
278    /// distribution
279    /// with given `x`'s corresponding to the concentration parameters for this
280    /// distribution
281    ///
282    /// # Panics
283    ///
284    /// If any element in `x` is not in `(0, 1)`, the elements in `x` do not
285    /// sum to
286    /// `1` with a tolerance of `1e-4`,  or if `x` is not the same length as
287    /// the vector of
288    /// concentration parameters for this distribution
289    ///
290    /// # Formula
291    ///
292    /// ```text
293    /// (1 / B(α)) * Π(x_i^(α_i - 1))
294    /// ```
295    ///
296    /// where
297    ///
298    /// ```text
299    /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i))
300    /// ```
301    ///
302    /// `α` is the vector of concentration parameters, `α_i` is the `i`th
303    /// concentration parameter, `x_i` is the `i`th argument corresponding to
304    /// the `i`th concentration parameter, `Γ` is the gamma function,
305    /// `Π` is the product from `1` to `K`, `Σ` is the sum from `1` to `K`,
306    /// and `K` is the number of concentration parameters
307    fn pdf(&self, x: &OVector<f64, D>) -> f64 {
308        self.ln_pdf(x).exp()
309    }
310
311    /// Calculates the log probabiliy density function for the dirichlet
312    /// distribution
313    /// with given `x`'s corresponding to the concentration parameters for this
314    /// distribution
315    ///
316    /// # Panics
317    ///
318    /// If any element in `x` is not in `(0, 1)`, the elements in `x` do not
319    /// sum to
320    /// `1` with a tolerance of `1e-4`,  or if `x` is not the same length as
321    /// the vector of
322    /// concentration parameters for this distribution
323    ///
324    /// # Formula
325    ///
326    /// ```text
327    /// ln((1 / B(α)) * Π(x_i^(α_i - 1)))
328    /// ```
329    ///
330    /// where
331    ///
332    /// ```text
333    /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i))
334    /// ```
335    ///
336    /// `α` is the vector of concentration parameters, `α_i` is the `i`th
337    /// concentration parameter, `x_i` is the `i`th argument corresponding to
338    /// the `i`th concentration parameter, `Γ` is the gamma function,
339    /// `Π` is the product from `1` to `K`, `Σ` is the sum from `1` to `K`,
340    /// and `K` is the number of concentration parameters
341    fn ln_pdf(&self, x: &OVector<f64, D>) -> f64 {
342        if self.alpha.len() != x.len() {
343            panic!("Arguments must have correct dimensions.");
344        }
345
346        let mut term = 0.0;
347        let mut sum_x = 0.0;
348        let mut sum_alpha = 0.0;
349
350        for (&x_i, &alpha_i) in x.iter().zip(self.alpha.iter()) {
351            assert!(0.0 < x_i && x_i < 1.0, "Arguments must be in (0, 1)");
352
353            term += (alpha_i - 1.0) * x_i.ln() - gamma::ln_gamma(alpha_i);
354            sum_x += x_i;
355            sum_alpha += alpha_i;
356        }
357
358        assert!(
359            prec::almost_eq(sum_x, 1.0, 1e-4),
360            "Arguments must sum up to 1"
361        );
362        term + gamma::ln_gamma(sum_alpha)
363    }
364}
365
366#[rustfmt::skip]
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    use std::fmt::{Debug, Display};
372
373    use nalgebra::{dmatrix, dvector, vector, DimMin, OVector};
374
375    fn try_create<D>(alpha: OVector<f64, D>) -> Dirichlet<D>
376    where
377        D: DimMin<D, Output = D>,
378        nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
379    {
380        let mvn = Dirichlet::new_from_nalgebra(alpha);
381        assert!(mvn.is_ok());
382        mvn.unwrap()
383    }
384
385    fn bad_create_case<D>(alpha: OVector<f64, D>)
386    where
387        D: DimMin<D, Output = D>,
388        nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
389    {
390        let dd = Dirichlet::new_from_nalgebra(alpha);
391        assert!(dd.is_err());
392    }
393
394    fn test_almost<F, T, D>(alpha: OVector<f64, D>, expected: T, acc: f64, eval: F)
395    where
396        T: Debug + Display + approx::RelativeEq<Epsilon = f64>,
397        F: FnOnce(Dirichlet<D>) -> T,
398        D: DimMin<D, Output = D>,
399        nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
400    {
401        let dd = try_create(alpha);
402        let x = eval(dd);
403        assert_relative_eq!(expected, x, epsilon = acc);
404    }
405
406    #[test]
407    fn test_create() {
408        try_create(vector![1.0, 2.0]);
409        try_create(vector![1.0, 2.0, 3.0, 4.0, 5.0]);
410        assert!(Dirichlet::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]).is_ok());
411        // try_create(vector![0.001, f64::INFINITY, 3756.0]); // moved to bad case as this is degenerate
412    }
413
414    #[test]
415    fn test_bad_create() {
416        bad_create_case(vector![1.0, f64::NAN]);
417        bad_create_case(vector![1.0, 0.0]);
418        bad_create_case(vector![1.0, f64::INFINITY]);
419        bad_create_case(vector![-1.0, 2.0]);
420        bad_create_case(vector![1.0]);
421        bad_create_case(vector![1.0, 2.0, 0.0, 4.0, 5.0]);
422        bad_create_case(vector![1.0, f64::NAN, 3.0, 4.0, 5.0]);
423        bad_create_case(vector![0.0, 0.0, 0.0]);
424        bad_create_case(vector![0.001, f64::INFINITY, 3756.0]); // moved to bad case as this is degenerate
425    }
426
427    #[test]
428    fn test_mean() {
429        let mean = |dd: Dirichlet<_>| dd.mean().unwrap();
430
431        test_almost(vec![0.5; 5].into(), vec![1.0 / 5.0; 5].into(), 1e-15, mean);
432
433        test_almost(
434            dvector![0.1, 0.2, 0.3, 0.4],
435            dvector![0.1, 0.2, 0.3, 0.4],
436            1e-15,
437            mean,
438        );
439
440        test_almost(
441            dvector![1.0, 2.0, 3.0, 4.0],
442            dvector![0.1, 0.2, 0.3, 0.4],
443            1e-15,
444            mean,
445        );
446    }
447
448    #[test]
449    fn test_variance() {
450        let variance = |dd: Dirichlet<_>| dd.variance().unwrap();
451
452        test_almost(
453            dvector![1.0, 2.0],
454            dmatrix![0.055555555555555, -0.055555555555555;
455                    -0.055555555555555,  0.055555555555555;
456            ],
457            1e-15,
458            variance,
459        );
460
461        test_almost(
462            dvector![0.1, 0.2, 0.3, 0.4],
463            dmatrix![0.045, -0.010, -0.015, -0.020;
464                    -0.010,  0.080, -0.030, -0.040;
465                    -0.015, -0.030,  0.105, -0.060;
466                    -0.020, -0.040, -0.060,  0.120;
467            ],
468            1e-15,
469            variance,
470        );
471    }
472
473    // #[test]
474    // fn test_std_dev() {
475    //     let alpha = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
476    //     let sum = alpha.iter().fold(0.0, |acc, x| acc + x);
477    //     let n = Dirichlet::new(&alpha).unwrap();
478    //     let res = n.std_dev();
479    //     for i in 1..11 {
480    //         let f = i as f64;
481    //         assert_almost_eq!(res[i-1], (f * (sum - f) / (sum * sum * (sum + 1.0))).sqrt(), 1e-15);
482    //     }
483    // }
484
485    #[test]
486    fn test_entropy() {
487        let entropy = |x: Dirichlet<_>| x.entropy().unwrap();
488        test_almost(
489            vector![0.1, 0.3, 0.5, 0.8],
490            -17.46469081094079,
491            1e-30,
492            entropy,
493        );
494        test_almost(
495            vector![0.1, 0.2, 0.3, 0.4],
496            -21.53881433791513,
497            1e-30,
498            entropy,
499        );
500    }
501
502    #[test]
503    fn test_pdf() {
504        let pdf = |arg| move |x: Dirichlet<_>| x.pdf(&arg);
505        test_almost(
506            vector![0.1, 0.3, 0.5, 0.8],
507            18.77225681167061,
508            1e-12,
509            pdf([0.01, 0.03, 0.5, 0.46].into()),
510        );
511        test_almost(
512            vector![0.1, 0.3, 0.5, 0.8],
513            0.8314656481199253,
514            1e-14,
515            pdf([0.1, 0.2, 0.3, 0.4].into()),
516        );
517    }
518
519    #[test]
520    fn test_ln_pdf() {
521        let ln_pdf = |arg| move |x: Dirichlet<_>| x.ln_pdf(&arg);
522        test_almost(
523            vector![0.1, 0.3, 0.5, 0.8],
524            18.77225681167061_f64.ln(),
525            1e-12,
526            ln_pdf([0.01, 0.03, 0.5, 0.46].into()),
527        );
528        test_almost(
529            vector![0.1, 0.3, 0.5, 0.8],
530            0.8314656481199253_f64.ln(),
531            1e-14,
532            ln_pdf([0.1, 0.2, 0.3, 0.4].into()),
533        );
534    }
535
536    #[test]
537    #[should_panic]
538    fn test_pdf_bad_input_length() {
539        let n = try_create(dvector![0.1, 0.3, 0.5, 0.8]);
540        n.pdf(&dvector![0.5]);
541    }
542
543    #[test]
544    #[should_panic]
545    fn test_pdf_bad_input_range() {
546        let n = try_create(vector![0.1, 0.3, 0.5, 0.8]);
547        n.pdf(&vector![1.5, 0.0, 0.0, 0.0]);
548    }
549
550    #[test]
551    #[should_panic]
552    fn test_pdf_bad_input_sum() {
553        let n = try_create(vector![0.1, 0.3, 0.5, 0.8]);
554        n.pdf(&vector![0.5, 0.25, 0.8, 0.9]);
555    }
556
557    #[test]
558    #[should_panic]
559    fn test_ln_pdf_bad_input_length() {
560        let n = try_create(dvector![0.1, 0.3, 0.5, 0.8]);
561        n.ln_pdf(&dvector![0.5]);
562    }
563
564    #[test]
565    #[should_panic]
566    fn test_ln_pdf_bad_input_range() {
567        let n = try_create(vector![0.1, 0.3, 0.5, 0.8]);
568        n.ln_pdf(&vector![1.5, 0.0, 0.0, 0.0]);
569    }
570
571    #[test]
572    #[should_panic]
573    fn test_ln_pdf_bad_input_sum() {
574        let n = try_create(vector![0.1, 0.3, 0.5, 0.8]);
575        n.ln_pdf(&vector![0.5, 0.25, 0.8, 0.9]);
576    }
577
578    #[test]
579    fn test_error_is_sync_send() {
580        fn assert_sync_send<T: Sync + Send>() {}
581        assert_sync_send::<DirichletError>();
582    }
583}