1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::statistics::*;
3use std::f64;
4
5#[derive(Clone, PartialEq, Debug)]
22pub struct Categorical {
23 norm_pmf: Vec<f64>,
24 cdf: Vec<f64>,
25 sf: Vec<f64>,
26}
27
28#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
30#[non_exhaustive]
31pub enum CategoricalError {
32 ProbMassEmpty,
34
35 ProbMassSumZero,
37
38 ProbMassHasInvalidElements,
40}
41
42impl std::fmt::Display for CategoricalError {
43 #[cfg_attr(coverage_nightly, coverage(off))]
44 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
45 match self {
46 CategoricalError::ProbMassEmpty => write!(f, "Probability mass is empty"),
47 CategoricalError::ProbMassSumZero => write!(f, "Probabilities sum up to zero"),
48 CategoricalError::ProbMassHasInvalidElements => write!(
49 f,
50 "Probability mass contains at least one element which is NaN or less than zero"
51 ),
52 }
53 }
54}
55
56impl std::error::Error for CategoricalError {}
57
58impl Categorical {
59 pub fn new(prob_mass: &[f64]) -> Result<Categorical, CategoricalError> {
84 if prob_mass.is_empty() {
85 return Err(CategoricalError::ProbMassEmpty);
86 }
87
88 let mut prob_sum = 0.0;
89 for &p in prob_mass {
90 if p.is_nan() || p < 0.0 {
91 return Err(CategoricalError::ProbMassHasInvalidElements);
92 }
93
94 prob_sum += p;
95 }
96
97 if prob_sum == 0.0 {
98 return Err(CategoricalError::ProbMassSumZero);
99 }
100
101 let cdf = prob_mass_to_cdf(prob_mass);
103 let sf = cdf_to_sf(&cdf);
105 let sum = cdf[cdf.len() - 1];
107 let mut norm_pmf = vec![0.0; prob_mass.len()];
108 norm_pmf
109 .iter_mut()
110 .zip(prob_mass.iter())
111 .for_each(|(np, pm)| *np = *pm / sum);
112 Ok(Categorical { norm_pmf, cdf, sf })
113 }
114
115 fn cdf_max(&self) -> f64 {
116 *self.cdf.last().unwrap()
117 }
118}
119
120impl std::fmt::Display for Categorical {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 write!(f, "Cat({:#?})", self.norm_pmf)
123 }
124}
125
126#[cfg(feature = "rand")]
127#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
128impl ::rand::distributions::Distribution<usize> for Categorical {
129 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> usize {
130 sample_unchecked(rng, &self.cdf)
131 }
132}
133
134#[cfg(feature = "rand")]
135#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
136impl ::rand::distributions::Distribution<u64> for Categorical {
137 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
138 sample_unchecked(rng, &self.cdf) as u64
139 }
140}
141
142#[cfg(feature = "rand")]
143#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
144impl ::rand::distributions::Distribution<f64> for Categorical {
145 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
146 sample_unchecked(rng, &self.cdf) as f64
147 }
148}
149
150impl DiscreteCDF<u64, f64> for Categorical {
151 fn cdf(&self, x: u64) -> f64 {
162 if x >= self.cdf.len() as u64 {
163 1.0
164 } else {
165 self.cdf.get(x as usize).unwrap() / self.cdf_max()
166 }
167 }
168
169 fn sf(&self, x: u64) -> f64 {
178 if x >= self.sf.len() as u64 {
179 0.0
180 } else {
181 self.sf.get(x as usize).unwrap() / self.cdf_max()
182 }
183 }
184
185 fn inverse_cdf(&self, x: f64) -> u64 {
203 if x >= 1.0 || x <= 0.0 {
204 panic!("x must be in [0, 1]")
205 }
206 let denorm_prob = x * self.cdf_max();
207 binary_index(&self.cdf, denorm_prob) as u64
208 }
209}
210
211impl Min<u64> for Categorical {
212 fn min(&self) -> u64 {
222 0
223 }
224}
225
226impl Max<u64> for Categorical {
227 fn max(&self) -> u64 {
237 self.cdf.len() as u64 - 1
238 }
239}
240
241impl Distribution<f64> for Categorical {
242 fn mean(&self) -> Option<f64> {
254 Some(
255 self.norm_pmf
256 .iter()
257 .enumerate()
258 .fold(0.0, |acc, (idx, &val)| acc + idx as f64 * val),
259 )
260 }
261
262 fn variance(&self) -> Option<f64> {
274 let mu = self.mean()?;
275 let var = self
276 .norm_pmf
277 .iter()
278 .enumerate()
279 .fold(0.0, |acc, (idx, &val)| {
280 let r = idx as f64 - mu;
281 acc + r * r * val
282 });
283 Some(var)
284 }
285
286 fn entropy(&self) -> Option<f64> {
298 let entr = -self
299 .norm_pmf
300 .iter()
301 .filter(|&&p| p > 0.0)
302 .map(|p| p * p.ln())
303 .sum::<f64>();
304 Some(entr)
305 }
306}
307impl Median<f64> for Categorical {
308 fn median(&self) -> f64 {
316 self.inverse_cdf(0.5) as f64
317 }
318}
319
320impl Discrete<u64, f64> for Categorical {
321 fn pmf(&self, x: u64) -> f64 {
330 *self.norm_pmf.get(x as usize).unwrap_or(&0.0)
331 }
332
333 fn ln_pmf(&self, x: u64) -> f64 {
336 self.pmf(x).ln()
337 }
338}
339
340#[cfg(feature = "rand")]
343#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
344pub fn sample_unchecked<R: ::rand::Rng + ?Sized>(rng: &mut R, cdf: &[f64]) -> usize {
345 let draw = rng.gen::<f64>() * cdf.last().unwrap();
346 cdf.iter().position(|val| *val >= draw).unwrap()
347}
348
349pub fn prob_mass_to_cdf(prob_mass: &[f64]) -> Vec<f64> {
352 let mut cdf = Vec::with_capacity(prob_mass.len());
353 prob_mass.iter().fold(0.0, |s, p| {
354 let sum = s + p;
355 cdf.push(sum);
356 sum
357 });
358 cdf
359}
360
361pub fn cdf_to_sf(cdf: &[f64]) -> Vec<f64> {
364 let max = *cdf.last().unwrap();
365 cdf.iter().map(|x| max - x).collect()
366}
367
368fn binary_index(search: &[f64], val: f64) -> usize {
374 use std::cmp;
375
376 let mut low = 0_isize;
377 let mut high = search.len() as isize - 1;
378 while low <= high {
379 let mid = low + ((high - low) / 2);
380 let el = *search.get(mid as usize).unwrap();
381 if el > val {
382 high = mid - 1;
383 } else if el < val {
384 low = mid.saturating_add(1);
385 } else {
386 return mid as usize;
387 }
388 }
389 cmp::min(search.len(), cmp::max(low, 0) as usize)
390}
391
392#[test]
393fn test_prob_mass_to_cdf() {
394 let arr = [0.0, 0.5, 0.5, 3.0, 1.1];
395 let res = prob_mass_to_cdf(&arr);
396 assert_eq!(res, [0.0, 0.5, 1.0, 4.0, 5.1]);
397}
398
399#[test]
400fn test_binary_index() {
401 let arr = [0.0, 3.0, 5.0, 9.0, 10.0];
402 assert_eq!(0, binary_index(&arr, -1.0));
403 assert_eq!(2, binary_index(&arr, 5.0));
404 assert_eq!(3, binary_index(&arr, 5.2));
405 assert_eq!(5, binary_index(&arr, 10.1));
406}
407
408#[rustfmt::skip]
409#[cfg(test)]
410mod tests {
411 use super::*;
412 use crate::distribution::internal::*;
413 use crate::testing_boiler;
414
415 testing_boiler!(prob_mass: &[f64]; Categorical; CategoricalError);
416
417 #[test]
418 fn test_create() {
419 create_ok(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
420 }
421
422 #[test]
423 fn test_bad_create() {
424 let invalid: &[(&[f64], CategoricalError)] = &[
425 (&[], CategoricalError::ProbMassEmpty),
426 (&[-1.0, 1.0], CategoricalError::ProbMassHasInvalidElements),
427 (&[0.0, 0.0, 0.0], CategoricalError::ProbMassSumZero),
428 ];
429
430 for &(prob_mass, err) in invalid {
431 test_create_err(prob_mass, err);
432 }
433 }
434
435 #[test]
436 fn test_mean() {
437 let mean = |x: Categorical| x.mean().unwrap();
438 test_exact(&[0.0, 0.25, 0.5, 0.25], 2.0, mean);
439 test_exact(&[0.0, 1.0, 2.0, 1.0], 2.0, mean);
440 test_exact(&[0.0, 0.5, 0.5], 1.5, mean);
441 test_exact(&[0.75, 0.25], 0.25, mean);
442 test_exact(&[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 5.0, mean);
443 }
444
445 #[test]
446 fn test_variance() {
447 let variance = |x: Categorical| x.variance().unwrap();
448 test_exact(&[0.0, 0.25, 0.5, 0.25], 0.5, variance);
449 test_exact(&[0.0, 1.0, 2.0, 1.0], 0.5, variance);
450 test_exact(&[0.0, 0.5, 0.5], 0.25, variance);
451 test_exact(&[0.75, 0.25], 0.1875, variance);
452 test_exact(&[1.0, 0.0, 1.0], 1.0, variance);
453 }
454
455 #[test]
456 fn test_entropy() {
457 let entropy = |x: Categorical| x.entropy().unwrap();
458 test_exact(&[0.0, 1.0], 0.0, entropy);
459 test_absolute(&[0.0, 1.0, 1.0], 2f64.ln(), 1e-15, entropy);
460 test_absolute(&[1.0, 1.0, 1.0], 3f64.ln(), 1e-15, entropy);
461 test_absolute(&vec![1.0; 100], 100f64.ln(), 1e-14, entropy);
462 test_absolute(&[0.0, 0.25, 0.5, 0.25], 1.0397207708399179, 1e-15, entropy);
463 }
464
465 #[test]
466 fn test_median() {
467 let median = |x: Categorical| x.median();
468 test_exact(&[0.0, 3.0, 1.0, 1.0], 1.0, median);
469 test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, median);
470 }
471
472 #[test]
473 fn test_min_max() {
474 let min = |x: Categorical| x.min();
475 let max = |x: Categorical| x.max();
476 test_exact(&[4.0, 2.5, 2.5, 1.0], 0, min);
477 test_exact(&[4.0, 2.5, 2.5, 1.0], 3, max);
478 }
479
480 #[test]
481 fn test_pmf() {
482 let pmf = |arg: u64| move |x: Categorical| x.pmf(arg);
483 test_exact(&[0.0, 0.25, 0.5, 0.25], 0.0, pmf(0));
484 test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(1));
485 test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(3));
486 }
487
488 #[test]
489 fn test_pmf_x_too_high() {
490 let pmf = |arg: u64| move |x: Categorical| x.pmf(arg);
491 test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, pmf(4));
492 }
493
494 #[test]
495 fn test_ln_pmf() {
496 let ln_pmf = |arg: u64| move |x: Categorical| x.ln_pmf(arg);
497 test_exact(&[0.0, 0.25, 0.5, 0.25], 0f64.ln(), ln_pmf(0));
498 test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(1));
499 test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(3));
500 }
501
502 #[test]
503 fn test_ln_pmf_x_too_high() {
504 let ln_pmf = |arg: u64| move |x: Categorical| x.ln_pmf(arg);
505 test_exact(&[4.0, 2.5, 2.5, 1.0], f64::NEG_INFINITY, ln_pmf(4));
506 }
507
508 #[test]
509 fn test_cdf() {
510 let cdf = |arg: u64| move |x: Categorical| x.cdf(arg);
511 test_exact(&[0.0, 3.0, 1.0, 1.0], 3.0 / 5.0, cdf(1));
512 test_exact(&[1.0, 1.0, 1.0, 1.0], 0.25, cdf(0));
513 test_exact(&[4.0, 2.5, 2.5, 1.0], 0.4, cdf(0));
514 test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(3));
515 test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4));
516 }
517
518 #[test]
519 fn test_sf() {
520 let sf = |arg: u64| move |x: Categorical| x.sf(arg);
521 test_exact(&[0.0, 3.0, 1.0, 1.0], 2.0 / 5.0, sf(1));
522 test_exact(&[1.0, 1.0, 1.0, 1.0], 0.75, sf(0));
523 test_exact(&[4.0, 2.5, 2.5, 1.0], 0.6, sf(0));
524 test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(3));
525 test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4));
526 }
527
528 #[test]
529 fn test_cdf_input_high() {
530 let cdf = |arg: u64| move |x: Categorical| x.cdf(arg);
531 test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4));
532 }
533
534 #[test]
535 fn test_sf_input_high() {
536 let sf = |arg: u64| move |x: Categorical| x.sf(arg);
537 test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4));
538 }
539
540 #[test]
541 fn test_cdf_sf_mirror() {
542 let mass = [4.0, 2.5, 2.5, 1.0];
543 let cat = Categorical::new(&mass).unwrap();
544 assert_eq!(cat.cdf(0), 1.-cat.sf(0));
545 assert_eq!(cat.cdf(1), 1.-cat.sf(1));
546 assert_eq!(cat.cdf(2), 1.-cat.sf(2));
547 assert_eq!(cat.cdf(3), 1.-cat.sf(3));
548 }
549
550 #[test]
551 fn test_inverse_cdf() {
552 let inverse_cdf = |arg: f64| move |x: Categorical| x.inverse_cdf(arg);
553 test_exact(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.2));
554 test_exact(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.5));
555 test_exact(&[0.0, 3.0, 1.0, 1.0], 3, inverse_cdf(0.95));
556 test_exact(&[4.0, 2.5, 2.5, 1.0], 0, inverse_cdf(0.2));
557 test_exact(&[4.0, 2.5, 2.5, 1.0], 1, inverse_cdf(0.5));
558 test_exact(&[4.0, 2.5, 2.5, 1.0], 3, inverse_cdf(0.95));
559 }
560
561 #[test]
562 #[should_panic]
563 fn test_inverse_cdf_input_low() {
564 let dist = create_ok(&[4.0, 2.5, 2.5, 1.0]);
565 dist.inverse_cdf(0.0);
566 }
567
568 #[test]
569 #[should_panic]
570 fn test_inverse_cdf_input_high() {
571 let dist = create_ok(&[4.0, 2.5, 2.5, 1.0]);
572 dist.inverse_cdf(1.0);
573 }
574
575 #[test]
576 fn test_discrete() {
577 test::check_discrete_distribution(&create_ok(&[1.0, 2.0, 3.0, 4.0]), 4);
578 test::check_discrete_distribution(&create_ok(&[0.0, 1.0, 2.0, 3.0, 4.0]), 5);
579 }
580}