statrs/distribution/
categorical.rs

1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::statistics::*;
3use std::f64;
4
5/// Implements the
6/// [Categorical](https://en.wikipedia.org/wiki/Categorical_distribution)
7/// distribution, also known as the generalized Bernoulli or discrete
8/// distribution
9///
10/// # Examples
11///
12/// ```
13/// use statrs::distribution::{Categorical, Discrete};
14/// use statrs::statistics::Distribution;
15/// use statrs::prec;
16///
17/// let n = Categorical::new(&[0.0, 1.0, 2.0]).unwrap();
18/// assert!(prec::almost_eq(n.mean().unwrap(), 5.0 / 3.0, 1e-15));
19/// assert_eq!(n.pmf(1), 1.0 / 3.0);
20/// ```
21#[derive(Clone, PartialEq, Debug)]
22pub struct Categorical {
23    norm_pmf: Vec<f64>,
24    cdf: Vec<f64>,
25    sf: Vec<f64>,
26}
27
28/// Represents the errors that can occur when creating a [`Categorical`].
29#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
30#[non_exhaustive]
31pub enum CategoricalError {
32    /// The probability mass is empty.
33    ProbMassEmpty,
34
35    /// The probabilities sums up to zero.
36    ProbMassSumZero,
37
38    /// The probability mass contains at least one element which is NaN or less than zero.
39    ProbMassHasInvalidElements,
40}
41
42impl std::fmt::Display for CategoricalError {
43    #[cfg_attr(coverage_nightly, coverage(off))]
44    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
45        match self {
46            CategoricalError::ProbMassEmpty => write!(f, "Probability mass is empty"),
47            CategoricalError::ProbMassSumZero => write!(f, "Probabilities sum up to zero"),
48            CategoricalError::ProbMassHasInvalidElements => write!(
49                f,
50                "Probability mass contains at least one element which is NaN or less than zero"
51            ),
52        }
53    }
54}
55
56impl std::error::Error for CategoricalError {}
57
58impl Categorical {
59    /// Constructs a new categorical distribution
60    /// with the probabilities masses defined by `prob_mass`
61    ///
62    /// # Errors
63    ///
64    /// Returns an error if `prob_mass` is empty, the sum of
65    /// the elements in `prob_mass` is 0, or any element is less than
66    /// 0 or is `f64::NAN`
67    ///
68    /// # Note
69    ///
70    /// The elements in `prob_mass` do not need to be normalized
71    ///
72    /// # Examples
73    ///
74    /// ```
75    /// use statrs::distribution::Categorical;
76    ///
77    /// let mut result = Categorical::new(&[0.0, 1.0, 2.0]);
78    /// assert!(result.is_ok());
79    ///
80    /// result = Categorical::new(&[0.0, -1.0, 2.0]);
81    /// assert!(result.is_err());
82    /// ```
83    pub fn new(prob_mass: &[f64]) -> Result<Categorical, CategoricalError> {
84        if prob_mass.is_empty() {
85            return Err(CategoricalError::ProbMassEmpty);
86        }
87
88        let mut prob_sum = 0.0;
89        for &p in prob_mass {
90            if p.is_nan() || p < 0.0 {
91                return Err(CategoricalError::ProbMassHasInvalidElements);
92            }
93
94            prob_sum += p;
95        }
96
97        if prob_sum == 0.0 {
98            return Err(CategoricalError::ProbMassSumZero);
99        }
100
101        // extract un-normalized cdf
102        let cdf = prob_mass_to_cdf(prob_mass);
103        // extract un-normalized sf
104        let sf = cdf_to_sf(&cdf);
105        // extract normalized probability mass
106        let sum = cdf[cdf.len() - 1];
107        let mut norm_pmf = vec![0.0; prob_mass.len()];
108        norm_pmf
109            .iter_mut()
110            .zip(prob_mass.iter())
111            .for_each(|(np, pm)| *np = *pm / sum);
112        Ok(Categorical { norm_pmf, cdf, sf })
113    }
114
115    fn cdf_max(&self) -> f64 {
116        *self.cdf.last().unwrap()
117    }
118}
119
120impl std::fmt::Display for Categorical {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        write!(f, "Cat({:#?})", self.norm_pmf)
123    }
124}
125
126#[cfg(feature = "rand")]
127#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
128impl ::rand::distributions::Distribution<usize> for Categorical {
129    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> usize {
130        sample_unchecked(rng, &self.cdf)
131    }
132}
133
134#[cfg(feature = "rand")]
135#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
136impl ::rand::distributions::Distribution<u64> for Categorical {
137    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
138        sample_unchecked(rng, &self.cdf) as u64
139    }
140}
141
142#[cfg(feature = "rand")]
143#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
144impl ::rand::distributions::Distribution<f64> for Categorical {
145    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
146        sample_unchecked(rng, &self.cdf) as f64
147    }
148}
149
150impl DiscreteCDF<u64, f64> for Categorical {
151    /// Calculates the cumulative distribution function for the categorical
152    /// distribution at `x`
153    ///
154    /// # Formula
155    ///
156    /// ```text
157    /// sum(p_j) from 0..x
158    /// ```
159    ///
160    /// where `p_j` is the probability mass for the `j`th category
161    fn cdf(&self, x: u64) -> f64 {
162        if x >= self.cdf.len() as u64 {
163            1.0
164        } else {
165            self.cdf.get(x as usize).unwrap() / self.cdf_max()
166        }
167    }
168
169    /// Calculates the survival function for the categorical distribution
170    /// at `x`
171    ///
172    /// # Formula
173    ///
174    /// ```text
175    /// [ sum(p_j) from x..end ]
176    /// ```
177    fn sf(&self, x: u64) -> f64 {
178        if x >= self.sf.len() as u64 {
179            0.0
180        } else {
181            self.sf.get(x as usize).unwrap() / self.cdf_max()
182        }
183    }
184
185    /// Calculates the inverse cumulative distribution function for the
186    /// categorical
187    /// distribution at `x`
188    ///
189    /// # Panics
190    ///
191    /// If `x <= 0.0` or `x >= 1.0`
192    ///
193    /// # Formula
194    ///
195    /// ```text
196    /// i
197    /// ```
198    ///
199    /// where `i` is the first index such that `x < f(i)`
200    /// and `f(x)` is defined as `p_x + f(x - 1)` and `f(0) = p_0` where
201    /// `p_x` is the `x`th probability mass
202    fn inverse_cdf(&self, x: f64) -> u64 {
203        if x >= 1.0 || x <= 0.0 {
204            panic!("x must be in [0, 1]")
205        }
206        let denorm_prob = x * self.cdf_max();
207        binary_index(&self.cdf, denorm_prob) as u64
208    }
209}
210
211impl Min<u64> for Categorical {
212    /// Returns the minimum value in the domain of the
213    /// categorical distribution representable by a 64-bit
214    /// integer
215    ///
216    /// # Formula
217    ///
218    /// ```text
219    /// 0
220    /// ```
221    fn min(&self) -> u64 {
222        0
223    }
224}
225
226impl Max<u64> for Categorical {
227    /// Returns the maximum value in the domain of the
228    /// categorical distribution representable by a 64-bit
229    /// integer
230    ///
231    /// # Formula
232    ///
233    /// ```text
234    /// n
235    /// ```
236    fn max(&self) -> u64 {
237        self.cdf.len() as u64 - 1
238    }
239}
240
241impl Distribution<f64> for Categorical {
242    /// Returns the mean of the categorical distribution
243    ///
244    /// # Formula
245    ///
246    /// ```text
247    /// Σ(j * p_j)
248    /// ```
249    ///
250    /// where `p_j` is the `j`th probability mass,
251    /// `Σ` is the sum from `0` to `k - 1`,
252    /// and `k` is the number of categories
253    fn mean(&self) -> Option<f64> {
254        Some(
255            self.norm_pmf
256                .iter()
257                .enumerate()
258                .fold(0.0, |acc, (idx, &val)| acc + idx as f64 * val),
259        )
260    }
261
262    /// Returns the variance of the categorical distribution
263    ///
264    /// # Formula
265    ///
266    /// ```text
267    /// Σ(p_j * (j - μ)^2)
268    /// ```
269    ///
270    /// where `p_j` is the `j`th probability mass, `μ` is the mean,
271    /// `Σ` is the sum from `0` to `k - 1`,
272    /// and `k` is the number of categories
273    fn variance(&self) -> Option<f64> {
274        let mu = self.mean()?;
275        let var = self
276            .norm_pmf
277            .iter()
278            .enumerate()
279            .fold(0.0, |acc, (idx, &val)| {
280                let r = idx as f64 - mu;
281                acc + r * r * val
282            });
283        Some(var)
284    }
285
286    /// Returns the entropy of the categorical distribution
287    ///
288    /// # Formula
289    ///
290    /// ```text
291    /// -Σ(p_j * ln(p_j))
292    /// ```
293    ///
294    /// where `p_j` is the `j`th probability mass,
295    /// `Σ` is the sum from `0` to `k - 1`,
296    /// and `k` is the number of categories
297    fn entropy(&self) -> Option<f64> {
298        let entr = -self
299            .norm_pmf
300            .iter()
301            .filter(|&&p| p > 0.0)
302            .map(|p| p * p.ln())
303            .sum::<f64>();
304        Some(entr)
305    }
306}
307impl Median<f64> for Categorical {
308    /// Returns the median of the categorical distribution
309    ///
310    /// # Formula
311    ///
312    /// ```text
313    /// CDF^-1(0.5)
314    /// ```
315    fn median(&self) -> f64 {
316        self.inverse_cdf(0.5) as f64
317    }
318}
319
320impl Discrete<u64, f64> for Categorical {
321    /// Calculates the probability mass function for the categorical
322    /// distribution at `x`
323    ///
324    /// # Formula
325    ///
326    /// ```text
327    /// p_x
328    /// ```
329    fn pmf(&self, x: u64) -> f64 {
330        *self.norm_pmf.get(x as usize).unwrap_or(&0.0)
331    }
332
333    /// Calculates the log probability mass function for the categorical
334    /// distribution at `x`
335    fn ln_pmf(&self, x: u64) -> f64 {
336        self.pmf(x).ln()
337    }
338}
339
340/// Draws a sample from the categorical distribution described by `cdf`
341/// without doing any bounds checking
342#[cfg(feature = "rand")]
343#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
344pub fn sample_unchecked<R: ::rand::Rng + ?Sized>(rng: &mut R, cdf: &[f64]) -> usize {
345    let draw = rng.gen::<f64>() * cdf.last().unwrap();
346    cdf.iter().position(|val| *val >= draw).unwrap()
347}
348
349/// Computes the cdf from the given probability masses. Performs
350/// no parameter or bounds checking.
351pub fn prob_mass_to_cdf(prob_mass: &[f64]) -> Vec<f64> {
352    let mut cdf = Vec::with_capacity(prob_mass.len());
353    prob_mass.iter().fold(0.0, |s, p| {
354        let sum = s + p;
355        cdf.push(sum);
356        sum
357    });
358    cdf
359}
360
361/// Computes the sf from the given cumulative densities.
362/// Performs no parameter or bounds checking.
363pub fn cdf_to_sf(cdf: &[f64]) -> Vec<f64> {
364    let max = *cdf.last().unwrap();
365    cdf.iter().map(|x| max - x).collect()
366}
367
368// Returns the index of val if placed into the sorted search array.
369// If val is greater than all elements, it therefore would return
370// the length of the array (N). If val is less than all elements, it would
371// return 0. Otherwise val returns the index of the first element larger than
372// it within the search array.
373fn binary_index(search: &[f64], val: f64) -> usize {
374    use std::cmp;
375
376    let mut low = 0_isize;
377    let mut high = search.len() as isize - 1;
378    while low <= high {
379        let mid = low + ((high - low) / 2);
380        let el = *search.get(mid as usize).unwrap();
381        if el > val {
382            high = mid - 1;
383        } else if el < val {
384            low = mid.saturating_add(1);
385        } else {
386            return mid as usize;
387        }
388    }
389    cmp::min(search.len(), cmp::max(low, 0) as usize)
390}
391
392#[test]
393fn test_prob_mass_to_cdf() {
394    let arr = [0.0, 0.5, 0.5, 3.0, 1.1];
395    let res = prob_mass_to_cdf(&arr);
396    assert_eq!(res, [0.0, 0.5, 1.0, 4.0, 5.1]);
397}
398
399#[test]
400fn test_binary_index() {
401    let arr = [0.0, 3.0, 5.0, 9.0, 10.0];
402    assert_eq!(0, binary_index(&arr, -1.0));
403    assert_eq!(2, binary_index(&arr, 5.0));
404    assert_eq!(3, binary_index(&arr, 5.2));
405    assert_eq!(5, binary_index(&arr, 10.1));
406}
407
408#[rustfmt::skip]
409#[cfg(test)]
410mod tests {
411    use super::*;
412    use crate::distribution::internal::*;
413    use crate::testing_boiler;
414
415    testing_boiler!(prob_mass: &[f64]; Categorical; CategoricalError);
416
417    #[test]
418    fn test_create() {
419        create_ok(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
420    }
421
422    #[test]
423    fn test_bad_create() {
424        let invalid: &[(&[f64], CategoricalError)] = &[
425            (&[], CategoricalError::ProbMassEmpty),
426            (&[-1.0, 1.0], CategoricalError::ProbMassHasInvalidElements),
427            (&[0.0, 0.0, 0.0], CategoricalError::ProbMassSumZero),
428        ];
429
430        for &(prob_mass, err) in invalid {
431            test_create_err(prob_mass, err);
432        }
433    }
434
435    #[test]
436    fn test_mean() {
437        let mean = |x: Categorical| x.mean().unwrap();
438        test_exact(&[0.0, 0.25, 0.5, 0.25], 2.0, mean);
439        test_exact(&[0.0, 1.0, 2.0, 1.0], 2.0, mean);
440        test_exact(&[0.0, 0.5, 0.5], 1.5, mean);
441        test_exact(&[0.75, 0.25], 0.25, mean);
442        test_exact(&[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 5.0, mean);
443    }
444
445    #[test]
446    fn test_variance() {
447        let variance = |x: Categorical| x.variance().unwrap();
448        test_exact(&[0.0, 0.25, 0.5, 0.25], 0.5, variance);
449        test_exact(&[0.0, 1.0, 2.0, 1.0], 0.5, variance);
450        test_exact(&[0.0, 0.5, 0.5], 0.25, variance);
451        test_exact(&[0.75, 0.25], 0.1875, variance);
452        test_exact(&[1.0, 0.0, 1.0], 1.0, variance);
453    }
454
455    #[test]
456    fn test_entropy() {
457        let entropy = |x: Categorical| x.entropy().unwrap();
458        test_exact(&[0.0, 1.0], 0.0, entropy);
459        test_absolute(&[0.0, 1.0, 1.0], 2f64.ln(), 1e-15, entropy);
460        test_absolute(&[1.0, 1.0, 1.0], 3f64.ln(), 1e-15, entropy);
461        test_absolute(&vec![1.0; 100], 100f64.ln(), 1e-14, entropy);
462        test_absolute(&[0.0, 0.25, 0.5, 0.25], 1.0397207708399179, 1e-15, entropy);
463    }
464
465    #[test]
466    fn test_median() {
467        let median = |x: Categorical| x.median();
468        test_exact(&[0.0, 3.0, 1.0, 1.0], 1.0, median);
469        test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, median);
470    }
471
472    #[test]
473    fn test_min_max() {
474        let min = |x: Categorical| x.min();
475        let max = |x: Categorical| x.max();
476        test_exact(&[4.0, 2.5, 2.5, 1.0], 0, min);
477        test_exact(&[4.0, 2.5, 2.5, 1.0], 3, max);
478    }
479
480    #[test]
481    fn test_pmf() {
482        let pmf = |arg: u64| move |x: Categorical| x.pmf(arg);
483        test_exact(&[0.0, 0.25, 0.5, 0.25], 0.0, pmf(0));
484        test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(1));
485        test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(3));
486    }
487
488    #[test]
489    fn test_pmf_x_too_high() {
490        let pmf = |arg: u64| move |x: Categorical| x.pmf(arg);
491        test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, pmf(4));
492    }
493
494    #[test]
495    fn test_ln_pmf() {
496        let ln_pmf = |arg: u64| move |x: Categorical| x.ln_pmf(arg);
497        test_exact(&[0.0, 0.25, 0.5, 0.25], 0f64.ln(), ln_pmf(0));
498        test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(1));
499        test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(3));
500    }
501
502    #[test]
503    fn test_ln_pmf_x_too_high() {
504        let ln_pmf = |arg: u64| move |x: Categorical| x.ln_pmf(arg);
505        test_exact(&[4.0, 2.5, 2.5, 1.0], f64::NEG_INFINITY, ln_pmf(4));
506    }
507
508    #[test]
509    fn test_cdf() {
510        let cdf = |arg: u64| move |x: Categorical| x.cdf(arg);
511        test_exact(&[0.0, 3.0, 1.0, 1.0], 3.0 / 5.0, cdf(1));
512        test_exact(&[1.0, 1.0, 1.0, 1.0], 0.25, cdf(0));
513        test_exact(&[4.0, 2.5, 2.5, 1.0], 0.4, cdf(0));
514        test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(3));
515        test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4));
516    }
517
518    #[test]
519    fn test_sf() {
520        let sf = |arg: u64| move |x: Categorical| x.sf(arg);
521        test_exact(&[0.0, 3.0, 1.0, 1.0], 2.0 / 5.0, sf(1));
522        test_exact(&[1.0, 1.0, 1.0, 1.0], 0.75, sf(0));
523        test_exact(&[4.0, 2.5, 2.5, 1.0], 0.6, sf(0));
524        test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(3));
525        test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4));
526    }
527
528    #[test]
529    fn test_cdf_input_high() {
530        let cdf = |arg: u64| move |x: Categorical| x.cdf(arg);
531        test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4));
532    }
533
534    #[test]
535    fn test_sf_input_high() {
536        let sf = |arg: u64| move |x: Categorical| x.sf(arg);
537        test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4));
538    }
539
540    #[test]
541    fn test_cdf_sf_mirror() {
542        let mass = [4.0, 2.5, 2.5, 1.0];
543        let cat = Categorical::new(&mass).unwrap();
544        assert_eq!(cat.cdf(0), 1.-cat.sf(0));
545        assert_eq!(cat.cdf(1), 1.-cat.sf(1));
546        assert_eq!(cat.cdf(2), 1.-cat.sf(2));
547        assert_eq!(cat.cdf(3), 1.-cat.sf(3));
548    }
549
550    #[test]
551    fn test_inverse_cdf() {
552        let inverse_cdf = |arg: f64| move |x: Categorical| x.inverse_cdf(arg);
553        test_exact(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.2));
554        test_exact(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.5));
555        test_exact(&[0.0, 3.0, 1.0, 1.0], 3, inverse_cdf(0.95));
556        test_exact(&[4.0, 2.5, 2.5, 1.0], 0, inverse_cdf(0.2));
557        test_exact(&[4.0, 2.5, 2.5, 1.0], 1, inverse_cdf(0.5));
558        test_exact(&[4.0, 2.5, 2.5, 1.0], 3, inverse_cdf(0.95));
559    }
560
561    #[test]
562    #[should_panic]
563    fn test_inverse_cdf_input_low() {
564        let dist = create_ok(&[4.0, 2.5, 2.5, 1.0]);
565        dist.inverse_cdf(0.0);
566    }
567
568    #[test]
569    #[should_panic]
570    fn test_inverse_cdf_input_high() {
571        let dist = create_ok(&[4.0, 2.5, 2.5, 1.0]);
572        dist.inverse_cdf(1.0);
573    }
574
575    #[test]
576    fn test_discrete() {
577        test::check_discrete_distribution(&create_ok(&[1.0, 2.0, 3.0, 4.0]), 4);
578        test::check_discrete_distribution(&create_ok(&[0.0, 1.0, 2.0, 3.0, 4.0]), 5);
579    }
580}