statrs/distribution/dirichlet.rs
1use crate::distribution::Continuous;
2use crate::function::gamma;
3use crate::statistics::*;
4use crate::{prec, Result, StatsError};
5use nalgebra::DMatrix;
6use nalgebra::DVector;
7use nalgebra::{
8 base::allocator::Allocator, base::dimension::DimName, DefaultAllocator, Dim, DimMin, U1,
9};
10use rand::Rng;
11use std::f64;
12
13/// Implements the
14/// [Dirichlet](https://en.wikipedia.org/wiki/Dirichlet_distribution)
15/// distribution
16///
17/// # Examples
18///
19/// ```
20/// use statrs::distribution::{Dirichlet, Continuous};
21/// use statrs::statistics::Distribution;
22/// use nalgebra::DVector;
23/// use statrs::statistics::MeanN;
24///
25/// let n = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap();
26/// assert_eq!(n.mean().unwrap(), DVector::from_vec(vec![1.0 / 6.0, 1.0 / 3.0, 0.5]));
27/// assert_eq!(n.pdf(&DVector::from_vec(vec![0.33333, 0.33333, 0.33333])), 2.222155556222205);
28/// ```
29#[derive(Debug, Clone, PartialEq)]
30pub struct Dirichlet {
31 alpha: DVector<f64>,
32}
33impl Dirichlet {
34 /// Constructs a new dirichlet distribution with the given
35 /// concentration parameters (alpha)
36 ///
37 /// # Errors
38 ///
39 /// Returns an error if any element `x` in alpha exist
40 /// such that `x < = 0.0` or `x` is `NaN`, or if the length of alpha is
41 /// less than 2
42 ///
43 /// # Examples
44 ///
45 /// ```
46 /// use statrs::distribution::Dirichlet;
47 /// use nalgebra::DVector;
48 ///
49 /// let alpha_ok = vec![1.0, 2.0, 3.0];
50 /// let mut result = Dirichlet::new(alpha_ok);
51 /// assert!(result.is_ok());
52 ///
53 /// let alpha_err = vec![0.0];
54 /// result = Dirichlet::new(alpha_err);
55 /// assert!(result.is_err());
56 /// ```
57 pub fn new(alpha: Vec<f64>) -> Result<Dirichlet> {
58 if !is_valid_alpha(&alpha) {
59 Err(StatsError::BadParams)
60 } else {
61 // let vec = alpha.to_vec();
62 Ok(Dirichlet {
63 alpha: DVector::from_vec(alpha.to_vec()),
64 })
65 }
66 }
67
68 /// Constructs a new dirichlet distribution with the given
69 /// concentration parameter (alpha) repeated `n` times
70 ///
71 /// # Errors
72 ///
73 /// Returns an error if `alpha < = 0.0` or `alpha` is `NaN`,
74 /// or if `n < 2`
75 ///
76 /// # Examples
77 ///
78 /// ```
79 /// use statrs::distribution::Dirichlet;
80 ///
81 /// let mut result = Dirichlet::new_with_param(1.0, 3);
82 /// assert!(result.is_ok());
83 ///
84 /// result = Dirichlet::new_with_param(0.0, 1);
85 /// assert!(result.is_err());
86 /// ```
87 pub fn new_with_param(alpha: f64, n: usize) -> Result<Dirichlet> {
88 Self::new(vec![alpha; n])
89 }
90
91 /// Returns the concentration parameters of
92 /// the dirichlet distribution as a slice
93 ///
94 /// # Examples
95 ///
96 /// ```
97 /// use statrs::distribution::Dirichlet;
98 /// use nalgebra::DVector;
99 ///
100 /// let n = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap();
101 /// assert_eq!(n.alpha(), &DVector::from_vec(vec![1.0, 2.0, 3.0]));
102 /// ```
103 pub fn alpha(&self) -> &DVector<f64> {
104 &self.alpha
105 }
106
107 fn alpha_sum(&self) -> f64 {
108 self.alpha.fold(0.0, |acc, x| acc + x)
109 }
110 /// Returns the entropy of the dirichlet distribution
111 ///
112 /// # Formula
113 ///
114 /// ```ignore
115 /// ln(B(α)) - (K - α_0)ψ(α_0) - Σ((α_i - 1)ψ(α_i))
116 /// ```
117 ///
118 /// where
119 ///
120 /// ```ignore
121 /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i))
122 /// ```
123 ///
124 /// `α_0` is the sum of all concentration parameters,
125 /// `K` is the number of concentration parameters, `ψ` is the digamma
126 /// function, `α_i`
127 /// is the `i`th concentration parameter, and `Σ` is the sum from `1` to `K`
128 pub fn entropy(&self) -> Option<f64> {
129 let sum = self.alpha_sum();
130 let num = self.alpha.iter().fold(0.0, |acc, &x| {
131 acc + gamma::ln_gamma(x) + (x - 1.0) * gamma::digamma(x)
132 });
133 let entr =
134 -gamma::ln_gamma(sum) + (sum - self.alpha.len() as f64) * gamma::digamma(sum) - num;
135 Some(entr)
136 }
137}
138
139impl ::rand::distributions::Distribution<DVector<f64>> for Dirichlet {
140 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> DVector<f64> {
141 let mut sum = 0.0;
142 let mut samples: Vec<_> = self
143 .alpha
144 .iter()
145 .map(|&a| {
146 let sample = super::gamma::sample_unchecked(rng, a, 1.0);
147 sum += sample;
148 sample
149 })
150 .collect();
151 for _ in samples.iter_mut().map(|x| *x /= sum) {}
152 DVector::from_vec(samples)
153 }
154}
155
156impl MeanN<DVector<f64>> for Dirichlet {
157 /// Returns the means of the dirichlet distribution
158 ///
159 /// # Formula
160 ///
161 /// ```ignore
162 /// α_i / α_0
163 /// ```
164 ///
165 /// for the `i`th element where `α_i` is the `i`th concentration parameter
166 /// and `α_0` is the sum of all concentration parameters
167 fn mean(&self) -> Option<DVector<f64>> {
168 let sum = self.alpha_sum();
169 Some(self.alpha.map(|x| x / sum))
170 }
171}
172
173impl VarianceN<DMatrix<f64>> for Dirichlet {
174 /// Returns the variances of the dirichlet distribution
175 ///
176 /// # Formula
177 ///
178 /// ```ignore
179 /// (α_i * (α_0 - α_i)) / (α_0^2 * (α_0 + 1))
180 /// ```
181 ///
182 /// for the `i`th element where `α_i` is the `i`th concentration parameter
183 /// and `α_0` is the sum of all concentration parameters
184 fn variance(&self) -> Option<DMatrix<f64>> {
185 let sum = self.alpha_sum();
186 let normalizing = sum * sum * (sum + 1.0);
187 let mut cov = DMatrix::from_diagonal(&self.alpha.map(|x| x * (sum - x) / normalizing));
188 let mut offdiag = |x: usize, y: usize| {
189 let elt = -self.alpha[x] * self.alpha[y] / normalizing;
190 cov[(x, y)] = elt;
191 cov[(y, x)] = elt;
192 };
193 for i in 0..self.alpha.len() {
194 for j in 0..i {
195 offdiag(i, j);
196 }
197 }
198 Some(cov)
199 }
200}
201
202impl<'a> Continuous<&'a DVector<f64>, f64> for Dirichlet {
203 /// Calculates the probabiliy density function for the dirichlet
204 /// distribution
205 /// with given `x`'s corresponding to the concentration parameters for this
206 /// distribution
207 ///
208 /// # Panics
209 ///
210 /// If any element in `x` is not in `(0, 1)`, the elements in `x` do not
211 /// sum to
212 /// `1` with a tolerance of `1e-4`, or if `x` is not the same length as
213 /// the vector of
214 /// concentration parameters for this distribution
215 ///
216 /// # Formula
217 ///
218 /// ```ignore
219 /// (1 / B(α)) * Π(x_i^(α_i - 1))
220 /// ```
221 ///
222 /// where
223 ///
224 /// ```ignore
225 /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i))
226 /// ```
227 ///
228 /// `α` is the vector of concentration parameters, `α_i` is the `i`th
229 /// concentration parameter, `x_i` is the `i`th argument corresponding to
230 /// the `i`th concentration parameter, `Γ` is the gamma function,
231 /// `Π` is the product from `1` to `K`, `Σ` is the sum from `1` to `K`,
232 /// and `K` is the number of concentration parameters
233 fn pdf(&self, x: &DVector<f64>) -> f64 {
234 self.ln_pdf(x).exp()
235 }
236
237 /// Calculates the log probabiliy density function for the dirichlet
238 /// distribution
239 /// with given `x`'s corresponding to the concentration parameters for this
240 /// distribution
241 ///
242 /// # Panics
243 ///
244 /// If any element in `x` is not in `(0, 1)`, the elements in `x` do not
245 /// sum to
246 /// `1` with a tolerance of `1e-4`, or if `x` is not the same length as
247 /// the vector of
248 /// concentration parameters for this distribution
249 ///
250 /// # Formula
251 ///
252 /// ```ignore
253 /// ln((1 / B(α)) * Π(x_i^(α_i - 1)))
254 /// ```
255 ///
256 /// where
257 ///
258 /// ```ignore
259 /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i))
260 /// ```
261 ///
262 /// `α` is the vector of concentration parameters, `α_i` is the `i`th
263 /// concentration parameter, `x_i` is the `i`th argument corresponding to
264 /// the `i`th concentration parameter, `Γ` is the gamma function,
265 /// `Π` is the product from `1` to `K`, `Σ` is the sum from `1` to `K`,
266 /// and `K` is the number of concentration parameters
267 fn ln_pdf(&self, x: &DVector<f64>) -> f64 {
268 // TODO: would it be clearer here to just do a for loop instead
269 // of using iterators?
270 if self.alpha.len() != x.len() {
271 panic!("Arguments must have correct dimensions.");
272 }
273 if x.iter().any(|&x| x <= 0.0 || x >= 1.0) {
274 panic!("Arguments must be in (0, 1)");
275 }
276 let (term, sum_xi, sum_alpha) = x
277 .iter()
278 .enumerate()
279 .map(|pair| (pair.1, self.alpha[pair.0]))
280 .fold((0.0, 0.0, 0.0), |acc, pair| {
281 (
282 acc.0 + (pair.1 - 1.0) * pair.0.ln() - gamma::ln_gamma(pair.1),
283 acc.1 + pair.0,
284 acc.2 + pair.1,
285 )
286 });
287
288 if !prec::almost_eq(sum_xi, 1.0, 1e-4) {
289 panic!();
290 } else {
291 term + gamma::ln_gamma(sum_alpha)
292 }
293 }
294}
295
296// determines if `a` is a valid alpha array
297// for the Dirichlet distribution
298fn is_valid_alpha(a: &[f64]) -> bool {
299 a.len() >= 2 && super::internal::is_valid_multinomial(a, false)
300}
301
302#[rustfmt::skip]
303#[cfg(all(test, feature = "nightly"))]
304mod tests {
305 use super::*;
306 use nalgebra::{DVector};
307 use crate::function::gamma;
308 use crate::statistics::*;
309 use crate::distribution::{Continuous, Dirichlet};
310 use crate::consts::ACC;
311
312 #[test]
313 fn test_is_valid_alpha() {
314 let invalid = [1.0];
315 assert!(!is_valid_alpha(&invalid));
316 }
317
318 fn try_create(alpha: &[f64]) -> Dirichlet
319 {
320 let n = Dirichlet::new(alpha.to_vec());
321 assert!(n.is_ok());
322 n.unwrap()
323 }
324
325 fn create_case(alpha: &[f64])
326 {
327 let n = try_create(alpha);
328 let a2 = n.alpha();
329 for i in 0..alpha.len() {
330 assert_eq!(alpha[i], a2[i]);
331 }
332 }
333
334 fn bad_create_case(alpha: &[f64])
335 {
336 let n = Dirichlet::new(alpha.to_vec());
337 assert!(n.is_err());
338 }
339
340 #[test]
341 fn test_create() {
342 create_case(&[1.0, 2.0, 3.0, 4.0, 5.0]);
343 create_case(&[0.001, f64::INFINITY, 3756.0]);
344 }
345
346 #[test]
347 fn test_bad_create() {
348 bad_create_case(&[1.0]);
349 bad_create_case(&[1.0, 2.0, 0.0, 4.0, 5.0]);
350 bad_create_case(&[1.0, f64::NAN, 3.0, 4.0, 5.0]);
351 bad_create_case(&[0.0, 0.0, 0.0]);
352 }
353
354 // #[test]
355 // fn test_mean() {
356 // let n = Dirichlet::new_with_param(0.3, 5).unwrap();
357 // let res = n.mean();
358 // for x in res {
359 // assert_eq!(x, 0.3 / 1.5);
360 // }
361 // }
362
363 // #[test]
364 // fn test_variance() {
365 // let alpha = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
366 // let sum = alpha.iter().fold(0.0, |acc, x| acc + x);
367 // let n = Dirichlet::new(&alpha).unwrap();
368 // let res = n.variance();
369 // for i in 1..11 {
370 // let f = i as f64;
371 // assert_almost_eq!(res[i-1], f * (sum - f) / (sum * sum * (sum + 1.0)), 1e-15);
372 // }
373 // }
374
375 // #[test]
376 // fn test_std_dev() {
377 // let alpha = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
378 // let sum = alpha.iter().fold(0.0, |acc, x| acc + x);
379 // let n = Dirichlet::new(&alpha).unwrap();
380 // let res = n.std_dev();
381 // for i in 1..11 {
382 // let f = i as f64;
383 // assert_almost_eq!(res[i-1], (f * (sum - f) / (sum * sum * (sum + 1.0))).sqrt(), 1e-15);
384 // }
385 // }
386
387 #[test]
388 fn test_entropy() {
389 let mut n = try_create(&[0.1, 0.3, 0.5, 0.8]);
390 assert_eq!(n.entropy().unwrap(), -17.46469081094079);
391
392 n = try_create(&[0.1, 0.2, 0.3, 0.4]);
393 assert_eq!(n.entropy().unwrap(), -21.53881433791513);
394 }
395
396 macro_rules! dvec {
397 ($($x:expr),*) => (DVector::from_vec(vec![$($x),*]));
398 }
399
400 #[test]
401 fn test_pdf() {
402 let n = try_create(&[0.1, 0.3, 0.5, 0.8]);
403 assert_almost_eq!(n.pdf(&dvec![0.01, 0.03, 0.5, 0.46]), 18.77225681167061, 1e-12);
404 assert_almost_eq!(n.pdf(&dvec![0.1,0.2,0.3,0.4]), 0.8314656481199253, 1e-14);
405 }
406
407 #[test]
408 fn test_ln_pdf() {
409 let n = try_create(&[0.1, 0.3, 0.5, 0.8]);
410 assert_almost_eq!(n.ln_pdf(&dvec![0.01, 0.03, 0.5, 0.46]), 18.77225681167061f64.ln(), 1e-12);
411 assert_almost_eq!(n.ln_pdf(&dvec![0.1,0.2,0.3,0.4]), 0.8314656481199253f64.ln(), 1e-14);
412 }
413
414 #[test]
415 #[should_panic]
416 fn test_pdf_bad_input_length() {
417 let n = try_create(&[0.1, 0.3, 0.5, 0.8]);
418 n.pdf(&dvec![0.5]);
419 }
420
421 #[test]
422 #[should_panic]
423 fn test_pdf_bad_input_range() {
424 let n = try_create(&[0.1, 0.3, 0.5, 0.8]);
425 n.pdf(&dvec![1.5, 0.0, 0.0, 0.0]);
426 }
427
428 #[test]
429 #[should_panic]
430 fn test_pdf_bad_input_sum() {
431 let n = try_create(&[0.1, 0.3, 0.5, 0.8]);
432 n.pdf(&dvec![0.5, 0.25, 0.8, 0.9]);
433 }
434
435 #[test]
436 #[should_panic]
437 fn test_ln_pdf_bad_input_length() {
438 let n = try_create(&[0.1, 0.3, 0.5, 0.8]);
439 n.ln_pdf(&dvec![0.5]);
440 }
441
442 #[test]
443 #[should_panic]
444 fn test_ln_pdf_bad_input_range() {
445 let n = try_create(&[0.1, 0.3, 0.5, 0.8]);
446 n.ln_pdf(&dvec![1.5, 0.0, 0.0, 0.0]);
447 }
448
449 #[test]
450 #[should_panic]
451 fn test_ln_pdf_bad_input_sum() {
452 let n = try_create(&[0.1, 0.3, 0.5, 0.8]);
453 n.ln_pdf(&dvec![0.5, 0.25, 0.8, 0.9]);
454 }
455}