statrs/distribution/
multinomial.rs

1use crate::distribution::Discrete;
2use crate::function::factorial;
3use crate::statistics::*;
4use crate::{Result, StatsError};
5use ::nalgebra::{DMatrix, DVector};
6use rand::Rng;
7
8/// Implements the
9/// [Multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution)
10/// distribution which is a generalization of the
11/// [Binomial](https://en.wikipedia.org/wiki/Binomial_distribution)
12/// distribution
13///
14/// # Examples
15///
16/// ```
17/// use statrs::distribution::Multinomial;
18/// use statrs::statistics::MeanN;
19/// use nalgebra::DVector;
20///
21/// let n = Multinomial::new(&[0.3, 0.7], 5).unwrap();
22/// assert_eq!(n.mean().unwrap(), DVector::from_vec(vec![1.5, 3.5]));
23/// ```
24#[derive(Debug, Clone, PartialEq)]
25pub struct Multinomial {
26    p: Vec<f64>,
27    n: u64,
28}
29
30impl Multinomial {
31    /// Constructs a new multinomial distribution with probabilities `p`
32    /// and `n` number of trials.
33    ///
34    /// # Errors
35    ///
36    /// Returns an error if `p` is empty, the sum of the elements
37    /// in `p` is 0, or any element in `p` is less than 0 or is `f64::NAN`
38    ///
39    /// # Note
40    ///
41    /// The elements in `p` do not need to be normalized
42    ///
43    /// # Examples
44    ///
45    /// ```
46    /// use statrs::distribution::Multinomial;
47    ///
48    /// let mut result = Multinomial::new(&[0.0, 1.0, 2.0], 3);
49    /// assert!(result.is_ok());
50    ///
51    /// result = Multinomial::new(&[0.0, -1.0, 2.0], 3);
52    /// assert!(result.is_err());
53    /// ```
54    pub fn new(p: &[f64], n: u64) -> Result<Multinomial> {
55        if !super::internal::is_valid_multinomial(p, true) {
56            Err(StatsError::BadParams)
57        } else {
58            Ok(Multinomial { p: p.to_vec(), n })
59        }
60    }
61
62    /// Returns the probabilities of the multinomial
63    /// distribution as a slice
64    ///
65    /// # Examples
66    ///
67    /// ```
68    /// use statrs::distribution::Multinomial;
69    ///
70    /// let n = Multinomial::new(&[0.0, 1.0, 2.0], 3).unwrap();
71    /// assert_eq!(n.p(), [0.0, 1.0, 2.0]);
72    /// ```
73    pub fn p(&self) -> &[f64] {
74        &self.p
75    }
76
77    /// Returns the number of trials of the multinomial
78    /// distribution
79    ///
80    /// # Examples
81    ///
82    /// ```
83    /// use statrs::distribution::Multinomial;
84    ///
85    /// let n = Multinomial::new(&[0.0, 1.0, 2.0], 3).unwrap();
86    /// assert_eq!(n.n(), 3);
87    /// ```
88    pub fn n(&self) -> u64 {
89        self.n
90    }
91}
92
93impl ::rand::distributions::Distribution<Vec<f64>> for Multinomial {
94    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64> {
95        let p_cdf = super::categorical::prob_mass_to_cdf(self.p());
96        let mut res = vec![0.0; self.p.len()];
97        for _ in 0..self.n {
98            let i = super::categorical::sample_unchecked(rng, &p_cdf);
99            let el = res.get_mut(i as usize).unwrap();
100            *el += 1.0;
101        }
102        res
103    }
104}
105
106impl MeanN<DVector<f64>> for Multinomial {
107    /// Returns the mean of the multinomial distribution
108    ///
109    /// # Formula
110    ///
111    /// ```ignore
112    /// n * p_i for i in 1...k
113    /// ```
114    ///
115    /// where `n` is the number of trials, `p_i` is the `i`th probability,
116    /// and `k` is the total number of probabilities
117    fn mean(&self) -> Option<DVector<f64>> {
118        Some(DVector::from_vec(
119            self.p.iter().map(|x| x * self.n as f64).collect(),
120        ))
121    }
122}
123
124impl VarianceN<DMatrix<f64>> for Multinomial {
125    /// Returns the variance of the multinomial distribution
126    ///
127    /// # Formula
128    ///
129    /// ```ignore
130    /// n * p_i * (1 - p_i) for i in 1...k
131    /// ```
132    ///
133    /// where `n` is the number of trials, `p_i` is the `i`th probability,
134    /// and `k` is the total number of probabilities
135    fn variance(&self) -> Option<DMatrix<f64>> {
136        let cov: Vec<_> = self
137            .p
138            .iter()
139            .map(|x| x * self.n as f64 * (1.0 - x))
140            .collect();
141        Some(DMatrix::from_diagonal(&DVector::from_vec(cov)))
142    }
143}
144
145// impl Skewness<Vec<f64>> for Multinomial {
146//     /// Returns the skewness of the multinomial distribution
147//     ///
148//     /// # Formula
149//     ///
150//     /// ```ignore
151//     /// (1 - 2 * p_i) / (n * p_i * (1 - p_i)) for i in 1...k
152//     /// ```
153//     ///
154//     /// where `n` is the number of trials, `p_i` is the `i`th probability,
155//     /// and `k` is the total number of probabilities
156//     fn skewness(&self) -> Option<Vec<f64>> {
157//         Some(
158//             self.p
159//                 .iter()
160//                 .map(|x| (1.0 - 2.0 * x) / (self.n as f64 * (1.0 - x) * x).sqrt())
161//                 .collect(),
162//         )
163//     }
164// }
165
166impl<'a> Discrete<&'a [u64], f64> for Multinomial {
167    /// Calculates the probability mass function for the multinomial
168    /// distribution
169    /// with the given `x`'s corresponding to the probabilities for this
170    /// distribution
171    ///
172    /// # Panics
173    ///
174    /// If the elements in `x` do not sum to `n` or if the length of `x` is not
175    /// equivalent to the length of `p`
176    ///
177    /// # Formula
178    ///
179    /// ```ignore
180    /// (n! / x_1!...x_k!) * p_i^x_i for i in 1...k
181    /// ```
182    ///
183    /// where `n` is the number of trials, `p_i` is the `i`th probability,
184    /// `x_i` is the `i`th `x` value, and `k` is the total number of
185    /// probabilities
186    fn pmf(&self, x: &[u64]) -> f64 {
187        if self.p.len() != x.len() {
188            panic!("Expected x and p to have equal lengths.");
189        }
190        if x.iter().sum::<u64>() != self.n {
191            return 0.0;
192        }
193        let coeff = factorial::multinomial(self.n, x);
194        let val = coeff
195            * self
196                .p
197                .iter()
198                .zip(x.iter())
199                .fold(1.0, |acc, (pi, xi)| acc * pi.powf(*xi as f64));
200        val
201    }
202
203    /// Calculates the log probability mass function for the multinomial
204    /// distribution
205    /// with the given `x`'s corresponding to the probabilities for this
206    /// distribution
207    ///
208    /// # Panics
209    ///
210    /// If the elements in `x` do not sum to `n` or if the length of `x` is not
211    /// equivalent to the length of `p`
212    ///
213    /// # Formula
214    ///
215    /// ```ignore
216    /// ln((n! / x_1!...x_k!) * p_i^x_i) for i in 1...k
217    /// ```
218    ///
219    /// where `n` is the number of trials, `p_i` is the `i`th probability,
220    /// `x_i` is the `i`th `x` value, and `k` is the total number of
221    /// probabilities
222    fn ln_pmf(&self, x: &[u64]) -> f64 {
223        if self.p.len() != x.len() {
224            panic!("Expected x and p to have equal lengths.");
225        }
226        if x.iter().sum::<u64>() != self.n {
227            return f64::NEG_INFINITY;
228        }
229        let coeff = factorial::multinomial(self.n, x).ln();
230        let val = coeff
231            + self
232                .p
233                .iter()
234                .zip(x.iter())
235                .map(|(pi, xi)| *xi as f64 * pi.ln())
236                .fold(0.0, |acc, x| acc + x);
237        val
238    }
239}
240
241// TODO: fix tests
242// #[rustfmt::skip]
243// #[cfg(test)]
244// mod tests {
245//     use crate::statistics::*;
246//     use crate::distribution::{Discrete, Multinomial};
247//     use crate::consts::ACC;
248
249//     fn try_create(p: &[f64], n: u64) -> Multinomial {
250//         let dist = Multinomial::new(p, n);
251//         assert!(dist.is_ok());
252//         dist.unwrap()
253//     }
254
255//     fn create_case(p: &[f64], n: u64) {
256//         let dist = try_create(p, n);
257//         assert_eq!(dist.p(), p);
258//         assert_eq!(dist.n(), n);
259//     }
260
261//     fn bad_create_case(p: &[f64], n: u64) {
262//         let dist = Multinomial::new(p, n);
263//         assert!(dist.is_err());
264//     }
265
266//     fn test_case<F>(p: &[f64], n: u64, expected: &[f64], eval: F)
267//         where F: Fn(Multinomial) -> Vec<f64>
268//     {
269//         let dist = try_create(p, n);
270//         let x = eval(dist);
271//         assert_eq!(*expected, *x);
272//     }
273
274//     fn test_almost<F>(p: &[f64], n: u64, expected: &[f64], acc: f64, eval: F)
275//         where F: Fn(Multinomial) -> Vec<f64>
276//     {
277//         let dist = try_create(p, n);
278//         let x = eval(dist);
279//         assert_eq!(expected.len(), x.len());
280//         for i in 0..expected.len() {
281//             assert_almost_eq!(expected[i], x[i], acc);
282//         }
283//     }
284
285//     fn test_almost_sr<F>(p: &[f64], n: u64, expected: f64, acc:f64, eval: F)
286//         where F: Fn(Multinomial) -> f64
287//     {
288//         let dist = try_create(p, n);
289//         let x = eval(dist);
290//         assert_almost_eq!(expected, x, acc);
291//     }
292
293//     #[test]
294//     fn test_create() {
295//         create_case(&[1.0, 1.0, 1.0], 4);
296//         create_case(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 4);
297//     }
298
299//     #[test]
300//     fn test_bad_create() {
301//         bad_create_case(&[-1.0, 1.0], 4);
302//         bad_create_case(&[0.0, 0.0], 4);
303//     }
304
305//     #[test]
306//     fn test_mean() {
307//         let mean = |x: Multinomial| x.mean().unwrap();
308//         test_case(&[0.3, 0.7], 5, &[1.5, 3.5], mean);
309//         test_case(&[0.1, 0.3, 0.6], 10, &[1.0, 3.0, 6.0], mean);
310//         test_case(&[0.15, 0.35, 0.3, 0.2], 20, &[3.0, 7.0, 6.0, 4.0], mean);
311//     }
312
313//     #[test]
314//     fn test_variance() {
315//         let variance = |x: Multinomial| x.variance().unwrap();
316//         test_almost(&[0.3, 0.7], 5, &[1.05, 1.05], 1e-15, variance);
317//         test_almost(&[0.1, 0.3, 0.6], 10, &[0.9, 2.1, 2.4], 1e-15, variance);
318//         test_almost(&[0.15, 0.35, 0.3, 0.2], 20, &[2.55, 4.55, 4.2, 3.2], 1e-15, variance);
319//     }
320
321//     // #[test]
322//     // fn test_skewness() {
323//     //     let skewness = |x: Multinomial| x.skewness().unwrap();
324//     //     test_almost(&[0.3, 0.7], 5, &[0.390360029179413, -0.390360029179413], 1e-15, skewness);
325//     //     test_almost(&[0.1, 0.3, 0.6], 10, &[0.843274042711568, 0.276026223736942, -0.12909944487358], 1e-15, skewness);
326//     //     test_almost(&[0.15, 0.35, 0.3, 0.2], 20, &[0.438357003759605, 0.140642169281549, 0.195180014589707, 0.335410196624968], 1e-15, skewness);
327//     // }
328
329//     #[test]
330//     fn test_pmf() {
331//         let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg);
332//         test_almost_sr(&[0.3, 0.7], 10, 0.121060821, 1e-15, pmf(&[1, 9]));
333//         test_almost_sr(&[0.1, 0.3, 0.6], 10, 0.105815808, 1e-15, pmf(&[1, 3, 6]));
334//         test_almost_sr(&[0.15, 0.35, 0.3, 0.2], 10, 0.000145152, 1e-15, pmf(&[1, 1, 1, 7]));
335//     }
336
337//     #[test]
338//     #[should_panic]
339//     fn test_pmf_x_wrong_length() {
340//         let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg);
341//         let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
342//         n.pmf(&[1]);
343//     }
344
345//     #[test]
346//     #[should_panic]
347//     fn test_pmf_x_wrong_sum() {
348//         let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg);
349//         let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
350//         n.pmf(&[1, 3]);
351//     }
352
353//     #[test]
354//     fn test_ln_pmf() {
355//         let large_p = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
356//         let n = Multinomial::new(large_p, 45).unwrap();
357//         let x = &[1, 2, 3, 4, 5, 6, 7, 8, 9];
358//         assert_almost_eq!(n.pmf(x).ln(), n.ln_pmf(x), 1e-13);
359//         let n2 = Multinomial::new(large_p, 18).unwrap();
360//         let x2 = &[1, 1, 1, 2, 2, 2, 3, 3, 3];
361//         assert_almost_eq!(n2.pmf(x2).ln(), n2.ln_pmf(x2), 1e-13);
362//         let n3 = Multinomial::new(large_p, 51).unwrap();
363//         let x3 = &[5, 6, 7, 8, 7, 6, 5, 4, 3];
364//         assert_almost_eq!(n3.pmf(x3).ln(), n3.ln_pmf(x3), 1e-13);
365//     }
366
367//     #[test]
368//     #[should_panic]
369//     fn test_ln_pmf_x_wrong_length() {
370//         let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
371//         n.ln_pmf(&[1]);
372//     }
373
374//     #[test]
375//     #[should_panic]
376//     fn test_ln_pmf_x_wrong_sum() {
377//         let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
378//         n.ln_pmf(&[1, 3]);
379//     }
380// }