1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::function::factorial;
3use crate::statistics::*;
4use std::cmp;
5use std::f64;
6
7#[derive(Copy, Clone, PartialEq, Eq, Debug)]
24pub struct Hypergeometric {
25 population: u64,
26 successes: u64,
27 draws: u64,
28}
29
30#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
32#[non_exhaustive]
33pub enum HypergeometricError {
34 TooManySuccesses,
36
37 TooManyDraws,
39}
40
41impl std::fmt::Display for HypergeometricError {
42 #[cfg_attr(coverage_nightly, coverage(off))]
43 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
44 match self {
45 HypergeometricError::TooManySuccesses => write!(f, "successes > population"),
46 HypergeometricError::TooManyDraws => write!(f, "draws > population"),
47 }
48 }
49}
50
51impl std::error::Error for HypergeometricError {}
52
53impl Hypergeometric {
54 pub fn new(
75 population: u64,
76 successes: u64,
77 draws: u64,
78 ) -> Result<Hypergeometric, HypergeometricError> {
79 if successes > population {
80 return Err(HypergeometricError::TooManySuccesses);
81 }
82
83 if draws > population {
84 return Err(HypergeometricError::TooManyDraws);
85 }
86
87 Ok(Hypergeometric {
88 population,
89 successes,
90 draws,
91 })
92 }
93
94 pub fn population(&self) -> u64 {
106 self.population
107 }
108
109 pub fn successes(&self) -> u64 {
121 self.successes
122 }
123
124 pub fn draws(&self) -> u64 {
136 self.draws
137 }
138
139 fn values_f64(&self) -> (f64, f64, f64) {
142 (
143 self.population as f64,
144 self.successes as f64,
145 self.draws as f64,
146 )
147 }
148}
149
150impl std::fmt::Display for Hypergeometric {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 write!(
153 f,
154 "Hypergeometric({},{},{})",
155 self.population, self.successes, self.draws
156 )
157 }
158}
159
160#[cfg(feature = "rand")]
161#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
162impl ::rand::distributions::Distribution<u64> for Hypergeometric {
163 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
164 let mut population = self.population as f64;
165 let mut successes = self.successes as f64;
166 let mut draws = self.draws;
167 let mut x = 0;
168 loop {
169 let p = successes / population;
170 let next: f64 = rng.gen();
171 if next < p {
172 x += 1;
173 successes -= 1.0;
174 }
175 population -= 1.0;
176 draws -= 1;
177 if draws == 0 {
178 break;
179 }
180 }
181 x
182 }
183}
184
185#[cfg(feature = "rand")]
186#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
187impl ::rand::distributions::Distribution<f64> for Hypergeometric {
188 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
189 rng.sample::<u64, _>(self) as f64
190 }
191}
192
193impl DiscreteCDF<u64, f64> for Hypergeometric {
194 fn cdf(&self, x: u64) -> f64 {
211 if x < self.min() {
212 0.0
213 } else if x >= self.max() {
214 1.0
215 } else {
216 let k = x;
217 let ln_denom = factorial::ln_binomial(self.population, self.draws);
218 (0..k + 1).fold(0.0, |acc, i| {
219 acc + (factorial::ln_binomial(self.successes, i)
220 + factorial::ln_binomial(self.population - self.successes, self.draws - i)
221 - ln_denom)
222 .exp()
223 })
224 }
225 }
226
227 fn sf(&self, x: u64) -> f64 {
244 if x < self.min() {
245 1.0
246 } else if x >= self.max() {
247 0.0
248 } else {
249 let k = x;
250 let ln_denom = factorial::ln_binomial(self.population, self.draws);
251 (k + 1..=self.max()).fold(0.0, |acc, i| {
252 acc + (factorial::ln_binomial(self.successes, i)
253 + factorial::ln_binomial(self.population - self.successes, self.draws - i)
254 - ln_denom)
255 .exp()
256 })
257 }
258 }
259}
260
261impl Min<u64> for Hypergeometric {
262 fn min(&self) -> u64 {
274 (self.draws + self.successes).saturating_sub(self.population)
275 }
276}
277
278impl Max<u64> for Hypergeometric {
279 fn max(&self) -> u64 {
291 cmp::min(self.successes, self.draws)
292 }
293}
294
295impl Distribution<f64> for Hypergeometric {
296 fn mean(&self) -> Option<f64> {
310 if self.population == 0 {
311 None
312 } else {
313 Some(self.successes as f64 * self.draws as f64 / self.population as f64)
314 }
315 }
316
317 fn variance(&self) -> Option<f64> {
331 if self.population <= 1 {
332 None
333 } else {
334 let (population, successes, draws) = self.values_f64();
335 let val = draws * successes * (population - draws) * (population - successes)
336 / (population * population * (population - 1.0));
337 Some(val)
338 }
339 }
340
341 fn skewness(&self) -> Option<f64> {
356 if self.population <= 2 {
357 None
358 } else {
359 let (population, successes, draws) = self.values_f64();
360 let val = (population - 1.0).sqrt()
361 * (population - 2.0 * draws)
362 * (population - 2.0 * successes)
363 / ((draws * successes * (population - successes) * (population - draws)).sqrt()
364 * (population - 2.0));
365 Some(val)
366 }
367 }
368}
369
370impl Mode<Option<u64>> for Hypergeometric {
371 fn mode(&self) -> Option<u64> {
381 Some(((self.draws + 1) * (self.successes + 1)) / (self.population + 2))
382 }
383}
384
385impl Discrete<u64, f64> for Hypergeometric {
386 fn pmf(&self, x: u64) -> f64 {
397 if x > self.draws {
398 0.0
399 } else {
400 factorial::binomial(self.successes, x)
401 * factorial::binomial(self.population - self.successes, self.draws - x)
402 / factorial::binomial(self.population, self.draws)
403 }
404 }
405
406 fn ln_pmf(&self, x: u64) -> f64 {
417 factorial::ln_binomial(self.successes, x)
418 + factorial::ln_binomial(self.population - self.successes, self.draws - x)
419 - factorial::ln_binomial(self.population, self.draws)
420 }
421}
422
423#[rustfmt::skip]
424#[cfg(test)]
425mod tests {
426 use super::*;
427 use crate::distribution::internal::*;
428 use crate::testing_boiler;
429
430 testing_boiler!(population: u64, successes: u64, draws: u64; Hypergeometric; HypergeometricError);
431
432 #[test]
433 fn test_create() {
434 create_ok(0, 0, 0);
435 create_ok(1, 1, 1,);
436 create_ok(2, 1, 1);
437 create_ok(2, 2, 2);
438 create_ok(10, 1, 1);
439 create_ok(10, 5, 3);
440 }
441
442 #[test]
443 fn test_bad_create() {
444 test_create_err(2, 3, 2, HypergeometricError::TooManySuccesses);
445 test_create_err(10, 5, 20, HypergeometricError::TooManyDraws);
446 create_err(0, 1, 1);
447 }
448
449 #[test]
450 fn test_mean() {
451 let mean = |x: Hypergeometric| x.mean().unwrap();
452 test_exact(1, 1, 1, 1.0, mean);
453 test_exact(2, 1, 1, 0.5, mean);
454 test_exact(2, 2, 2, 2.0, mean);
455 test_exact(10, 1, 1, 0.1, mean);
456 test_exact(10, 5, 3, 15.0 / 10.0, mean);
457 }
458
459 #[test]
460 fn test_mean_with_population_0() {
461 test_none(0, 0, 0, |dist| dist.mean());
462 }
463
464 #[test]
465 fn test_variance() {
466 let variance = |x: Hypergeometric| x.variance().unwrap();
467 test_exact(2, 1, 1, 0.25, variance);
468 test_exact(2, 2, 2, 0.0, variance);
469 test_exact(10, 1, 1, 81.0 / 900.0, variance);
470 test_exact(10, 5, 3, 525.0 / 900.0, variance);
471 }
472
473 #[test]
474 fn test_variance_with_pop_lte_1() {
475 test_none(1, 1, 1, |dist| dist.variance());
476 }
477
478 #[test]
479 fn test_skewness() {
480 let skewness = |x: Hypergeometric| x.skewness().unwrap();
481 test_exact(10, 1, 1, 8.0 / 3.0, skewness);
482 test_exact(10, 5, 3, 0.0, skewness);
483 }
484
485 #[test]
486 fn test_skewness_with_pop_lte_2() {
487 test_none(2, 2, 2, |dist| dist.skewness());
488 }
489
490 #[test]
491 fn test_mode() {
492 let mode = |x: Hypergeometric| x.mode().unwrap();
493 test_exact(0, 0, 0, 0, mode);
494 test_exact(1, 1, 1, 1, mode);
495 test_exact(2, 1, 1, 1, mode);
496 test_exact(2, 2, 2, 2, mode);
497 test_exact(10, 1, 1, 0, mode);
498 test_exact(10, 5, 3, 2, mode);
499 }
500
501 #[test]
502 fn test_min() {
503 let min = |x: Hypergeometric| x.min();
504 test_exact(0, 0, 0, 0, min);
505 test_exact(1, 1, 1, 1, min);
506 test_exact(2, 1, 1, 0, min);
507 test_exact(2, 2, 2, 2, min);
508 test_exact(10, 1, 1, 0, min);
509 test_exact(10, 5, 3, 0, min);
510 }
511
512 #[test]
513 fn test_max() {
514 let max = |x: Hypergeometric| x.max();
515 test_exact(0, 0, 0, 0, max);
516 test_exact(1, 1, 1, 1, max);
517 test_exact(2, 1, 1, 1, max);
518 test_exact(2, 2, 2, 2, max);
519 test_exact(10, 1, 1, 1, max);
520 test_exact(10, 5, 3, 3, max);
521 }
522
523 #[test]
524 fn test_pmf() {
525 let pmf = |arg: u64| move |x: Hypergeometric| x.pmf(arg);
526 test_exact(0, 0, 0, 1.0, pmf(0));
527 test_exact(1, 1, 1, 1.0, pmf(1));
528 test_exact(2, 1, 1, 0.5, pmf(0));
529 test_exact(2, 1, 1, 0.5, pmf(1));
530 test_exact(2, 2, 2, 1.0, pmf(2));
531 test_exact(10, 1, 1, 0.9, pmf(0));
532 test_exact(10, 1, 1, 0.1, pmf(1));
533 test_exact(10, 5, 3, 0.41666666666666666667, pmf(1));
534 test_exact(10, 5, 3, 0.083333333333333333333, pmf(3));
535 }
536
537 #[test]
538 fn test_ln_pmf() {
539 let ln_pmf = |arg: u64| move |x: Hypergeometric| x.ln_pmf(arg);
540 test_exact(0, 0, 0, 0.0, ln_pmf(0));
541 test_exact(1, 1, 1, 0.0, ln_pmf(1));
542 test_exact(2, 1, 1, -0.6931471805599453094172, ln_pmf(0));
543 test_exact(2, 1, 1, -0.6931471805599453094172, ln_pmf(1));
544 test_exact(2, 2, 2, 0.0, ln_pmf(2));
545 test_absolute(10, 1, 1, -0.1053605156578263012275, 1e-14, ln_pmf(0));
546 test_absolute(10, 1, 1, -2.302585092994045684018, 1e-14, ln_pmf(1));
547 test_absolute(10, 5, 3, -0.875468737353899935621, 1e-14, ln_pmf(1));
548 test_absolute(10, 5, 3, -2.484906649788000310234, 1e-14, ln_pmf(3));
549 }
550
551 #[test]
552 fn test_cdf() {
553 let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg);
554 test_exact(2, 1, 1, 0.5, cdf(0));
555 test_absolute(10, 1, 1, 0.9, 1e-14, cdf(0));
556 test_absolute(10, 5, 3, 0.5, 1e-15, cdf(1));
557 test_absolute(10, 5, 3, 11.0 / 12.0, 1e-14, cdf(2));
558 test_absolute(10000, 2, 9800, 199.0 / 499950.0, 1e-14, cdf(0));
559 test_absolute(10000, 2, 9800, 19799.0 / 499950.0, 1e-12, cdf(1));
560 }
561
562 #[test]
563 fn test_sf() {
564 let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg);
565 test_exact(2, 1, 1, 0.5, sf(0));
566 test_absolute(10, 1, 1, 0.1, 1e-14, sf(0));
567 test_absolute(10, 5, 3, 0.5, 1e-15, sf(1));
568 test_absolute(10, 5, 3, 1.0 / 12.0, 1e-14, sf(2));
569 test_absolute(10000, 2, 9800, 499751. / 499950.0, 1e-10, sf(0));
570 test_absolute(10000, 2, 9800, 480151. / 499950.0, 1e-10, sf(1));
571 }
572
573 #[test]
574 fn test_cdf_arg_too_big() {
575 let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg);
576 test_exact(0, 0, 0, 1.0, cdf(0));
577 }
578
579 #[test]
580 fn test_cdf_arg_too_small() {
581 let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg);
582 test_exact(2, 2, 2, 0.0, cdf(0));
583 }
584
585 #[test]
586 fn test_sf_arg_too_big() {
587 let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg);
588 test_exact(0, 0, 0, 0.0, sf(0));
589 }
590
591 #[test]
592 fn test_sf_arg_too_small() {
593 let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg);
594 test_exact(2, 2, 2, 1.0, sf(0));
595 }
596
597 #[test]
598 fn test_discrete() {
599 test::check_discrete_distribution(&create_ok(5, 4, 3), 4);
600 test::check_discrete_distribution(&create_ok(3, 2, 1), 2);
601 }
602}