statrs/distribution/
multinomial.rs

1use crate::distribution::Discrete;
2use crate::function::factorial;
3use crate::statistics::*;
4use nalgebra::{Dim, Dyn, OMatrix, OVector};
5
6/// Implements the
7/// [Multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution)
8/// distribution which is a generalization of the
9/// [Binomial](https://en.wikipedia.org/wiki/Binomial_distribution)
10/// distribution
11///
12/// # Examples
13///
14/// ```
15/// use statrs::distribution::Multinomial;
16/// use statrs::statistics::MeanN;
17/// use nalgebra::vector;
18///
19/// let n = Multinomial::new_from_nalgebra(vector![0.3, 0.7], 5).unwrap();
20/// assert_eq!(n.mean().unwrap(), (vector![1.5, 3.5]));
21/// ```
22#[derive(Debug, Clone, PartialEq)]
23pub struct Multinomial<D>
24where
25    D: Dim,
26    nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
27{
28    /// normalized probabilities for each species
29    p: OVector<f64, D>,
30    /// count of trials
31    n: u64,
32}
33
34/// Represents the errors that can occur when creating a [`Multinomial`].
35#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
36#[non_exhaustive]
37pub enum MultinomialError {
38    /// Fewer than two probabilities.
39    NotEnoughProbabilities,
40
41    /// The sum of all probabilities is zero.
42    ProbabilitySumZero,
43
44    /// At least one probability is NaN, infinite or less than zero.
45    ProbabilityInvalid,
46}
47
48impl std::fmt::Display for MultinomialError {
49    #[cfg_attr(coverage_nightly, coverage(off))]
50    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
51        match self {
52            MultinomialError::NotEnoughProbabilities => write!(f, "Fewer than two probabilities"),
53            MultinomialError::ProbabilitySumZero => write!(f, "The probabilities sum up to zero"),
54            MultinomialError::ProbabilityInvalid => write!(
55                f,
56                "At least one probability is NaN, infinity or less than zero"
57            ),
58        }
59    }
60}
61
62impl std::error::Error for MultinomialError {}
63
64impl Multinomial<Dyn> {
65    /// Constructs a new multinomial distribution with probabilities `p`
66    /// and `n` number of trials.
67    ///
68    /// # Errors
69    ///
70    /// Returns an error if `p` is empty, the sum of the elements
71    /// in `p` is 0, or any element in `p` is less than 0 or is `f64::NAN`
72    ///
73    /// # Note
74    ///
75    /// The elements in `p` do not need to be normalized
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use statrs::distribution::Multinomial;
81    ///
82    /// let mut result = Multinomial::new(vec![0.0, 1.0, 2.0], 3);
83    /// assert!(result.is_ok());
84    ///
85    /// result = Multinomial::new(vec![0.0, -1.0, 2.0], 3);
86    /// assert!(result.is_err());
87    /// ```
88    pub fn new(p: Vec<f64>, n: u64) -> Result<Self, MultinomialError> {
89        Self::new_from_nalgebra(p.into(), n)
90    }
91}
92
93impl<D> Multinomial<D>
94where
95    D: Dim,
96    nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
97{
98    pub fn new_from_nalgebra(mut p: OVector<f64, D>, n: u64) -> Result<Self, MultinomialError> {
99        if p.len() < 2 {
100            return Err(MultinomialError::NotEnoughProbabilities);
101        }
102
103        let mut sum = 0.0;
104        for &val in &p {
105            if val.is_nan() || val < 0.0 {
106                return Err(MultinomialError::ProbabilityInvalid);
107            }
108
109            sum += val;
110        }
111
112        if sum == 0.0 {
113            return Err(MultinomialError::ProbabilitySumZero);
114        }
115
116        p.unscale_mut(p.lp_norm(1));
117        Ok(Self { p, n })
118    }
119
120    /// Returns the probabilities of the multinomial
121    /// distribution as a slice
122    ///
123    /// # Examples
124    ///
125    /// ```
126    /// use statrs::distribution::Multinomial;
127    /// use nalgebra::dvector;
128    ///
129    /// let n = Multinomial::new(vec![0.0, 1.0, 2.0], 3).unwrap();
130    /// assert_eq!(*n.p(), dvector![0.0, 1.0/3.0, 2.0/3.0]);
131    /// ```
132    pub fn p(&self) -> &OVector<f64, D> {
133        &self.p
134    }
135
136    /// Returns the number of trials of the multinomial
137    /// distribution
138    ///
139    /// # Examples
140    ///
141    /// ```
142    /// use statrs::distribution::Multinomial;
143    ///
144    /// let n = Multinomial::new(vec![0.0, 1.0, 2.0], 3).unwrap();
145    /// assert_eq!(n.n(), 3);
146    /// ```
147    pub fn n(&self) -> u64 {
148        self.n
149    }
150}
151
152impl<D> std::fmt::Display for Multinomial<D>
153where
154    D: Dim,
155    nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
156{
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        write!(f, "Multinom({:#?},{})", self.p, self.n)
159    }
160}
161
162#[cfg(feature = "rand")]
163#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
164impl<D> ::rand::distributions::Distribution<OVector<u64, D>> for Multinomial<D>
165where
166    D: Dim,
167    nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
168    nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
169{
170    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<u64, D> {
171        sample_generic(self, rng)
172    }
173}
174
175#[cfg(feature = "rand")]
176#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
177impl<D> ::rand::distributions::Distribution<OVector<f64, D>> for Multinomial<D>
178where
179    D: Dim,
180    nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
181{
182    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> OVector<f64, D> {
183        sample_generic(self, rng)
184    }
185}
186
187#[cfg(feature = "rand")]
188fn sample_generic<D, R, T>(dist: &Multinomial<D>, rng: &mut R) -> OVector<T, D>
189where
190    D: Dim,
191    nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
192    R: ::rand::Rng + ?Sized,
193    T: ::num_traits::Num + ::nalgebra::Scalar + ::std::ops::AddAssign<T>,
194{
195    use nalgebra::Const;
196
197    let p_cdf = super::categorical::prob_mass_to_cdf(dist.p().as_slice());
198    let mut res = OVector::zeros_generic(dist.p.shape_generic().0, Const::<1>);
199    for _ in 0..dist.n {
200        let i = super::categorical::sample_unchecked(rng, &p_cdf);
201        res[i] += T::one();
202    }
203    res
204}
205
206impl<D> MeanN<OVector<f64, D>> for Multinomial<D>
207where
208    D: Dim,
209    nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
210{
211    /// Returns the mean of the multinomial distribution
212    ///
213    /// # Formula
214    ///
215    /// ```text
216    /// n * p_i for i in 1...k
217    /// ```
218    ///
219    /// where `n` is the number of trials, `p_i` is the `i`th probability,
220    /// and `k` is the total number of probabilities
221    fn mean(&self) -> Option<OVector<f64, D>> {
222        Some(self.p.map(|x| x * self.n as f64))
223    }
224}
225
226impl<D> VarianceN<OMatrix<f64, D, D>> for Multinomial<D>
227where
228    D: Dim,
229    nalgebra::DefaultAllocator:
230        nalgebra::allocator::Allocator<D> + nalgebra::allocator::Allocator<D, D>,
231{
232    /// Returns the variance of the multinomial distribution
233    ///
234    /// # Formula
235    ///
236    /// ```text
237    /// n * p_i * (1 - p_i) for i in 1...k
238    /// ```
239    ///
240    /// where `n` is the number of trials, `p_i` is the `i`th probability,
241    /// and `k` is the total number of probabilities
242    fn variance(&self) -> Option<OMatrix<f64, D, D>> {
243        let mut cov = OMatrix::from_diagonal(&self.p.map(|x| x * (1.0 - x)));
244        let mut offdiag = |x: usize, y: usize| {
245            let elt = -self.p[x] * self.p[y];
246            // cov[(x, y)] = elt;
247            cov[(y, x)] = elt;
248        };
249
250        for i in 0..self.p.len() {
251            for j in 0..i {
252                offdiag(i, j);
253            }
254        }
255        cov.fill_lower_triangle_with_upper_triangle();
256        Some(cov.scale(self.n as f64))
257    }
258}
259
260// impl Skewness<Vec<f64>> for Multinomial {
261//     /// Returns the skewness of the multinomial distribution
262//     ///
263//     /// # Formula
264//     ///
265//     /// ```text
266//     /// (1 - 2 * p_i) / (n * p_i * (1 - p_i)) for i in 1...k
267//     /// ```
268//     ///
269//     /// where `n` is the number of trials, `p_i` is the `i`th probability,
270//     /// and `k` is the total number of probabilities
271//     fn skewness(&self) -> Option<Vec<f64>> {
272//         Some(
273//             self.p
274//                 .iter()
275//                 .map(|x| (1.0 - 2.0 * x) / (self.n as f64 * (1.0 - x) * x).sqrt())
276//                 .collect(),
277//         )
278//     }
279// }
280
281impl<D> Discrete<&OVector<u64, D>, f64> for Multinomial<D>
282where
283    D: Dim,
284    nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
285{
286    /// Calculates the probability mass function for the multinomial
287    /// distribution
288    /// with the given `x`'s corresponding to the probabilities for this
289    /// distribution
290    ///
291    /// # Panics
292    ///
293    /// If length of `x` is not equal to length of `p`
294    ///
295    /// # Formula
296    ///
297    /// ```text
298    /// (n! / x_1!...x_k!) * p_i^x_i for i in 1...k
299    /// ```
300    ///
301    /// where `n` is the number of trials, `p_i` is the `i`th probability,
302    /// `x_i` is the `i`th `x` value, and `k` is the total number of
303    /// probabilities
304    fn pmf(&self, x: &OVector<u64, D>) -> f64 {
305        if self.p.len() != x.len() {
306            panic!("Expected x and p to have equal lengths.");
307        }
308        if x.iter().sum::<u64>() != self.n {
309            return 0.0;
310        }
311        let coeff = factorial::multinomial(self.n, x.as_slice());
312        let val = coeff
313            * self
314                .p
315                .iter()
316                .zip(x.iter())
317                .fold(1.0, |acc, (pi, xi)| acc * pi.powf(*xi as f64));
318        val
319    }
320
321    /// Calculates the log probability mass function for the multinomial
322    /// distribution
323    /// with the given `x`'s corresponding to the probabilities for this
324    /// distribution
325    ///
326    /// # Panics
327    ///
328    /// If length of `x` is not equal to length of `p`
329    ///
330    /// # Formula
331    ///
332    /// ```text
333    /// ln((n! / x_1!...x_k!) * p_i^x_i) for i in 1...k
334    /// ```
335    ///
336    /// where `n` is the number of trials, `p_i` is the `i`th probability,
337    /// `x_i` is the `i`th `x` value, and `k` is the total number of
338    /// probabilities
339    fn ln_pmf(&self, x: &OVector<u64, D>) -> f64 {
340        if self.p.len() != x.len() {
341            panic!("Expected x and p to have equal lengths.");
342        }
343        if x.iter().sum::<u64>() != self.n {
344            return f64::NEG_INFINITY;
345        }
346        let coeff = factorial::multinomial(self.n, x.as_slice()).ln();
347        let val = coeff
348            + self
349                .p
350                .iter()
351                .zip(x.iter())
352                .map(|(pi, xi)| *xi as f64 * pi.ln())
353                .fold(0.0, |acc, x| acc + x);
354        val
355    }
356}
357
358#[rustfmt::skip]
359#[cfg(test)]
360mod tests {
361    use crate::{
362        distribution::{Discrete, Multinomial, MultinomialError},
363        statistics::{MeanN, VarianceN},
364    };
365    use nalgebra::{dmatrix, dvector, vector, DimMin, Dyn, OVector};
366    use std::fmt::{Debug, Display};
367
368    fn try_create<D>(p: OVector<f64, D>, n: u64) -> Multinomial<D>
369    where
370        D: DimMin<D, Output = D>,
371        nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
372    {
373        let mvn = Multinomial::new_from_nalgebra(p, n);
374        assert!(mvn.is_ok());
375        mvn.unwrap()
376    }
377
378    fn bad_create_case<D>(p: OVector<f64, D>, n: u64) -> MultinomialError
379    where
380        D: DimMin<D, Output = D>,
381        nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
382    {
383        let dd = Multinomial::new_from_nalgebra(p, n);
384        assert!(dd.is_err());
385        dd.unwrap_err()
386    }
387
388    fn test_almost<F, T, D>(p: OVector<f64, D>, n: u64, expected: T, acc: f64, eval: F)
389    where
390        T: Debug + Display + approx::RelativeEq<Epsilon = f64>,
391        F: FnOnce(Multinomial<D>) -> T,
392        D: DimMin<D, Output = D>,
393        nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
394    {
395        let dd = try_create(p, n);
396        let x = eval(dd);
397        assert_relative_eq!(expected, x, epsilon = acc);
398    }
399
400    #[test]
401    fn test_create() {
402        assert_relative_eq!(
403            *try_create(vector![1.0, 1.0, 1.0], 4).p(),
404            vector![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]
405        );
406        try_create(dvector![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 4);
407    }
408
409    #[test]
410    fn test_bad_create() {
411        assert_eq!(
412            bad_create_case(vector![0.5], 4),
413            MultinomialError::NotEnoughProbabilities,
414        );
415
416        assert_eq!(
417            bad_create_case(vector![-1.0, 2.0], 4),
418            MultinomialError::ProbabilityInvalid,
419        );
420
421        assert_eq!(
422            bad_create_case(vector![0.0, 0.0], 4),
423            MultinomialError::ProbabilitySumZero,
424        );
425        assert_eq!(
426            bad_create_case(vector![1.0, f64::NAN], 4),
427            MultinomialError::ProbabilityInvalid,
428        );
429    }
430
431    #[test]
432    fn test_mean() {
433        let mean = |x: Multinomial<_>| x.mean().unwrap();
434        test_almost(dvector![0.3, 0.7], 5, dvector![1.5, 3.5], 1e-12, mean);
435        test_almost(
436            dvector![0.1, 0.3, 0.6],
437            10,
438            dvector![1.0, 3.0, 6.0],
439            1e-12,
440            mean,
441        );
442        test_almost(
443            dvector![1.0, 3.0, 6.0],
444            10,
445            dvector![1.0, 3.0, 6.0],
446            1e-12,
447            mean,
448        );
449        test_almost(
450            dvector![0.15, 0.35, 0.3, 0.2],
451            20,
452            dvector![3.0, 7.0, 6.0, 4.0],
453            1e-12,
454            mean,
455        );
456    }
457
458    #[test]
459    fn test_variance() {
460        let variance = |x: Multinomial<_>| x.variance().unwrap();
461        test_almost(
462            dvector![0.3, 0.7],
463            5,
464            dmatrix![1.05, -1.05; 
465                    -1.05,  1.05],
466            1e-15,
467            variance,
468        );
469        test_almost(
470            dvector![0.1, 0.3, 0.6],
471            10,
472            dmatrix![0.9, -0.3, -0.6;
473                    -0.3,  2.1, -1.8;
474                    -0.6, -1.8,  2.4;
475            ],
476            1e-15,
477            variance,
478        );
479        test_almost(
480            dvector![0.15, 0.35, 0.3, 0.2],
481            20,
482            dmatrix![2.55, -1.05, -0.90, -0.60;
483                    -1.05,  4.55, -2.10, -1.40;
484                    -0.90, -2.10,  4.20, -1.20;
485                    -0.60, -1.40, -1.20,  3.20;
486            ],
487            1e-15,
488            variance,
489        );
490    }
491
492    //     // #[test]
493    //     // fn test_skewness() {
494    //     //     let skewness = |x: Multinomial| x.skewness().unwrap();
495    //     //     test_almost(&[0.3, 0.7], 5, &[0.390360029179413, -0.390360029179413], 1e-15, skewness);
496    //     //     test_almost(&[0.1, 0.3, 0.6], 10, &[0.843274042711568, 0.276026223736942, -0.12909944487358], 1e-15, skewness);
497    //     //     test_almost(&[0.15, 0.35, 0.3, 0.2], 20, &[0.438357003759605, 0.140642169281549, 0.195180014589707, 0.335410196624968], 1e-15, skewness);
498    //     // }
499
500    #[test]
501    fn test_pmf() {
502        let pmf = |arg: OVector<u64, Dyn>| move |x: Multinomial<_>| x.pmf(&arg);
503        test_almost(
504            dvector![0.3, 0.7],
505            10,
506            0.121060821,
507            1e-15,
508            pmf(dvector![1, 9]),
509        );
510        test_almost(
511            dvector![0.1, 0.3, 0.6],
512            10,
513            0.105815808,
514            1e-15,
515            pmf(dvector![1, 3, 6]),
516        );
517        test_almost(
518            dvector![0.15, 0.35, 0.3, 0.2],
519            10,
520            0.000145152,
521            1e-15,
522            pmf(dvector![1, 1, 1, 7]),
523        );
524    }
525
526    #[test]
527    fn test_error_is_sync_send() {
528        fn assert_sync_send<T: Sync + Send>() {}
529        assert_sync_send::<MultinomialError>();
530    }
531
532    //     #[test]
533    //     #[should_panic]
534    //     fn test_pmf_x_wrong_length() {
535    //         let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg);
536    //         let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
537    //         n.pmf(&[1]);
538    //     }
539
540    //     #[test]
541    //     #[should_panic]
542    //     fn test_pmf_x_wrong_sum() {
543    //         let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg);
544    //         let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
545    //         n.pmf(&[1, 3]);
546    //     }
547
548    //     #[test]
549    //     fn test_ln_pmf() {
550    //         let large_p = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
551    //         let n = Multinomial::new(large_p, 45).unwrap();
552    //         let x = &[1, 2, 3, 4, 5, 6, 7, 8, 9];
553    //         assert_almost_eq!(n.pmf(x).ln(), n.ln_pmf(x), 1e-13);
554    //         let n2 = Multinomial::new(large_p, 18).unwrap();
555    //         let x2 = &[1, 1, 1, 2, 2, 2, 3, 3, 3];
556    //         assert_almost_eq!(n2.pmf(x2).ln(), n2.ln_pmf(x2), 1e-13);
557    //         let n3 = Multinomial::new(large_p, 51).unwrap();
558    //         let x3 = &[5, 6, 7, 8, 7, 6, 5, 4, 3];
559    //         assert_almost_eq!(n3.pmf(x3).ln(), n3.ln_pmf(x3), 1e-13);
560    //     }
561
562    //     #[test]
563    //     #[should_panic]
564    //     fn test_ln_pmf_x_wrong_length() {
565    //         let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
566    //         n.ln_pmf(&[1]);
567    //     }
568
569    //     #[test]
570    //     #[should_panic]
571    //     fn test_ln_pmf_x_wrong_sum() {
572    //         let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
573    //         n.ln_pmf(&[1, 3]);
574    //     }
575}