1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::statistics::*;
3use crate::{Result, StatsError};
4use rand::Rng;
5use std::f64;
6
7#[derive(Debug, Clone, PartialEq)]
25pub struct Categorical {
26 norm_pmf: Vec<f64>,
27 cdf: Vec<f64>,
28 sf: Vec<f64>
29}
30
31impl Categorical {
32 pub fn new(prob_mass: &[f64]) -> Result<Categorical> {
57 if !super::internal::is_valid_multinomial(prob_mass, true) {
58 Err(StatsError::BadParams)
59 } else {
60 let cdf = prob_mass_to_cdf(prob_mass);
62 let sf = cdf_to_sf(&cdf);
64 let sum = cdf[cdf.len() - 1];
66 let mut norm_pmf = vec![0.0; prob_mass.len()];
67 norm_pmf
68 .iter_mut()
69 .zip(prob_mass.iter())
70 .for_each(|(np, pm)| *np = *pm / sum);
71 Ok(Categorical { norm_pmf, cdf, sf })
72 }
73 }
74
75 fn cdf_max(&self) -> f64 {
76 *self.cdf.last().unwrap()
77 }
78}
79
80impl ::rand::distributions::Distribution<f64> for Categorical {
81 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
82 sample_unchecked(rng, &self.cdf)
83 }
84}
85
86impl DiscreteCDF<u64, f64> for Categorical {
87 fn cdf(&self, x: u64) -> f64 {
98 if x >= self.cdf.len() as u64 {
99 1.0
100 } else {
101 self.cdf.get(x as usize).unwrap() / self.cdf_max()
102 }
103 }
104
105 fn sf(&self, x: u64) -> f64 {
114 if x >= self.sf.len() as u64 {
115 0.0
116 } else {
117 self.sf.get(x as usize).unwrap() / self.cdf_max()
118 }
119 }
120
121 fn inverse_cdf(&self, x: f64) -> u64 {
139 if x >= 1.0 || x <= 0.0 {
140 panic!("x must be in [0, 1]")
141 }
142 let denorm_prob = x * self.cdf_max();
143 binary_index(&self.cdf, denorm_prob) as u64
144 }
145}
146
147impl Min<u64> for Categorical {
148 fn min(&self) -> u64 {
158 0
159 }
160}
161
162impl Max<u64> for Categorical {
163 fn max(&self) -> u64 {
173 self.cdf.len() as u64 - 1
174 }
175}
176
177impl Distribution<f64> for Categorical {
178 fn mean(&self) -> Option<f64> {
190 Some(
191 self.norm_pmf
192 .iter()
193 .enumerate()
194 .fold(0.0, |acc, (idx, &val)| acc + idx as f64 * val),
195 )
196 }
197 fn variance(&self) -> Option<f64> {
209 let mu = self.mean()?;
210 let var = self
211 .norm_pmf
212 .iter()
213 .enumerate()
214 .fold(0.0, |acc, (idx, &val)| {
215 let r = idx as f64 - mu;
216 acc + r * r * val
217 });
218 Some(var)
219 }
220 fn entropy(&self) -> Option<f64> {
232 let entr = -self
233 .norm_pmf
234 .iter()
235 .filter(|&&p| p > 0.0)
236 .map(|p| p * p.ln())
237 .sum::<f64>();
238 Some(entr)
239 }
240}
241impl Median<f64> for Categorical {
242 fn median(&self) -> f64 {
250 self.inverse_cdf(0.5) as f64
251 }
252}
253
254impl Discrete<u64, f64> for Categorical {
255 fn pmf(&self, x: u64) -> f64 {
264 *self.norm_pmf.get(x as usize).unwrap_or(&0.0)
265 }
266
267 fn ln_pmf(&self, x: u64) -> f64 {
270 self.pmf(x).ln()
271 }
272}
273
274pub fn sample_unchecked<R: Rng + ?Sized>(rng: &mut R, cdf: &[f64]) -> f64 {
277 let draw = rng.gen::<f64>() * cdf.last().unwrap();
278 cdf.iter()
279 .enumerate()
280 .find(|(_, val)| **val >= draw)
281 .map(|(i, _)| i)
282 .unwrap() as f64
283}
284
285pub fn prob_mass_to_cdf(prob_mass: &[f64]) -> Vec<f64> {
288 let mut cdf = Vec::with_capacity(prob_mass.len());
289 prob_mass.iter().fold(0.0, |s, p| {
290 let sum = s + p;
291 cdf.push(sum);
292 sum
293 });
294 cdf
295}
296
297pub fn cdf_to_sf(cdf: &[f64]) -> Vec<f64> {
300 let max = *cdf.last().unwrap();
301 cdf.iter().map(|x| max - x).collect()
302}
303
304fn binary_index(search: &[f64], val: f64) -> usize {
310 use std::cmp;
311
312 let mut low = 0_isize;
313 let mut high = search.len() as isize - 1;
314 while low <= high {
315 let mid = low + ((high - low) / 2);
316 let el = *search.get(mid as usize).unwrap();
317 if el > val {
318 high = mid - 1;
319 } else if el < val {
320 low = mid.saturating_add(1);
321 } else {
322 return mid as usize;
323 }
324 }
325 cmp::min(search.len(), cmp::max(low, 0) as usize)
326}
327
328#[test]
329fn test_prob_mass_to_cdf() {
330 let arr = [0.0, 0.5, 0.5, 3.0, 1.1];
331 let res = prob_mass_to_cdf(&arr);
332 assert_eq!(res, [0.0, 0.5, 1.0, 4.0, 5.1]);
333}
334
335#[test]
336fn test_binary_index() {
337 let arr = [0.0, 3.0, 5.0, 9.0, 10.0];
338 assert_eq!(0, binary_index(&arr, -1.0));
339 assert_eq!(2, binary_index(&arr, 5.0));
340 assert_eq!(3, binary_index(&arr, 5.2));
341 assert_eq!(5, binary_index(&arr, 10.1));
342}
343
344#[rustfmt::skip]
345#[cfg(all(test, feature = "nightly"))]
346mod tests {
347 use std::fmt::Debug;
348 use crate::statistics::*;
349 use crate::distribution::{Categorical, Discrete, DiscreteCDF};
350 use crate::distribution::internal::*;
351 use crate::consts::ACC;
352
353 fn try_create(prob_mass: &[f64]) -> Categorical {
354 let n = Categorical::new(prob_mass);
355 assert!(n.is_ok());
356 n.unwrap()
357 }
358
359 fn create_case(prob_mass: &[f64]) {
360 try_create(prob_mass);
361 }
362
363 fn bad_create_case(prob_mass: &[f64]) {
364 let n = Categorical::new(prob_mass);
365 assert!(n.is_err());
366 }
367
368 fn get_value<T, F>(prob_mass: &[f64], eval: F) -> T
369 where T: PartialEq + Debug,
370 F: Fn(Categorical) -> T
371 {
372 let n = try_create(prob_mass);
373 eval(n)
374 }
375
376 fn test_case<T, F>(prob_mass: &[f64], expected: T, eval: F)
377 where T: PartialEq + Debug,
378 F: Fn(Categorical) -> T
379 {
380 let x = get_value(prob_mass, eval);
381 assert_eq!(expected, x);
382 }
383
384 fn test_almost<F>(prob_mass: &[f64], expected: f64, acc: f64, eval: F)
385 where F: Fn(Categorical) -> f64
386 {
387 let x = get_value(prob_mass, eval);
388 assert_almost_eq!(expected, x, acc);
389 }
390
391 #[test]
392 fn test_create() {
393 create_case(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
394 }
395
396 #[test]
397 fn test_bad_create() {
398 bad_create_case(&[-1.0, 1.0]);
399 bad_create_case(&[0.0, 0.0]);
400 }
401
402 #[test]
403 fn test_mean() {
404 let mean = |x: Categorical| x.mean().unwrap();
405 test_case(&[0.0, 0.25, 0.5, 0.25], 2.0, mean);
406 test_case(&[0.0, 1.0, 2.0, 1.0], 2.0, mean);
407 test_case(&[0.0, 0.5, 0.5], 1.5, mean);
408 test_case(&[0.75, 0.25], 0.25, mean);
409 test_case(&[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 5.0, mean);
410 }
411
412 #[test]
413 fn test_variance() {
414 let variance = |x: Categorical| x.variance().unwrap();
415 test_case(&[0.0, 0.25, 0.5, 0.25], 0.5, variance);
416 test_case(&[0.0, 1.0, 2.0, 1.0], 0.5, variance);
417 test_case(&[0.0, 0.5, 0.5], 0.25, variance);
418 test_case(&[0.75, 0.25], 0.1875, variance);
419 test_case(&[1.0, 0.0, 1.0], 1.0, variance);
420 }
421
422 #[test]
423 fn test_entropy() {
424 let entropy = |x: Categorical| x.entropy().unwrap();
425 test_case(&[0.0, 1.0], 0.0, entropy);
426 test_almost(&[0.0, 1.0, 1.0], 2f64.ln(), 1e-15, entropy);
427 test_almost(&[1.0, 1.0, 1.0], 3f64.ln(), 1e-15, entropy);
428 test_almost(&vec![1.0; 100], 100f64.ln(), 1e-14, entropy);
429 test_almost(&[0.0, 0.25, 0.5, 0.25], 1.0397207708399179, 1e-15, entropy);
430 }
431
432 #[test]
433 fn test_median() {
434 let median = |x: Categorical| x.median();
435 test_case(&[0.0, 3.0, 1.0, 1.0], 1.0, median);
436 test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, median);
437 }
438
439 #[test]
440 fn test_min_max() {
441 let min = |x: Categorical| x.min();
442 let max = |x: Categorical| x.max();
443 test_case(&[4.0, 2.5, 2.5, 1.0], 0, min);
444 test_case(&[4.0, 2.5, 2.5, 1.0], 3, max);
445 }
446
447 #[test]
448 fn test_pmf() {
449 let pmf = |arg: u64| move |x: Categorical| x.pmf(arg);
450 test_case(&[0.0, 0.25, 0.5, 0.25], 0.0, pmf(0));
451 test_case(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(1));
452 test_case(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(3));
453 }
454
455 #[test]
456 fn test_pmf_x_too_high() {
457 let pmf = |arg: u64| move |x: Categorical| x.pmf(arg);
458 test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, pmf(4));
459 }
460
461 #[test]
462 fn test_ln_pmf() {
463 let ln_pmf = |arg: u64| move |x: Categorical| x.ln_pmf(arg);
464 test_case(&[0.0, 0.25, 0.5, 0.25], 0f64.ln(), ln_pmf(0));
465 test_case(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(1));
466 test_case(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(3));
467 }
468
469 #[test]
470 fn test_ln_pmf_x_too_high() {
471 let ln_pmf = |arg: u64| move |x: Categorical| x.ln_pmf(arg);
472 test_case(&[4.0, 2.5, 2.5, 1.0], f64::NEG_INFINITY, ln_pmf(4));
473 }
474
475 #[test]
476 fn test_cdf() {
477 let cdf = |arg: u64| move |x: Categorical| x.cdf(arg);
478 test_case(&[0.0, 3.0, 1.0, 1.0], 3.0 / 5.0, cdf(1));
479 test_case(&[1.0, 1.0, 1.0, 1.0], 0.25, cdf(0));
480 test_case(&[4.0, 2.5, 2.5, 1.0], 0.4, cdf(0));
481 test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(3));
482 test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4));
483 }
484
485 #[test]
486 fn test_sf() {
487 let sf = |arg: u64| move |x: Categorical| x.sf(arg);
488 test_case(&[0.0, 3.0, 1.0, 1.0], 2.0 / 5.0, sf(1));
489 test_case(&[1.0, 1.0, 1.0, 1.0], 0.75, sf(0));
490 test_case(&[4.0, 2.5, 2.5, 1.0], 0.6, sf(0));
491 test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(3));
492 test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4));
493 }
494
495 #[test]
496 fn test_cdf_input_high() {
497 let cdf = |arg: u64| move |x: Categorical| x.cdf(arg);
498 test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4));
499 }
500
501 #[test]
502 fn test_sf_input_high() {
503 let sf = |arg: u64| move |x: Categorical| x.sf(arg);
504 test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4));
505 }
506
507 #[test]
508 fn test_cdf_sf_mirror() {
509 let mass = [4.0, 2.5, 2.5, 1.0];
510 let cat = Categorical::new(&mass).unwrap();
511 assert_eq!(cat.cdf(0), 1.-cat.sf(0));
512 assert_eq!(cat.cdf(1), 1.-cat.sf(1));
513 assert_eq!(cat.cdf(2), 1.-cat.sf(2));
514 assert_eq!(cat.cdf(3), 1.-cat.sf(3));
515 }
516
517 #[test]
518 fn test_inverse_cdf() {
519 let inverse_cdf = |arg: f64| move |x: Categorical| x.inverse_cdf(arg);
520 test_case(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.2));
521 test_case(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.5));
522 test_case(&[0.0, 3.0, 1.0, 1.0], 3, inverse_cdf(0.95));
523 test_case(&[4.0, 2.5, 2.5, 1.0], 0, inverse_cdf(0.2));
524 test_case(&[4.0, 2.5, 2.5, 1.0], 1, inverse_cdf(0.5));
525 test_case(&[4.0, 2.5, 2.5, 1.0], 3, inverse_cdf(0.95));
526 }
527
528 #[test]
529 #[should_panic]
530 fn test_inverse_cdf_input_low() {
531 let inverse_cdf = |arg: f64| move |x: Categorical| x.inverse_cdf(arg);
532 get_value(&[4.0, 2.5, 2.5, 1.0], inverse_cdf(0.0));
533 }
534
535 #[test]
536 #[should_panic]
537 fn test_inverse_cdf_input_high() {
538 let inverse_cdf = |arg: f64| move |x: Categorical| x.inverse_cdf(arg);
539 get_value(&[4.0, 2.5, 2.5, 1.0], inverse_cdf(1.0));
540 }
541
542 #[test]
543 fn test_discrete() {
544 test::check_discrete_distribution(&try_create(&[1.0, 2.0, 3.0, 4.0]), 4);
545 test::check_discrete_distribution(&try_create(&[0.0, 1.0, 2.0, 3.0, 4.0]), 5);
546 }
547}