statrs/distribution/multinomial.rs
1use crate::distribution::Discrete;
2use crate::function::factorial;
3use crate::statistics::*;
4use crate::{Result, StatsError};
5use ::nalgebra::{DMatrix, DVector};
6use rand::Rng;
7
8/// Implements the
9/// [Multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution)
10/// distribution which is a generalization of the
11/// [Binomial](https://en.wikipedia.org/wiki/Binomial_distribution)
12/// distribution
13///
14/// # Examples
15///
16/// ```
17/// use statrs::distribution::Multinomial;
18/// use statrs::statistics::MeanN;
19/// use nalgebra::DVector;
20///
21/// let n = Multinomial::new(&[0.3, 0.7], 5).unwrap();
22/// assert_eq!(n.mean().unwrap(), DVector::from_vec(vec![1.5, 3.5]));
23/// ```
24#[derive(Debug, Clone, PartialEq)]
25pub struct Multinomial {
26 p: Vec<f64>,
27 n: u64,
28}
29
30impl Multinomial {
31 /// Constructs a new multinomial distribution with probabilities `p`
32 /// and `n` number of trials.
33 ///
34 /// # Errors
35 ///
36 /// Returns an error if `p` is empty, the sum of the elements
37 /// in `p` is 0, or any element in `p` is less than 0 or is `f64::NAN`
38 ///
39 /// # Note
40 ///
41 /// The elements in `p` do not need to be normalized
42 ///
43 /// # Examples
44 ///
45 /// ```
46 /// use statrs::distribution::Multinomial;
47 ///
48 /// let mut result = Multinomial::new(&[0.0, 1.0, 2.0], 3);
49 /// assert!(result.is_ok());
50 ///
51 /// result = Multinomial::new(&[0.0, -1.0, 2.0], 3);
52 /// assert!(result.is_err());
53 /// ```
54 pub fn new(p: &[f64], n: u64) -> Result<Multinomial> {
55 if !super::internal::is_valid_multinomial(p, true) {
56 Err(StatsError::BadParams)
57 } else {
58 Ok(Multinomial { p: p.to_vec(), n })
59 }
60 }
61
62 /// Returns the probabilities of the multinomial
63 /// distribution as a slice
64 ///
65 /// # Examples
66 ///
67 /// ```
68 /// use statrs::distribution::Multinomial;
69 ///
70 /// let n = Multinomial::new(&[0.0, 1.0, 2.0], 3).unwrap();
71 /// assert_eq!(n.p(), [0.0, 1.0, 2.0]);
72 /// ```
73 pub fn p(&self) -> &[f64] {
74 &self.p
75 }
76
77 /// Returns the number of trials of the multinomial
78 /// distribution
79 ///
80 /// # Examples
81 ///
82 /// ```
83 /// use statrs::distribution::Multinomial;
84 ///
85 /// let n = Multinomial::new(&[0.0, 1.0, 2.0], 3).unwrap();
86 /// assert_eq!(n.n(), 3);
87 /// ```
88 pub fn n(&self) -> u64 {
89 self.n
90 }
91}
92
93impl ::rand::distributions::Distribution<Vec<f64>> for Multinomial {
94 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64> {
95 let p_cdf = super::categorical::prob_mass_to_cdf(self.p());
96 let mut res = vec![0.0; self.p.len()];
97 for _ in 0..self.n {
98 let i = super::categorical::sample_unchecked(rng, &p_cdf);
99 let el = res.get_mut(i as usize).unwrap();
100 *el += 1.0;
101 }
102 res
103 }
104}
105
106impl MeanN<DVector<f64>> for Multinomial {
107 /// Returns the mean of the multinomial distribution
108 ///
109 /// # Formula
110 ///
111 /// ```ignore
112 /// n * p_i for i in 1...k
113 /// ```
114 ///
115 /// where `n` is the number of trials, `p_i` is the `i`th probability,
116 /// and `k` is the total number of probabilities
117 fn mean(&self) -> Option<DVector<f64>> {
118 Some(DVector::from_vec(
119 self.p.iter().map(|x| x * self.n as f64).collect(),
120 ))
121 }
122}
123
124impl VarianceN<DMatrix<f64>> for Multinomial {
125 /// Returns the variance of the multinomial distribution
126 ///
127 /// # Formula
128 ///
129 /// ```ignore
130 /// n * p_i * (1 - p_i) for i in 1...k
131 /// ```
132 ///
133 /// where `n` is the number of trials, `p_i` is the `i`th probability,
134 /// and `k` is the total number of probabilities
135 fn variance(&self) -> Option<DMatrix<f64>> {
136 let cov: Vec<_> = self
137 .p
138 .iter()
139 .map(|x| x * self.n as f64 * (1.0 - x))
140 .collect();
141 Some(DMatrix::from_diagonal(&DVector::from_vec(cov)))
142 }
143}
144
145// impl Skewness<Vec<f64>> for Multinomial {
146// /// Returns the skewness of the multinomial distribution
147// ///
148// /// # Formula
149// ///
150// /// ```ignore
151// /// (1 - 2 * p_i) / (n * p_i * (1 - p_i)) for i in 1...k
152// /// ```
153// ///
154// /// where `n` is the number of trials, `p_i` is the `i`th probability,
155// /// and `k` is the total number of probabilities
156// fn skewness(&self) -> Option<Vec<f64>> {
157// Some(
158// self.p
159// .iter()
160// .map(|x| (1.0 - 2.0 * x) / (self.n as f64 * (1.0 - x) * x).sqrt())
161// .collect(),
162// )
163// }
164// }
165
166impl<'a> Discrete<&'a [u64], f64> for Multinomial {
167 /// Calculates the probability mass function for the multinomial
168 /// distribution
169 /// with the given `x`'s corresponding to the probabilities for this
170 /// distribution
171 ///
172 /// # Panics
173 ///
174 /// If the elements in `x` do not sum to `n` or if the length of `x` is not
175 /// equivalent to the length of `p`
176 ///
177 /// # Formula
178 ///
179 /// ```ignore
180 /// (n! / x_1!...x_k!) * p_i^x_i for i in 1...k
181 /// ```
182 ///
183 /// where `n` is the number of trials, `p_i` is the `i`th probability,
184 /// `x_i` is the `i`th `x` value, and `k` is the total number of
185 /// probabilities
186 fn pmf(&self, x: &[u64]) -> f64 {
187 if self.p.len() != x.len() {
188 panic!("Expected x and p to have equal lengths.");
189 }
190 if x.iter().sum::<u64>() != self.n {
191 return 0.0;
192 }
193 let coeff = factorial::multinomial(self.n, x);
194 let val = coeff
195 * self
196 .p
197 .iter()
198 .zip(x.iter())
199 .fold(1.0, |acc, (pi, xi)| acc * pi.powf(*xi as f64));
200 val
201 }
202
203 /// Calculates the log probability mass function for the multinomial
204 /// distribution
205 /// with the given `x`'s corresponding to the probabilities for this
206 /// distribution
207 ///
208 /// # Panics
209 ///
210 /// If the elements in `x` do not sum to `n` or if the length of `x` is not
211 /// equivalent to the length of `p`
212 ///
213 /// # Formula
214 ///
215 /// ```ignore
216 /// ln((n! / x_1!...x_k!) * p_i^x_i) for i in 1...k
217 /// ```
218 ///
219 /// where `n` is the number of trials, `p_i` is the `i`th probability,
220 /// `x_i` is the `i`th `x` value, and `k` is the total number of
221 /// probabilities
222 fn ln_pmf(&self, x: &[u64]) -> f64 {
223 if self.p.len() != x.len() {
224 panic!("Expected x and p to have equal lengths.");
225 }
226 if x.iter().sum::<u64>() != self.n {
227 return f64::NEG_INFINITY;
228 }
229 let coeff = factorial::multinomial(self.n, x).ln();
230 let val = coeff
231 + self
232 .p
233 .iter()
234 .zip(x.iter())
235 .map(|(pi, xi)| *xi as f64 * pi.ln())
236 .fold(0.0, |acc, x| acc + x);
237 val
238 }
239}
240
241// TODO: fix tests
242// #[rustfmt::skip]
243// #[cfg(test)]
244// mod tests {
245// use crate::statistics::*;
246// use crate::distribution::{Discrete, Multinomial};
247// use crate::consts::ACC;
248
249// fn try_create(p: &[f64], n: u64) -> Multinomial {
250// let dist = Multinomial::new(p, n);
251// assert!(dist.is_ok());
252// dist.unwrap()
253// }
254
255// fn create_case(p: &[f64], n: u64) {
256// let dist = try_create(p, n);
257// assert_eq!(dist.p(), p);
258// assert_eq!(dist.n(), n);
259// }
260
261// fn bad_create_case(p: &[f64], n: u64) {
262// let dist = Multinomial::new(p, n);
263// assert!(dist.is_err());
264// }
265
266// fn test_case<F>(p: &[f64], n: u64, expected: &[f64], eval: F)
267// where F: Fn(Multinomial) -> Vec<f64>
268// {
269// let dist = try_create(p, n);
270// let x = eval(dist);
271// assert_eq!(*expected, *x);
272// }
273
274// fn test_almost<F>(p: &[f64], n: u64, expected: &[f64], acc: f64, eval: F)
275// where F: Fn(Multinomial) -> Vec<f64>
276// {
277// let dist = try_create(p, n);
278// let x = eval(dist);
279// assert_eq!(expected.len(), x.len());
280// for i in 0..expected.len() {
281// assert_almost_eq!(expected[i], x[i], acc);
282// }
283// }
284
285// fn test_almost_sr<F>(p: &[f64], n: u64, expected: f64, acc:f64, eval: F)
286// where F: Fn(Multinomial) -> f64
287// {
288// let dist = try_create(p, n);
289// let x = eval(dist);
290// assert_almost_eq!(expected, x, acc);
291// }
292
293// #[test]
294// fn test_create() {
295// create_case(&[1.0, 1.0, 1.0], 4);
296// create_case(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 4);
297// }
298
299// #[test]
300// fn test_bad_create() {
301// bad_create_case(&[-1.0, 1.0], 4);
302// bad_create_case(&[0.0, 0.0], 4);
303// }
304
305// #[test]
306// fn test_mean() {
307// let mean = |x: Multinomial| x.mean().unwrap();
308// test_case(&[0.3, 0.7], 5, &[1.5, 3.5], mean);
309// test_case(&[0.1, 0.3, 0.6], 10, &[1.0, 3.0, 6.0], mean);
310// test_case(&[0.15, 0.35, 0.3, 0.2], 20, &[3.0, 7.0, 6.0, 4.0], mean);
311// }
312
313// #[test]
314// fn test_variance() {
315// let variance = |x: Multinomial| x.variance().unwrap();
316// test_almost(&[0.3, 0.7], 5, &[1.05, 1.05], 1e-15, variance);
317// test_almost(&[0.1, 0.3, 0.6], 10, &[0.9, 2.1, 2.4], 1e-15, variance);
318// test_almost(&[0.15, 0.35, 0.3, 0.2], 20, &[2.55, 4.55, 4.2, 3.2], 1e-15, variance);
319// }
320
321// // #[test]
322// // fn test_skewness() {
323// // let skewness = |x: Multinomial| x.skewness().unwrap();
324// // test_almost(&[0.3, 0.7], 5, &[0.390360029179413, -0.390360029179413], 1e-15, skewness);
325// // test_almost(&[0.1, 0.3, 0.6], 10, &[0.843274042711568, 0.276026223736942, -0.12909944487358], 1e-15, skewness);
326// // test_almost(&[0.15, 0.35, 0.3, 0.2], 20, &[0.438357003759605, 0.140642169281549, 0.195180014589707, 0.335410196624968], 1e-15, skewness);
327// // }
328
329// #[test]
330// fn test_pmf() {
331// let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg);
332// test_almost_sr(&[0.3, 0.7], 10, 0.121060821, 1e-15, pmf(&[1, 9]));
333// test_almost_sr(&[0.1, 0.3, 0.6], 10, 0.105815808, 1e-15, pmf(&[1, 3, 6]));
334// test_almost_sr(&[0.15, 0.35, 0.3, 0.2], 10, 0.000145152, 1e-15, pmf(&[1, 1, 1, 7]));
335// }
336
337// #[test]
338// #[should_panic]
339// fn test_pmf_x_wrong_length() {
340// let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg);
341// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
342// n.pmf(&[1]);
343// }
344
345// #[test]
346// #[should_panic]
347// fn test_pmf_x_wrong_sum() {
348// let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg);
349// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
350// n.pmf(&[1, 3]);
351// }
352
353// #[test]
354// fn test_ln_pmf() {
355// let large_p = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
356// let n = Multinomial::new(large_p, 45).unwrap();
357// let x = &[1, 2, 3, 4, 5, 6, 7, 8, 9];
358// assert_almost_eq!(n.pmf(x).ln(), n.ln_pmf(x), 1e-13);
359// let n2 = Multinomial::new(large_p, 18).unwrap();
360// let x2 = &[1, 1, 1, 2, 2, 2, 3, 3, 3];
361// assert_almost_eq!(n2.pmf(x2).ln(), n2.ln_pmf(x2), 1e-13);
362// let n3 = Multinomial::new(large_p, 51).unwrap();
363// let x3 = &[5, 6, 7, 8, 7, 6, 5, 4, 3];
364// assert_almost_eq!(n3.pmf(x3).ln(), n3.ln_pmf(x3), 1e-13);
365// }
366
367// #[test]
368// #[should_panic]
369// fn test_ln_pmf_x_wrong_length() {
370// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
371// n.ln_pmf(&[1]);
372// }
373
374// #[test]
375// #[should_panic]
376// fn test_ln_pmf_x_wrong_sum() {
377// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
378// n.ln_pmf(&[1, 3]);
379// }
380// }