statrs/distribution/multinomial.rs
1use crate::distribution::Discrete;
2use crate::function::factorial;
3use crate::statistics::*;
4use nalgebra::{Dim, Dyn, OMatrix, OVector};
5
6/// Implements the
7/// [Multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution)
8/// distribution which is a generalization of the
9/// [Binomial](https://en.wikipedia.org/wiki/Binomial_distribution)
10/// distribution
11///
12/// # Examples
13///
14/// ```
15/// use statrs::distribution::Multinomial;
16/// use statrs::statistics::MeanN;
17/// use nalgebra::vector;
18///
19/// let n = Multinomial::new_from_nalgebra(vector![0.3, 0.7], 5).unwrap();
20/// assert_eq!(n.mean().unwrap(), (vector![1.5, 3.5]));
21/// ```
22#[derive(Debug, Clone, PartialEq)]
23pub struct Multinomial<D>
24where
25 D: Dim,
26 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
27{
28 /// normalized probabilities for each species
29 p: OVector<f64, D>,
30 /// count of trials
31 n: u64,
32}
33
34/// Represents the errors that can occur when creating a [`Multinomial`].
35#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
36#[non_exhaustive]
37pub enum MultinomialError {
38 /// Fewer than two probabilities.
39 NotEnoughProbabilities,
40
41 /// The sum of all probabilities is zero.
42 ProbabilitySumZero,
43
44 /// At least one probability is NaN, infinite or less than zero.
45 ProbabilityInvalid,
46}
47
48impl std::fmt::Display for MultinomialError {
49 #[cfg_attr(coverage_nightly, coverage(off))]
50 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
51 match self {
52 MultinomialError::NotEnoughProbabilities => write!(f, "Fewer than two probabilities"),
53 MultinomialError::ProbabilitySumZero => write!(f, "The probabilities sum up to zero"),
54 MultinomialError::ProbabilityInvalid => write!(
55 f,
56 "At least one probability is NaN, infinity or less than zero"
57 ),
58 }
59 }
60}
61
62impl std::error::Error for MultinomialError {}
63
64impl Multinomial<Dyn> {
65 /// Constructs a new multinomial distribution with probabilities `p`
66 /// and `n` number of trials.
67 ///
68 /// # Errors
69 ///
70 /// Returns an error if `p` is empty, the sum of the elements
71 /// in `p` is 0, or any element in `p` is less than 0 or is `f64::NAN`
72 ///
73 /// # Note
74 ///
75 /// The elements in `p` do not need to be normalized
76 ///
77 /// # Examples
78 ///
79 /// ```
80 /// use statrs::distribution::Multinomial;
81 ///
82 /// let mut result = Multinomial::new(vec![0.0, 1.0, 2.0], 3);
83 /// assert!(result.is_ok());
84 ///
85 /// result = Multinomial::new(vec![0.0, -1.0, 2.0], 3);
86 /// assert!(result.is_err());
87 /// ```
88 pub fn new(p: Vec<f64>, n: u64) -> Result<Self, MultinomialError> {
89 Self::new_from_nalgebra(p.into(), n)
90 }
91}
92
93impl<D> Multinomial<D>
94where
95 D: Dim,
96 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
97{
98 pub fn new_from_nalgebra(mut p: OVector<f64, D>, n: u64) -> Result<Self, MultinomialError> {
99 if p.len() < 2 {
100 return Err(MultinomialError::NotEnoughProbabilities);
101 }
102
103 let mut sum = 0.0;
104 for &val in &p {
105 if val.is_nan() || val < 0.0 {
106 return Err(MultinomialError::ProbabilityInvalid);
107 }
108
109 sum += val;
110 }
111
112 if sum == 0.0 {
113 return Err(MultinomialError::ProbabilitySumZero);
114 }
115
116 p.unscale_mut(p.lp_norm(1));
117 Ok(Self { p, n })
118 }
119
120 /// Returns the probabilities of the multinomial
121 /// distribution as a slice
122 ///
123 /// # Examples
124 ///
125 /// ```
126 /// use statrs::distribution::Multinomial;
127 /// use nalgebra::dvector;
128 ///
129 /// let n = Multinomial::new(vec![0.0, 1.0, 2.0], 3).unwrap();
130 /// assert_eq!(*n.p(), dvector![0.0, 1.0/3.0, 2.0/3.0]);
131 /// ```
132 pub fn p(&self) -> &OVector<f64, D> {
133 &self.p
134 }
135
136 /// Returns the number of trials of the multinomial
137 /// distribution
138 ///
139 /// # Examples
140 ///
141 /// ```
142 /// use statrs::distribution::Multinomial;
143 ///
144 /// let n = Multinomial::new(vec![0.0, 1.0, 2.0], 3).unwrap();
145 /// assert_eq!(n.n(), 3);
146 /// ```
147 pub fn n(&self) -> u64 {
148 self.n
149 }
150}
151
152impl<D> std::fmt::Display for Multinomial<D>
153where
154 D: Dim,
155 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
156{
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 write!(f, "Multinom({:#?},{})", self.p, self.n)
159 }
160}
161
162#[cfg(feature = "rand")]
163#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
164impl<D> ::rand::distributions::Distribution<OVector<u64, D>> for Multinomial<D>
165where
166 D: Dim,
167 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
168 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
169{
170 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<u64, D> {
171 sample_generic(self, rng)
172 }
173}
174
175#[cfg(feature = "rand")]
176#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
177impl<D> ::rand::distributions::Distribution<OVector<f64, D>> for Multinomial<D>
178where
179 D: Dim,
180 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
181{
182 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<f64, D> {
183 sample_generic(self, rng)
184 }
185}
186
187#[cfg(feature = "rand")]
188fn sample_generic<D, R, T>(dist: &Multinomial<D>, rng: &mut R) -> OVector<T, D>
189where
190 D: Dim,
191 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
192 R: ::rand::Rng + ?Sized,
193 T: ::num_traits::Num + ::nalgebra::Scalar + ::std::ops::AddAssign<T>,
194{
195 use nalgebra::Const;
196
197 let p_cdf = super::categorical::prob_mass_to_cdf(dist.p().as_slice());
198 let mut res = OVector::zeros_generic(dist.p.shape_generic().0, Const::<1>);
199 for _ in 0..dist.n {
200 let i = super::categorical::sample_unchecked(rng, &p_cdf);
201 res[i] += T::one();
202 }
203 res
204}
205
206impl<D> MeanN<OVector<f64, D>> for Multinomial<D>
207where
208 D: Dim,
209 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
210{
211 /// Returns the mean of the multinomial distribution
212 ///
213 /// # Formula
214 ///
215 /// ```text
216 /// n * p_i for i in 1...k
217 /// ```
218 ///
219 /// where `n` is the number of trials, `p_i` is the `i`th probability,
220 /// and `k` is the total number of probabilities
221 fn mean(&self) -> Option<OVector<f64, D>> {
222 Some(self.p.map(|x| x * self.n as f64))
223 }
224}
225
226impl<D> VarianceN<OMatrix<f64, D, D>> for Multinomial<D>
227where
228 D: Dim,
229 nalgebra::DefaultAllocator:
230 nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
231{
232 /// Returns the variance of the multinomial distribution
233 ///
234 /// # Formula
235 ///
236 /// ```text
237 /// n * p_i * (1 - p_i) for i in 1...k
238 /// ```
239 ///
240 /// where `n` is the number of trials, `p_i` is the `i`th probability,
241 /// and `k` is the total number of probabilities
242 fn variance(&self) -> Option<OMatrix<f64, D, D>> {
243 let mut cov = OMatrix::from_diagonal(&self.p.map(|x| x * (1.0 - x)));
244 let mut offdiag = |x: usize, y: usize| {
245 let elt = -self.p[x] * self.p[y];
246 // cov[(x, y)] = elt;
247 cov[(y, x)] = elt;
248 };
249
250 for i in 0..self.p.len() {
251 for j in 0..i {
252 offdiag(i, j);
253 }
254 }
255 cov.fill_lower_triangle_with_upper_triangle();
256 Some(cov.scale(self.n as f64))
257 }
258}
259
260// impl Skewness<Vec<f64>> for Multinomial {
261// /// Returns the skewness of the multinomial distribution
262// ///
263// /// # Formula
264// ///
265// /// ```text
266// /// (1 - 2 * p_i) / (n * p_i * (1 - p_i)) for i in 1...k
267// /// ```
268// ///
269// /// where `n` is the number of trials, `p_i` is the `i`th probability,
270// /// and `k` is the total number of probabilities
271// fn skewness(&self) -> Option<Vec<f64>> {
272// Some(
273// self.p
274// .iter()
275// .map(|x| (1.0 - 2.0 * x) / (self.n as f64 * (1.0 - x) * x).sqrt())
276// .collect(),
277// )
278// }
279// }
280
281impl<D> Discrete<&OVector<u64, D>, f64> for Multinomial<D>
282where
283 D: Dim,
284 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
285{
286 /// Calculates the probability mass function for the multinomial
287 /// distribution
288 /// with the given `x`'s corresponding to the probabilities for this
289 /// distribution
290 ///
291 /// # Panics
292 ///
293 /// If length of `x` is not equal to length of `p`
294 ///
295 /// # Formula
296 ///
297 /// ```text
298 /// (n! / x_1!...x_k!) * p_i^x_i for i in 1...k
299 /// ```
300 ///
301 /// where `n` is the number of trials, `p_i` is the `i`th probability,
302 /// `x_i` is the `i`th `x` value, and `k` is the total number of
303 /// probabilities
304 fn pmf(&self, x: &OVector<u64, D>) -> f64 {
305 if self.p.len() != x.len() {
306 panic!("Expected x and p to have equal lengths.");
307 }
308 if x.iter().sum::<u64>() != self.n {
309 return 0.0;
310 }
311 let coeff = factorial::multinomial(self.n, x.as_slice());
312 let val = coeff
313 * self
314 .p
315 .iter()
316 .zip(x.iter())
317 .fold(1.0, |acc, (pi, xi)| acc * pi.powf(*xi as f64));
318 val
319 }
320
321 /// Calculates the log probability mass function for the multinomial
322 /// distribution
323 /// with the given `x`'s corresponding to the probabilities for this
324 /// distribution
325 ///
326 /// # Panics
327 ///
328 /// If length of `x` is not equal to length of `p`
329 ///
330 /// # Formula
331 ///
332 /// ```text
333 /// ln((n! / x_1!...x_k!) * p_i^x_i) for i in 1...k
334 /// ```
335 ///
336 /// where `n` is the number of trials, `p_i` is the `i`th probability,
337 /// `x_i` is the `i`th `x` value, and `k` is the total number of
338 /// probabilities
339 fn ln_pmf(&self, x: &OVector<u64, D>) -> f64 {
340 if self.p.len() != x.len() {
341 panic!("Expected x and p to have equal lengths.");
342 }
343 if x.iter().sum::<u64>() != self.n {
344 return f64::NEG_INFINITY;
345 }
346 let coeff = factorial::multinomial(self.n, x.as_slice()).ln();
347 let val = coeff
348 + self
349 .p
350 .iter()
351 .zip(x.iter())
352 .map(|(pi, xi)| *xi as f64 * pi.ln())
353 .fold(0.0, |acc, x| acc + x);
354 val
355 }
356}
357
358#[rustfmt::skip]
359#[cfg(test)]
360mod tests {
361 use crate::{
362 distribution::{Discrete, Multinomial, MultinomialError},
363 statistics::{MeanN, VarianceN},
364 };
365 use nalgebra::{dmatrix, dvector, vector, DimMin, Dyn, OVector};
366 use std::fmt::{Debug, Display};
367
368 fn try_create<D>(p: OVector<f64, D>, n: u64) -> Multinomial<D>
369 where
370 D: DimMin<D, Output = D>,
371 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
372 {
373 let mvn = Multinomial::new_from_nalgebra(p, n);
374 assert!(mvn.is_ok());
375 mvn.unwrap()
376 }
377
378 fn bad_create_case<D>(p: OVector<f64, D>, n: u64) -> MultinomialError
379 where
380 D: DimMin<D, Output = D>,
381 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
382 {
383 let dd = Multinomial::new_from_nalgebra(p, n);
384 assert!(dd.is_err());
385 dd.unwrap_err()
386 }
387
388 fn test_almost<F, T, D>(p: OVector<f64, D>, n: u64, expected: T, acc: f64, eval: F)
389 where
390 T: Debug + Display + approx::RelativeEq<Epsilon = f64>,
391 F: FnOnce(Multinomial<D>) -> T,
392 D: DimMin<D, Output = D>,
393 nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
394 {
395 let dd = try_create(p, n);
396 let x = eval(dd);
397 assert_relative_eq!(expected, x, epsilon = acc);
398 }
399
400 #[test]
401 fn test_create() {
402 assert_relative_eq!(
403 *try_create(vector![1.0, 1.0, 1.0], 4).p(),
404 vector![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]
405 );
406 try_create(dvector![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 4);
407 }
408
409 #[test]
410 fn test_bad_create() {
411 assert_eq!(
412 bad_create_case(vector![0.5], 4),
413 MultinomialError::NotEnoughProbabilities,
414 );
415
416 assert_eq!(
417 bad_create_case(vector![-1.0, 2.0], 4),
418 MultinomialError::ProbabilityInvalid,
419 );
420
421 assert_eq!(
422 bad_create_case(vector![0.0, 0.0], 4),
423 MultinomialError::ProbabilitySumZero,
424 );
425 assert_eq!(
426 bad_create_case(vector![1.0, f64::NAN], 4),
427 MultinomialError::ProbabilityInvalid,
428 );
429 }
430
431 #[test]
432 fn test_mean() {
433 let mean = |x: Multinomial<_>| x.mean().unwrap();
434 test_almost(dvector![0.3, 0.7], 5, dvector![1.5, 3.5], 1e-12, mean);
435 test_almost(
436 dvector![0.1, 0.3, 0.6],
437 10,
438 dvector![1.0, 3.0, 6.0],
439 1e-12,
440 mean,
441 );
442 test_almost(
443 dvector![1.0, 3.0, 6.0],
444 10,
445 dvector![1.0, 3.0, 6.0],
446 1e-12,
447 mean,
448 );
449 test_almost(
450 dvector![0.15, 0.35, 0.3, 0.2],
451 20,
452 dvector![3.0, 7.0, 6.0, 4.0],
453 1e-12,
454 mean,
455 );
456 }
457
458 #[test]
459 fn test_variance() {
460 let variance = |x: Multinomial<_>| x.variance().unwrap();
461 test_almost(
462 dvector![0.3, 0.7],
463 5,
464 dmatrix![1.05, -1.05;
465 -1.05, 1.05],
466 1e-15,
467 variance,
468 );
469 test_almost(
470 dvector![0.1, 0.3, 0.6],
471 10,
472 dmatrix![0.9, -0.3, -0.6;
473 -0.3, 2.1, -1.8;
474 -0.6, -1.8, 2.4;
475 ],
476 1e-15,
477 variance,
478 );
479 test_almost(
480 dvector![0.15, 0.35, 0.3, 0.2],
481 20,
482 dmatrix![2.55, -1.05, -0.90, -0.60;
483 -1.05, 4.55, -2.10, -1.40;
484 -0.90, -2.10, 4.20, -1.20;
485 -0.60, -1.40, -1.20, 3.20;
486 ],
487 1e-15,
488 variance,
489 );
490 }
491
492 // // #[test]
493 // // fn test_skewness() {
494 // // let skewness = |x: Multinomial| x.skewness().unwrap();
495 // // test_almost(&[0.3, 0.7], 5, &[0.390360029179413, -0.390360029179413], 1e-15, skewness);
496 // // test_almost(&[0.1, 0.3, 0.6], 10, &[0.843274042711568, 0.276026223736942, -0.12909944487358], 1e-15, skewness);
497 // // test_almost(&[0.15, 0.35, 0.3, 0.2], 20, &[0.438357003759605, 0.140642169281549, 0.195180014589707, 0.335410196624968], 1e-15, skewness);
498 // // }
499
500 #[test]
501 fn test_pmf() {
502 let pmf = |arg: OVector<u64, Dyn>| move |x: Multinomial<_>| x.pmf(&arg);
503 test_almost(
504 dvector![0.3, 0.7],
505 10,
506 0.121060821,
507 1e-15,
508 pmf(dvector![1, 9]),
509 );
510 test_almost(
511 dvector![0.1, 0.3, 0.6],
512 10,
513 0.105815808,
514 1e-15,
515 pmf(dvector![1, 3, 6]),
516 );
517 test_almost(
518 dvector![0.15, 0.35, 0.3, 0.2],
519 10,
520 0.000145152,
521 1e-15,
522 pmf(dvector![1, 1, 1, 7]),
523 );
524 }
525
526 #[test]
527 fn test_error_is_sync_send() {
528 fn assert_sync_send<T: Sync + Send>() {}
529 assert_sync_send::<MultinomialError>();
530 }
531
532 // #[test]
533 // #[should_panic]
534 // fn test_pmf_x_wrong_length() {
535 // let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg);
536 // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
537 // n.pmf(&[1]);
538 // }
539
540 // #[test]
541 // #[should_panic]
542 // fn test_pmf_x_wrong_sum() {
543 // let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg);
544 // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
545 // n.pmf(&[1, 3]);
546 // }
547
548 // #[test]
549 // fn test_ln_pmf() {
550 // let large_p = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
551 // let n = Multinomial::new(large_p, 45).unwrap();
552 // let x = &[1, 2, 3, 4, 5, 6, 7, 8, 9];
553 // assert_almost_eq!(n.pmf(x).ln(), n.ln_pmf(x), 1e-13);
554 // let n2 = Multinomial::new(large_p, 18).unwrap();
555 // let x2 = &[1, 1, 1, 2, 2, 2, 3, 3, 3];
556 // assert_almost_eq!(n2.pmf(x2).ln(), n2.ln_pmf(x2), 1e-13);
557 // let n3 = Multinomial::new(large_p, 51).unwrap();
558 // let x3 = &[5, 6, 7, 8, 7, 6, 5, 4, 3];
559 // assert_almost_eq!(n3.pmf(x3).ln(), n3.ln_pmf(x3), 1e-13);
560 // }
561
562 // #[test]
563 // #[should_panic]
564 // fn test_ln_pmf_x_wrong_length() {
565 // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
566 // n.ln_pmf(&[1]);
567 // }
568
569 // #[test]
570 // #[should_panic]
571 // fn test_ln_pmf_x_wrong_sum() {
572 // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
573 // n.ln_pmf(&[1, 3]);
574 // }
575}