statrs/distribution/
categorical.rs

1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::statistics::*;
3use crate::{Result, StatsError};
4use rand::Rng;
5use std::f64;
6
7/// Implements the
8/// [Categorical](https://en.wikipedia.org/wiki/Categorical_distribution)
9/// distribution, also known as the generalized Bernoulli or discrete
10/// distribution
11///
12/// # Examples
13///
14/// ```
15///
16/// use statrs::distribution::{Categorical, Discrete};
17/// use statrs::statistics::Distribution;
18/// use statrs::prec;
19///
20/// let n = Categorical::new(&[0.0, 1.0, 2.0]).unwrap();
21/// assert!(prec::almost_eq(n.mean().unwrap(), 5.0 / 3.0, 1e-15));
22/// assert_eq!(n.pmf(1), 1.0 / 3.0);
23/// ```
24#[derive(Debug, Clone, PartialEq)]
25pub struct Categorical {
26    norm_pmf: Vec<f64>,
27    cdf: Vec<f64>,
28    sf: Vec<f64>
29}
30
31impl Categorical {
32    /// Constructs a new categorical distribution
33    /// with the probabilities masses defined by `prob_mass`
34    ///
35    /// # Errors
36    ///
37    /// Returns an error if `prob_mass` is empty, the sum of
38    /// the elements in `prob_mass` is 0, or any element is less than
39    /// 0 or is `f64::NAN`
40    ///
41    /// # Note
42    ///
43    /// The elements in `prob_mass` do not need to be normalized
44    ///
45    /// # Examples
46    ///
47    /// ```
48    /// use statrs::distribution::Categorical;
49    ///
50    /// let mut result = Categorical::new(&[0.0, 1.0, 2.0]);
51    /// assert!(result.is_ok());
52    ///
53    /// result = Categorical::new(&[0.0, -1.0, 2.0]);
54    /// assert!(result.is_err());
55    /// ```
56    pub fn new(prob_mass: &[f64]) -> Result<Categorical> {
57        if !super::internal::is_valid_multinomial(prob_mass, true) {
58            Err(StatsError::BadParams)
59        } else {
60            // extract un-normalized cdf
61            let cdf = prob_mass_to_cdf(prob_mass);
62            // extract un-normalized sf
63            let sf = cdf_to_sf(&cdf);
64            // extract normalized probability mass
65            let sum = cdf[cdf.len() - 1];
66            let mut norm_pmf = vec![0.0; prob_mass.len()];
67            norm_pmf
68                .iter_mut()
69                .zip(prob_mass.iter())
70                .for_each(|(np, pm)| *np = *pm / sum);
71            Ok(Categorical { norm_pmf, cdf, sf })
72        }
73    }
74
75    fn cdf_max(&self) -> f64 {
76        *self.cdf.last().unwrap()
77    }
78}
79
80impl ::rand::distributions::Distribution<f64> for Categorical {
81    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
82        sample_unchecked(rng, &self.cdf)
83    }
84}
85
86impl DiscreteCDF<u64, f64> for Categorical {
87    /// Calculates the cumulative distribution function for the categorical
88    /// distribution at `x`
89    ///
90    /// # Formula
91    ///
92    /// ```ignore
93    /// sum(p_j) from 0..x
94    /// ```
95    ///
96    /// where `p_j` is the probability mass for the `j`th category
97    fn cdf(&self, x: u64) -> f64 {
98        if x >= self.cdf.len() as u64 {
99            1.0
100        } else {
101            self.cdf.get(x as usize).unwrap() / self.cdf_max()
102        }
103    }
104
105    /// Calculates the survival function for the categorical distribution
106    /// at `x`
107    ///
108    /// # Formula
109    ///
110    /// ```ignore
111    /// [ sum(p_j) from x..end ]
112    /// ```
113    fn sf(&self, x: u64) -> f64 {
114        if x >= self.sf.len() as u64 {
115            0.0
116        } else {
117            self.sf.get(x as usize).unwrap() / self.cdf_max()
118        }
119    }
120
121    /// Calculates the inverse cumulative distribution function for the
122    /// categorical
123    /// distribution at `x`
124    ///
125    /// # Panics
126    ///
127    /// If `x <= 0.0` or `x >= 1.0`
128    ///
129    /// # Formula
130    ///
131    /// ```ignore
132    /// i
133    /// ```
134    ///
135    /// where `i` is the first index such that `x < f(i)`
136    /// and `f(x)` is defined as `p_x + f(x - 1)` and `f(0) = p_0` where
137    /// `p_x` is the `x`th probability mass
138    fn inverse_cdf(&self, x: f64) -> u64 {
139        if x >= 1.0 || x <= 0.0 {
140            panic!("x must be in [0, 1]")
141        }
142        let denorm_prob = x * self.cdf_max();
143        binary_index(&self.cdf, denorm_prob) as u64
144    }
145}
146
147impl Min<u64> for Categorical {
148    /// Returns the minimum value in the domain of the
149    /// categorical distribution representable by a 64-bit
150    /// integer
151    ///
152    /// # Formula
153    ///
154    /// ```ignore
155    /// 0
156    /// ```
157    fn min(&self) -> u64 {
158        0
159    }
160}
161
162impl Max<u64> for Categorical {
163    /// Returns the maximum value in the domain of the
164    /// categorical distribution representable by a 64-bit
165    /// integer
166    ///
167    /// # Formula
168    ///
169    /// ```ignore
170    /// n
171    /// ```
172    fn max(&self) -> u64 {
173        self.cdf.len() as u64 - 1
174    }
175}
176
177impl Distribution<f64> for Categorical {
178    /// Returns the mean of the categorical distribution
179    ///
180    /// # Formula
181    ///
182    /// ```ignore
183    /// Σ(j * p_j)
184    /// ```
185    ///
186    /// where `p_j` is the `j`th probability mass,
187    /// `Σ` is the sum from `0` to `k - 1`,
188    /// and `k` is the number of categories
189    fn mean(&self) -> Option<f64> {
190        Some(
191            self.norm_pmf
192                .iter()
193                .enumerate()
194                .fold(0.0, |acc, (idx, &val)| acc + idx as f64 * val),
195        )
196    }
197    /// Returns the variance of the categorical distribution
198    ///
199    /// # Formula
200    ///
201    /// ```ignore
202    /// Σ(p_j * (j - μ)^2)
203    /// ```
204    ///
205    /// where `p_j` is the `j`th probability mass, `μ` is the mean,
206    /// `Σ` is the sum from `0` to `k - 1`,
207    /// and `k` is the number of categories
208    fn variance(&self) -> Option<f64> {
209        let mu = self.mean()?;
210        let var = self
211            .norm_pmf
212            .iter()
213            .enumerate()
214            .fold(0.0, |acc, (idx, &val)| {
215                let r = idx as f64 - mu;
216                acc + r * r * val
217            });
218        Some(var)
219    }
220    /// Returns the entropy of the categorical distribution
221    ///
222    /// # Formula
223    ///
224    /// ```ignore
225    /// -Σ(p_j * ln(p_j))
226    /// ```
227    ///
228    /// where `p_j` is the `j`th probability mass,
229    /// `Σ` is the sum from `0` to `k - 1`,
230    /// and `k` is the number of categories
231    fn entropy(&self) -> Option<f64> {
232        let entr = -self
233            .norm_pmf
234            .iter()
235            .filter(|&&p| p > 0.0)
236            .map(|p| p * p.ln())
237            .sum::<f64>();
238        Some(entr)
239    }
240}
241impl Median<f64> for Categorical {
242    /// Returns the median of the categorical distribution
243    ///
244    /// # Formula
245    ///
246    /// ```ignore
247    /// CDF^-1(0.5)
248    /// ```
249    fn median(&self) -> f64 {
250        self.inverse_cdf(0.5) as f64
251    }
252}
253
254impl Discrete<u64, f64> for Categorical {
255    /// Calculates the probability mass function for the categorical
256    /// distribution at `x`
257    ///
258    /// # Formula
259    ///
260    /// ```ignore
261    /// p_x
262    /// ```
263    fn pmf(&self, x: u64) -> f64 {
264        *self.norm_pmf.get(x as usize).unwrap_or(&0.0)
265    }
266
267    /// Calculates the log probability mass function for the categorical
268    /// distribution at `x`
269    fn ln_pmf(&self, x: u64) -> f64 {
270        self.pmf(x).ln()
271    }
272}
273
274/// Draws a sample from the categorical distribution described by `cdf`
275/// without doing any bounds checking
276pub fn sample_unchecked<R: Rng + ?Sized>(rng: &mut R, cdf: &[f64]) -> f64 {
277    let draw = rng.gen::<f64>() * cdf.last().unwrap();
278    cdf.iter()
279        .enumerate()
280        .find(|(_, val)| **val >= draw)
281        .map(|(i, _)| i)
282        .unwrap() as f64
283}
284
285/// Computes the cdf from the given probability masses. Performs
286/// no parameter or bounds checking.
287pub fn prob_mass_to_cdf(prob_mass: &[f64]) -> Vec<f64> {
288    let mut cdf = Vec::with_capacity(prob_mass.len());
289    prob_mass.iter().fold(0.0, |s, p| {
290        let sum = s + p;
291        cdf.push(sum);
292        sum
293    });
294    cdf
295}
296
297/// Computes the sf from the given cumulative densities. 
298/// Performs no parameter or bounds checking.
299pub fn cdf_to_sf(cdf: &[f64]) -> Vec<f64> {
300    let max = *cdf.last().unwrap();
301    cdf.iter().map(|x| max - x).collect()
302}
303
304// Returns the index of val if placed into the sorted search array.
305// If val is greater than all elements, it therefore would return
306// the length of the array (N). If val is less than all elements, it would
307// return 0. Otherwise val returns the index of the first element larger than
308// it within the search array.
309fn binary_index(search: &[f64], val: f64) -> usize {
310    use std::cmp;
311
312    let mut low = 0_isize;
313    let mut high = search.len() as isize - 1;
314    while low <= high {
315        let mid = low + ((high - low) / 2);
316        let el = *search.get(mid as usize).unwrap();
317        if el > val {
318            high = mid - 1;
319        } else if el < val {
320            low = mid.saturating_add(1);
321        } else {
322            return mid as usize;
323        }
324    }
325    cmp::min(search.len(), cmp::max(low, 0) as usize)
326}
327
328#[test]
329fn test_prob_mass_to_cdf() {
330    let arr = [0.0, 0.5, 0.5, 3.0, 1.1];
331    let res = prob_mass_to_cdf(&arr);
332    assert_eq!(res, [0.0, 0.5, 1.0, 4.0, 5.1]);
333}
334
335#[test]
336fn test_binary_index() {
337    let arr = [0.0, 3.0, 5.0, 9.0, 10.0];
338    assert_eq!(0, binary_index(&arr, -1.0));
339    assert_eq!(2, binary_index(&arr, 5.0));
340    assert_eq!(3, binary_index(&arr, 5.2));
341    assert_eq!(5, binary_index(&arr, 10.1));
342}
343
344#[rustfmt::skip]
345#[cfg(all(test, feature = "nightly"))]
346mod tests {
347    use std::fmt::Debug;
348    use crate::statistics::*;
349    use crate::distribution::{Categorical, Discrete, DiscreteCDF};
350    use crate::distribution::internal::*;
351    use crate::consts::ACC;
352
353    fn try_create(prob_mass: &[f64]) -> Categorical {
354        let n = Categorical::new(prob_mass);
355        assert!(n.is_ok());
356        n.unwrap()
357    }
358
359    fn create_case(prob_mass: &[f64]) {
360        try_create(prob_mass);
361    }
362
363    fn bad_create_case(prob_mass: &[f64]) {
364        let n = Categorical::new(prob_mass);
365        assert!(n.is_err());
366    }
367
368    fn get_value<T, F>(prob_mass: &[f64], eval: F) -> T
369        where T: PartialEq + Debug,
370              F: Fn(Categorical) -> T
371    {
372        let n = try_create(prob_mass);
373        eval(n)
374    }
375
376    fn test_case<T, F>(prob_mass: &[f64], expected: T, eval: F)
377        where T: PartialEq + Debug,
378              F: Fn(Categorical) -> T
379    {
380        let x = get_value(prob_mass, eval);
381        assert_eq!(expected, x);
382    }
383
384    fn test_almost<F>(prob_mass: &[f64], expected: f64, acc: f64, eval: F)
385        where F: Fn(Categorical) -> f64
386    {
387        let x = get_value(prob_mass, eval);
388        assert_almost_eq!(expected, x, acc);
389    }
390
391    #[test]
392    fn test_create() {
393        create_case(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
394    }
395
396    #[test]
397    fn test_bad_create() {
398        bad_create_case(&[-1.0, 1.0]);
399        bad_create_case(&[0.0, 0.0]);
400    }
401
402    #[test]
403    fn test_mean() {
404        let mean = |x: Categorical| x.mean().unwrap();
405        test_case(&[0.0, 0.25, 0.5, 0.25], 2.0, mean);
406        test_case(&[0.0, 1.0, 2.0, 1.0], 2.0, mean);
407        test_case(&[0.0, 0.5, 0.5], 1.5, mean);
408        test_case(&[0.75, 0.25], 0.25, mean);
409        test_case(&[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 5.0, mean);
410    }
411
412    #[test]
413    fn test_variance() {
414        let variance = |x: Categorical| x.variance().unwrap();
415        test_case(&[0.0, 0.25, 0.5, 0.25], 0.5, variance);
416        test_case(&[0.0, 1.0, 2.0, 1.0], 0.5, variance);
417        test_case(&[0.0, 0.5, 0.5], 0.25, variance);
418        test_case(&[0.75, 0.25], 0.1875, variance);
419        test_case(&[1.0, 0.0, 1.0], 1.0, variance);
420    }
421
422    #[test]
423    fn test_entropy() {
424        let entropy = |x: Categorical| x.entropy().unwrap();
425        test_case(&[0.0, 1.0], 0.0, entropy);
426        test_almost(&[0.0, 1.0, 1.0], 2f64.ln(), 1e-15, entropy);
427        test_almost(&[1.0, 1.0, 1.0], 3f64.ln(), 1e-15, entropy);
428        test_almost(&vec![1.0; 100], 100f64.ln(), 1e-14, entropy);
429        test_almost(&[0.0, 0.25, 0.5, 0.25], 1.0397207708399179, 1e-15, entropy);
430    }
431
432    #[test]
433    fn test_median() {
434        let median = |x: Categorical| x.median();
435        test_case(&[0.0, 3.0, 1.0, 1.0], 1.0, median);
436        test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, median);
437    }
438
439    #[test]
440    fn test_min_max() {
441        let min = |x: Categorical| x.min();
442        let max = |x: Categorical| x.max();
443        test_case(&[4.0, 2.5, 2.5, 1.0], 0, min);
444        test_case(&[4.0, 2.5, 2.5, 1.0], 3, max);
445    }
446
447    #[test]
448    fn test_pmf() {
449        let pmf = |arg: u64| move |x: Categorical| x.pmf(arg);
450        test_case(&[0.0, 0.25, 0.5, 0.25], 0.0, pmf(0));
451        test_case(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(1));
452        test_case(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(3));
453    }
454
455    #[test]
456    fn test_pmf_x_too_high() {
457        let pmf = |arg: u64| move |x: Categorical| x.pmf(arg);
458        test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, pmf(4));
459    }
460
461    #[test]
462    fn test_ln_pmf() {
463        let ln_pmf = |arg: u64| move |x: Categorical| x.ln_pmf(arg);
464        test_case(&[0.0, 0.25, 0.5, 0.25], 0f64.ln(), ln_pmf(0));
465        test_case(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(1));
466        test_case(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(3));
467    }
468
469    #[test]
470    fn test_ln_pmf_x_too_high() {
471        let ln_pmf = |arg: u64| move |x: Categorical| x.ln_pmf(arg);
472        test_case(&[4.0, 2.5, 2.5, 1.0], f64::NEG_INFINITY, ln_pmf(4));
473    }
474
475    #[test]
476    fn test_cdf() {
477        let cdf = |arg: u64| move |x: Categorical| x.cdf(arg);
478        test_case(&[0.0, 3.0, 1.0, 1.0], 3.0 / 5.0, cdf(1));
479        test_case(&[1.0, 1.0, 1.0, 1.0], 0.25, cdf(0));
480        test_case(&[4.0, 2.5, 2.5, 1.0], 0.4, cdf(0));
481        test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(3));
482        test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4));
483    }
484
485    #[test]
486    fn test_sf() {
487        let sf = |arg: u64| move |x: Categorical| x.sf(arg);
488        test_case(&[0.0, 3.0, 1.0, 1.0], 2.0 / 5.0, sf(1));
489        test_case(&[1.0, 1.0, 1.0, 1.0], 0.75, sf(0));
490        test_case(&[4.0, 2.5, 2.5, 1.0], 0.6, sf(0));
491        test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(3));
492        test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4));
493    }
494
495    #[test]
496    fn test_cdf_input_high() {
497        let cdf = |arg: u64| move |x: Categorical| x.cdf(arg);
498        test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4));
499    }
500
501    #[test]
502    fn test_sf_input_high() {
503        let sf = |arg: u64| move |x: Categorical| x.sf(arg);
504        test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4));
505    }
506
507    #[test]
508    fn test_cdf_sf_mirror() {
509        let mass = [4.0, 2.5, 2.5, 1.0];
510        let cat = Categorical::new(&mass).unwrap();
511        assert_eq!(cat.cdf(0), 1.-cat.sf(0));
512        assert_eq!(cat.cdf(1), 1.-cat.sf(1));
513        assert_eq!(cat.cdf(2), 1.-cat.sf(2));
514        assert_eq!(cat.cdf(3), 1.-cat.sf(3));
515    }
516
517    #[test]
518    fn test_inverse_cdf() {
519        let inverse_cdf = |arg: f64| move |x: Categorical| x.inverse_cdf(arg);
520        test_case(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.2));
521        test_case(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.5));
522        test_case(&[0.0, 3.0, 1.0, 1.0], 3, inverse_cdf(0.95));
523        test_case(&[4.0, 2.5, 2.5, 1.0], 0, inverse_cdf(0.2));
524        test_case(&[4.0, 2.5, 2.5, 1.0], 1, inverse_cdf(0.5));
525        test_case(&[4.0, 2.5, 2.5, 1.0], 3, inverse_cdf(0.95));
526    }
527
528    #[test]
529    #[should_panic]
530    fn test_inverse_cdf_input_low() {
531        let inverse_cdf = |arg: f64| move |x: Categorical| x.inverse_cdf(arg);
532        get_value(&[4.0, 2.5, 2.5, 1.0], inverse_cdf(0.0));
533    }
534
535    #[test]
536    #[should_panic]
537    fn test_inverse_cdf_input_high() {
538        let inverse_cdf = |arg: f64| move |x: Categorical| x.inverse_cdf(arg);
539        get_value(&[4.0, 2.5, 2.5, 1.0], inverse_cdf(1.0));
540    }
541
542    #[test]
543    fn test_discrete() {
544        test::check_discrete_distribution(&try_create(&[1.0, 2.0, 3.0, 4.0]), 4);
545        test::check_discrete_distribution(&try_create(&[0.0, 1.0, 2.0, 3.0, 4.0]), 5);
546    }
547}