statrs/distribution/
dirichlet.rs

1use crate::distribution::Continuous;
2use crate::function::gamma;
3use crate::statistics::*;
4use crate::{prec, Result, StatsError};
5use nalgebra::DMatrix;
6use nalgebra::DVector;
7use nalgebra::{
8    base::allocator::Allocator, base::dimension::DimName, DefaultAllocator, Dim, DimMin, U1,
9};
10use rand::Rng;
11use std::f64;
12
13/// Implements the
14/// [Dirichlet](https://en.wikipedia.org/wiki/Dirichlet_distribution)
15/// distribution
16///
17/// # Examples
18///
19/// ```
20/// use statrs::distribution::{Dirichlet, Continuous};
21/// use statrs::statistics::Distribution;
22/// use nalgebra::DVector;
23/// use statrs::statistics::MeanN;
24///
25/// let n = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap();
26/// assert_eq!(n.mean().unwrap(), DVector::from_vec(vec![1.0 / 6.0, 1.0 / 3.0, 0.5]));
27/// assert_eq!(n.pdf(&DVector::from_vec(vec![0.33333, 0.33333, 0.33333])), 2.222155556222205);
28/// ```
29#[derive(Debug, Clone, PartialEq)]
30pub struct Dirichlet {
31    alpha: DVector<f64>,
32}
33impl Dirichlet {
34    /// Constructs a new dirichlet distribution with the given
35    /// concentration parameters (alpha)
36    ///
37    /// # Errors
38    ///
39    /// Returns an error if any element `x` in alpha exist
40    /// such that `x < = 0.0` or `x` is `NaN`, or if the length of alpha is
41    /// less than 2
42    ///
43    /// # Examples
44    ///
45    /// ```
46    /// use statrs::distribution::Dirichlet;
47    /// use nalgebra::DVector;
48    ///
49    /// let alpha_ok = vec![1.0, 2.0, 3.0];
50    /// let mut result = Dirichlet::new(alpha_ok);
51    /// assert!(result.is_ok());
52    ///
53    /// let alpha_err = vec![0.0];
54    /// result = Dirichlet::new(alpha_err);
55    /// assert!(result.is_err());
56    /// ```
57    pub fn new(alpha: Vec<f64>) -> Result<Dirichlet> {
58        if !is_valid_alpha(&alpha) {
59            Err(StatsError::BadParams)
60        } else {
61            // let vec = alpha.to_vec();
62            Ok(Dirichlet {
63                alpha: DVector::from_vec(alpha.to_vec()),
64            })
65        }
66    }
67
68    /// Constructs a new dirichlet distribution with the given
69    /// concentration parameter (alpha) repeated `n` times
70    ///
71    /// # Errors
72    ///
73    /// Returns an error if `alpha < = 0.0` or `alpha` is `NaN`,
74    /// or if `n < 2`
75    ///
76    /// # Examples
77    ///
78    /// ```
79    /// use statrs::distribution::Dirichlet;
80    ///
81    /// let mut result = Dirichlet::new_with_param(1.0, 3);
82    /// assert!(result.is_ok());
83    ///
84    /// result = Dirichlet::new_with_param(0.0, 1);
85    /// assert!(result.is_err());
86    /// ```
87    pub fn new_with_param(alpha: f64, n: usize) -> Result<Dirichlet> {
88        Self::new(vec![alpha; n])
89    }
90
91    /// Returns the concentration parameters of
92    /// the dirichlet distribution as a slice
93    ///
94    /// # Examples
95    ///
96    /// ```
97    /// use statrs::distribution::Dirichlet;
98    /// use nalgebra::DVector;
99    ///
100    /// let n = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap();
101    /// assert_eq!(n.alpha(), &DVector::from_vec(vec![1.0, 2.0, 3.0]));
102    /// ```
103    pub fn alpha(&self) -> &DVector<f64> {
104        &self.alpha
105    }
106
107    fn alpha_sum(&self) -> f64 {
108        self.alpha.fold(0.0, |acc, x| acc + x)
109    }
110    /// Returns the entropy of the dirichlet distribution
111    ///
112    /// # Formula
113    ///
114    /// ```ignore
115    /// ln(B(α)) - (K - α_0)ψ(α_0) - Σ((α_i - 1)ψ(α_i))
116    /// ```
117    ///
118    /// where
119    ///
120    /// ```ignore
121    /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i))
122    /// ```
123    ///
124    /// `α_0` is the sum of all concentration parameters,
125    /// `K` is the number of concentration parameters, `ψ` is the digamma
126    /// function, `α_i`
127    /// is the `i`th concentration parameter, and `Σ` is the sum from `1` to `K`
128    pub fn entropy(&self) -> Option<f64> {
129        let sum = self.alpha_sum();
130        let num = self.alpha.iter().fold(0.0, |acc, &x| {
131            acc + gamma::ln_gamma(x) + (x - 1.0) * gamma::digamma(x)
132        });
133        let entr =
134            -gamma::ln_gamma(sum) + (sum - self.alpha.len() as f64) * gamma::digamma(sum) - num;
135        Some(entr)
136    }
137}
138
139impl ::rand::distributions::Distribution<DVector<f64>> for Dirichlet {
140    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> DVector<f64> {
141        let mut sum = 0.0;
142        let mut samples: Vec<_> = self
143            .alpha
144            .iter()
145            .map(|&a| {
146                let sample = super::gamma::sample_unchecked(rng, a, 1.0);
147                sum += sample;
148                sample
149            })
150            .collect();
151        for _ in samples.iter_mut().map(|x| *x /= sum) {}
152        DVector::from_vec(samples)
153    }
154}
155
156impl MeanN<DVector<f64>> for Dirichlet {
157    /// Returns the means of the dirichlet distribution
158    ///
159    /// # Formula
160    ///
161    /// ```ignore
162    /// α_i / α_0
163    /// ```
164    ///
165    /// for the `i`th element where `α_i` is the `i`th concentration parameter
166    /// and `α_0` is the sum of all concentration parameters
167    fn mean(&self) -> Option<DVector<f64>> {
168        let sum = self.alpha_sum();
169        Some(self.alpha.map(|x| x / sum))
170    }
171}
172
173impl VarianceN<DMatrix<f64>> for Dirichlet {
174    /// Returns the variances of the dirichlet distribution
175    ///
176    /// # Formula
177    ///
178    /// ```ignore
179    /// (α_i * (α_0 - α_i)) / (α_0^2 * (α_0 + 1))
180    /// ```
181    ///
182    /// for the `i`th element where `α_i` is the `i`th concentration parameter
183    /// and `α_0` is the sum of all concentration parameters
184    fn variance(&self) -> Option<DMatrix<f64>> {
185        let sum = self.alpha_sum();
186        let normalizing = sum * sum * (sum + 1.0);
187        let mut cov = DMatrix::from_diagonal(&self.alpha.map(|x| x * (sum - x) / normalizing));
188        let mut offdiag = |x: usize, y: usize| {
189            let elt = -self.alpha[x] * self.alpha[y] / normalizing;
190            cov[(x, y)] = elt;
191            cov[(y, x)] = elt;
192        };
193        for i in 0..self.alpha.len() {
194            for j in 0..i {
195                offdiag(i, j);
196            }
197        }
198        Some(cov)
199    }
200}
201
202impl<'a> Continuous<&'a DVector<f64>, f64> for Dirichlet {
203    /// Calculates the probabiliy density function for the dirichlet
204    /// distribution
205    /// with given `x`'s corresponding to the concentration parameters for this
206    /// distribution
207    ///
208    /// # Panics
209    ///
210    /// If any element in `x` is not in `(0, 1)`, the elements in `x` do not
211    /// sum to
212    /// `1` with a tolerance of `1e-4`,  or if `x` is not the same length as
213    /// the vector of
214    /// concentration parameters for this distribution
215    ///
216    /// # Formula
217    ///
218    /// ```ignore
219    /// (1 / B(α)) * Π(x_i^(α_i - 1))
220    /// ```
221    ///
222    /// where
223    ///
224    /// ```ignore
225    /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i))
226    /// ```
227    ///
228    /// `α` is the vector of concentration parameters, `α_i` is the `i`th
229    /// concentration parameter, `x_i` is the `i`th argument corresponding to
230    /// the `i`th concentration parameter, `Γ` is the gamma function,
231    /// `Π` is the product from `1` to `K`, `Σ` is the sum from `1` to `K`,
232    /// and `K` is the number of concentration parameters
233    fn pdf(&self, x: &DVector<f64>) -> f64 {
234        self.ln_pdf(x).exp()
235    }
236
237    /// Calculates the log probabiliy density function for the dirichlet
238    /// distribution
239    /// with given `x`'s corresponding to the concentration parameters for this
240    /// distribution
241    ///
242    /// # Panics
243    ///
244    /// If any element in `x` is not in `(0, 1)`, the elements in `x` do not
245    /// sum to
246    /// `1` with a tolerance of `1e-4`,  or if `x` is not the same length as
247    /// the vector of
248    /// concentration parameters for this distribution
249    ///
250    /// # Formula
251    ///
252    /// ```ignore
253    /// ln((1 / B(α)) * Π(x_i^(α_i - 1)))
254    /// ```
255    ///
256    /// where
257    ///
258    /// ```ignore
259    /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i))
260    /// ```
261    ///
262    /// `α` is the vector of concentration parameters, `α_i` is the `i`th
263    /// concentration parameter, `x_i` is the `i`th argument corresponding to
264    /// the `i`th concentration parameter, `Γ` is the gamma function,
265    /// `Π` is the product from `1` to `K`, `Σ` is the sum from `1` to `K`,
266    /// and `K` is the number of concentration parameters
267    fn ln_pdf(&self, x: &DVector<f64>) -> f64 {
268        // TODO: would it be clearer here to just do a for loop instead
269        // of using iterators?
270        if self.alpha.len() != x.len() {
271            panic!("Arguments must have correct dimensions.");
272        }
273        if x.iter().any(|&x| x <= 0.0 || x >= 1.0) {
274            panic!("Arguments must be in (0, 1)");
275        }
276        let (term, sum_xi, sum_alpha) = x
277            .iter()
278            .enumerate()
279            .map(|pair| (pair.1, self.alpha[pair.0]))
280            .fold((0.0, 0.0, 0.0), |acc, pair| {
281                (
282                    acc.0 + (pair.1 - 1.0) * pair.0.ln() - gamma::ln_gamma(pair.1),
283                    acc.1 + pair.0,
284                    acc.2 + pair.1,
285                )
286            });
287
288        if !prec::almost_eq(sum_xi, 1.0, 1e-4) {
289            panic!();
290        } else {
291            term + gamma::ln_gamma(sum_alpha)
292        }
293    }
294}
295
296// determines if `a` is a valid alpha array
297// for the Dirichlet distribution
298fn is_valid_alpha(a: &[f64]) -> bool {
299    a.len() >= 2 && super::internal::is_valid_multinomial(a, false)
300}
301
302#[rustfmt::skip]
303#[cfg(all(test, feature = "nightly"))]
304mod tests {
305    use super::*;
306    use nalgebra::{DVector};
307    use crate::function::gamma;
308    use crate::statistics::*;
309    use crate::distribution::{Continuous, Dirichlet};
310    use crate::consts::ACC;
311
312    #[test]
313    fn test_is_valid_alpha() {
314        let invalid = [1.0];
315        assert!(!is_valid_alpha(&invalid));
316    }
317
318    fn try_create(alpha: &[f64]) -> Dirichlet
319    {
320        let n = Dirichlet::new(alpha.to_vec());
321        assert!(n.is_ok());
322        n.unwrap()
323    }
324
325    fn create_case(alpha: &[f64])
326    {
327        let n = try_create(alpha);
328        let a2 = n.alpha();
329        for i in 0..alpha.len() {
330            assert_eq!(alpha[i], a2[i]);
331        }
332    }
333
334    fn bad_create_case(alpha: &[f64])
335    {
336        let n = Dirichlet::new(alpha.to_vec());
337        assert!(n.is_err());
338    }
339
340    #[test]
341    fn test_create() {
342        create_case(&[1.0, 2.0, 3.0, 4.0, 5.0]);
343        create_case(&[0.001, f64::INFINITY, 3756.0]);
344    }
345
346    #[test]
347    fn test_bad_create() {
348        bad_create_case(&[1.0]);
349        bad_create_case(&[1.0, 2.0, 0.0, 4.0, 5.0]);
350        bad_create_case(&[1.0, f64::NAN, 3.0, 4.0, 5.0]);
351        bad_create_case(&[0.0, 0.0, 0.0]);
352    }
353
354    // #[test]
355    // fn test_mean() {
356    //     let n = Dirichlet::new_with_param(0.3, 5).unwrap();
357    //     let res = n.mean();
358    //     for x in res {
359    //         assert_eq!(x, 0.3 / 1.5);
360    //     }
361    // }
362
363    // #[test]
364    // fn test_variance() {
365    //     let alpha = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
366    //     let sum = alpha.iter().fold(0.0, |acc, x| acc + x);
367    //     let n = Dirichlet::new(&alpha).unwrap();
368    //     let res = n.variance();
369    //     for i in 1..11 {
370    //         let f = i as f64;
371    //         assert_almost_eq!(res[i-1], f * (sum - f) / (sum * sum * (sum + 1.0)), 1e-15);
372    //     }
373    // }
374
375    // #[test]
376    // fn test_std_dev() {
377    //     let alpha = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
378    //     let sum = alpha.iter().fold(0.0, |acc, x| acc + x);
379    //     let n = Dirichlet::new(&alpha).unwrap();
380    //     let res = n.std_dev();
381    //     for i in 1..11 {
382    //         let f = i as f64;
383    //         assert_almost_eq!(res[i-1], (f * (sum - f) / (sum * sum * (sum + 1.0))).sqrt(), 1e-15);
384    //     }
385    // }
386
387    #[test]
388    fn test_entropy() {
389        let mut n = try_create(&[0.1, 0.3, 0.5, 0.8]);
390        assert_eq!(n.entropy().unwrap(), -17.46469081094079);
391
392        n = try_create(&[0.1, 0.2, 0.3, 0.4]);
393        assert_eq!(n.entropy().unwrap(), -21.53881433791513);
394    }
395
396    macro_rules! dvec {
397        ($($x:expr),*) => (DVector::from_vec(vec![$($x),*]));
398    }
399
400    #[test]
401    fn test_pdf() {
402        let n = try_create(&[0.1, 0.3, 0.5, 0.8]);
403        assert_almost_eq!(n.pdf(&dvec![0.01, 0.03, 0.5, 0.46]), 18.77225681167061, 1e-12);
404        assert_almost_eq!(n.pdf(&dvec![0.1,0.2,0.3,0.4]), 0.8314656481199253, 1e-14);
405    }
406
407    #[test]
408    fn test_ln_pdf() {
409        let n = try_create(&[0.1, 0.3, 0.5, 0.8]);
410        assert_almost_eq!(n.ln_pdf(&dvec![0.01, 0.03, 0.5, 0.46]), 18.77225681167061f64.ln(), 1e-12);
411        assert_almost_eq!(n.ln_pdf(&dvec![0.1,0.2,0.3,0.4]), 0.8314656481199253f64.ln(), 1e-14);
412    }
413
414    #[test]
415    #[should_panic]
416    fn test_pdf_bad_input_length() {
417        let n = try_create(&[0.1, 0.3, 0.5, 0.8]);
418        n.pdf(&dvec![0.5]);
419    }
420
421    #[test]
422    #[should_panic]
423    fn test_pdf_bad_input_range() {
424        let n = try_create(&[0.1, 0.3, 0.5, 0.8]);
425        n.pdf(&dvec![1.5, 0.0, 0.0, 0.0]);
426    }
427
428    #[test]
429    #[should_panic]
430    fn test_pdf_bad_input_sum() {
431        let n = try_create(&[0.1, 0.3, 0.5, 0.8]);
432        n.pdf(&dvec![0.5, 0.25, 0.8, 0.9]);
433    }
434
435    #[test]
436    #[should_panic]
437    fn test_ln_pdf_bad_input_length() {
438        let n = try_create(&[0.1, 0.3, 0.5, 0.8]);
439        n.ln_pdf(&dvec![0.5]);
440    }
441
442    #[test]
443    #[should_panic]
444    fn test_ln_pdf_bad_input_range() {
445        let n = try_create(&[0.1, 0.3, 0.5, 0.8]);
446        n.ln_pdf(&dvec![1.5, 0.0, 0.0, 0.0]);
447    }
448
449    #[test]
450    #[should_panic]
451    fn test_ln_pdf_bad_input_sum() {
452        let n = try_create(&[0.1, 0.3, 0.5, 0.8]);
453        n.ln_pdf(&dvec![0.5, 0.25, 0.8, 0.9]);
454    }
455}