1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::function::factorial;
3use crate::statistics::*;
4use crate::{Result, StatsError};
5use rand::Rng;
6use std::cmp;
7use std::f64;
8
9#[derive(Debug, Copy, Clone, PartialEq)]
18pub struct Hypergeometric {
19 population: u64,
20 successes: u64,
21 draws: u64,
22}
23
24impl Hypergeometric {
25 pub fn new(population: u64, successes: u64, draws: u64) -> Result<Hypergeometric> {
46 if successes > population || draws > population {
47 Err(StatsError::BadParams)
48 } else {
49 Ok(Hypergeometric {
50 population,
51 successes,
52 draws,
53 })
54 }
55 }
56
57 pub fn population(&self) -> u64 {
69 self.population
70 }
71
72 pub fn successes(&self) -> u64 {
84 self.successes
85 }
86
87 pub fn draws(&self) -> u64 {
99 self.draws
100 }
101
102 fn values_f64(&self) -> (f64, f64, f64) {
105 (
106 self.population as f64,
107 self.successes as f64,
108 self.draws as f64,
109 )
110 }
111}
112
113impl ::rand::distributions::Distribution<f64> for Hypergeometric {
114 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
115 let mut population = self.population as f64;
116 let mut successes = self.successes as f64;
117 let mut draws = self.draws;
118 let mut x = 0.0;
119 loop {
120 let p = successes / population;
121 let next: f64 = rng.gen();
122 if next < p {
123 x += 1.0;
124 successes -= 1.0;
125 }
126 population -= 1.0;
127 draws -= 1;
128 if draws == 0 {
129 break;
130 }
131 }
132 x
133 }
134}
135
136impl DiscreteCDF<u64, f64> for Hypergeometric {
137 fn cdf(&self, x: u64) -> f64 {
155 if x < self.min() {
156 0.0
157 } else if x >= self.max() {
158 1.0
159 } else {
160 let k = x;
161 let ln_denom = factorial::ln_binomial(self.population, self.draws);
162 (0..k + 1).fold(0.0, |acc, i| {
163 acc + (factorial::ln_binomial(self.successes, i)
164 + factorial::ln_binomial(self.population - self.successes, self.draws - i)
165 - ln_denom)
166 .exp()
167 })
168 }
169 }
170
171 fn sf(&self, x: u64) -> f64 {
189 if x < self.min() {
190 1.0
191 } else if x >= self.max() {
192 0.0
193 } else {
194 let k = x;
195 let ln_denom = factorial::ln_binomial(self.population, self.draws);
196 (k + 1 .. self.max() + 1).fold(0.0, |acc, i| {
197 acc + (factorial::ln_binomial(self.successes, i)
198 + factorial::ln_binomial(self.population - self.successes, self.draws - i)
199 - ln_denom)
200 .exp()
201 })
202 }
203 }
204}
205
206impl Min<u64> for Hypergeometric {
207 fn min(&self) -> u64 {
219 (self.draws + self.successes).saturating_sub(self.population)
220 }
221}
222
223impl Max<u64> for Hypergeometric {
224 fn max(&self) -> u64 {
236 cmp::min(self.successes, self.draws)
237 }
238}
239
240impl Distribution<f64> for Hypergeometric {
241 fn mean(&self) -> Option<f64> {
255 if self.population == 0 {
256 None
257 } else {
258 Some(self.successes as f64 * self.draws as f64 / self.population as f64)
259 }
260 }
261 fn variance(&self) -> Option<f64> {
275 if self.population <= 1 {
276 None
277 } else {
278 let (population, successes, draws) = self.values_f64();
279 let val = draws * successes * (population - draws) * (population - successes)
280 / (population * population * (population - 1.0));
281 Some(val)
282 }
283 }
284 fn skewness(&self) -> Option<f64> {
299 if self.population <= 2 {
300 None
301 } else {
302 let (population, successes, draws) = self.values_f64();
303 let val = (population - 1.0).sqrt()
304 * (population - 2.0 * draws)
305 * (population - 2.0 * successes)
306 / ((draws * successes * (population - successes) * (population - draws)).sqrt()
307 * (population - 2.0));
308 Some(val)
309 }
310 }
311}
312
313impl Mode<Option<u64>> for Hypergeometric {
314 fn mode(&self) -> Option<u64> {
324 Some(((self.draws + 1) * (self.successes + 1)) / (self.population + 2))
325 }
326}
327
328impl Discrete<u64, f64> for Hypergeometric {
329 fn pmf(&self, x: u64) -> f64 {
340 if x > self.draws {
341 0.0
342 } else {
343 factorial::binomial(self.successes, x)
344 * factorial::binomial(self.population - self.successes, self.draws - x)
345 / factorial::binomial(self.population, self.draws)
346 }
347 }
348
349 fn ln_pmf(&self, x: u64) -> f64 {
360 factorial::ln_binomial(self.successes, x)
361 + factorial::ln_binomial(self.population - self.successes, self.draws - x)
362 - factorial::ln_binomial(self.population, self.draws)
363 }
364}
365
366#[rustfmt::skip]
367#[cfg(all(test, feature = "nightly"))]
368mod tests {
369 use std::fmt::Debug;
370 use crate::statistics::*;
371 use crate::distribution::{DiscreteCDF, Discrete, Hypergeometric};
372 use crate::distribution::internal::*;
373 use crate::consts::ACC;
374
375 fn try_create(population: u64, successes: u64, draws: u64) -> Hypergeometric {
376 let n = Hypergeometric::new(population, successes, draws);
377 assert!(n.is_ok());
378 n.unwrap()
379 }
380
381 fn create_case(population: u64, successes: u64, draws: u64) {
382 let n = try_create(population, successes, draws);
383 assert_eq!(population, n.population());
384 assert_eq!(successes, n.successes());
385 assert_eq!(draws, n.draws());
386 }
387
388 fn bad_create_case(population: u64, successes: u64, draws: u64) {
389 let n = Hypergeometric::new(population, successes, draws);
390 assert!(n.is_err());
391 }
392
393 fn get_value<T, F>(population: u64, successes: u64, draws: u64, eval: F) -> T
394 where T: PartialEq + Debug,
395 F: Fn(Hypergeometric) -> T
396 {
397 let n = try_create(population, successes, draws);
398 eval(n)
399 }
400
401 fn test_case<T, F>(population: u64, successes: u64, draws: u64, expected: T, eval: F)
402 where T: PartialEq + Debug,
403 F: Fn(Hypergeometric) -> T
404 {
405 let x = get_value(population, successes, draws, eval);
406 assert_eq!(expected, x);
407 }
408
409 fn test_almost<F>(population: u64, successes: u64, draws: u64, expected: f64, acc: f64, eval: F)
410 where F: Fn(Hypergeometric) -> f64
411 {
412 let x = get_value(population, successes, draws, eval);
413 assert_almost_eq!(expected, x, acc);
414 }
415
416 #[test]
417 fn test_create() {
418 create_case(0, 0, 0);
419 create_case(1, 1, 1,);
420 create_case(2, 1, 1);
421 create_case(2, 2, 2);
422 create_case(10, 1, 1);
423 create_case(10, 5, 3);
424 }
425
426 #[test]
427 fn test_bad_create() {
428 bad_create_case(2, 3, 2);
429 bad_create_case(10, 5, 20);
430 bad_create_case(0, 1, 1);
431 }
432
433 #[test]
434 fn test_mean() {
435 let mean = |x: Hypergeometric| x.mean().unwrap();
436 test_case(1, 1, 1, 1.0, mean);
437 test_case(2, 1, 1, 0.5, mean);
438 test_case(2, 2, 2, 2.0, mean);
439 test_case(10, 1, 1, 0.1, mean);
440 test_case(10, 5, 3, 15.0 / 10.0, mean);
441 }
442
443 #[test]
444 #[should_panic]
445 fn test_mean_with_population_0() {
446 let mean = |x: Hypergeometric| x.mean().unwrap();
447 get_value(0, 0, 0, mean);
448 }
449
450 #[test]
451 fn test_variance() {
452 let variance = |x: Hypergeometric| x.variance().unwrap();
453 test_case(2, 1, 1, 0.25, variance);
454 test_case(2, 2, 2, 0.0, variance);
455 test_case(10, 1, 1, 81.0 / 900.0, variance);
456 test_case(10, 5, 3, 525.0 / 900.0, variance);
457 }
458
459 #[test]
460 #[should_panic]
461 fn test_variance_with_pop_lte_1() {
462 let variance = |x: Hypergeometric| x.variance().unwrap();
463 get_value(1, 1, 1, variance);
464 }
465
466 #[test]
467 fn test_skewness() {
468 let skewness = |x: Hypergeometric| x.skewness().unwrap();
469 test_case(10, 1, 1, 8.0 / 3.0, skewness);
470 test_case(10, 5, 3, 0.0, skewness);
471 }
472
473 #[test]
474 #[should_panic]
475 fn test_skewness_with_pop_lte_2() {
476 let skewness = |x: Hypergeometric| x.skewness().unwrap();
477 get_value(2, 2, 2, skewness);
478 }
479
480 #[test]
481 fn test_mode() {
482 let mode = |x: Hypergeometric| x.mode().unwrap();
483 test_case(0, 0, 0, 0, mode);
484 test_case(1, 1, 1, 1, mode);
485 test_case(2, 1, 1, 1, mode);
486 test_case(2, 2, 2, 2, mode);
487 test_case(10, 1, 1, 0, mode);
488 test_case(10, 5, 3, 2, mode);
489 }
490
491 #[test]
492 fn test_min() {
493 let min = |x: Hypergeometric| x.min();
494 test_case(0, 0, 0, 0, min);
495 test_case(1, 1, 1, 1, min);
496 test_case(2, 1, 1, 0, min);
497 test_case(2, 2, 2, 2, min);
498 test_case(10, 1, 1, 0, min);
499 test_case(10, 5, 3, 0, min);
500 }
501
502 #[test]
503 fn test_max() {
504 let max = |x: Hypergeometric| x.max();
505 test_case(0, 0, 0, 0, max);
506 test_case(1, 1, 1, 1, max);
507 test_case(2, 1, 1, 1, max);
508 test_case(2, 2, 2, 2, max);
509 test_case(10, 1, 1, 1, max);
510 test_case(10, 5, 3, 3, max);
511 }
512
513 #[test]
514 fn test_pmf() {
515 let pmf = |arg: u64| move |x: Hypergeometric| x.pmf(arg);
516 test_case(0, 0, 0, 1.0, pmf(0));
517 test_case(1, 1, 1, 1.0, pmf(1));
518 test_case(2, 1, 1, 0.5, pmf(0));
519 test_case(2, 1, 1, 0.5, pmf(1));
520 test_case(2, 2, 2, 1.0, pmf(2));
521 test_case(10, 1, 1, 0.9, pmf(0));
522 test_case(10, 1, 1, 0.1, pmf(1));
523 test_case(10, 5, 3, 0.41666666666666666667, pmf(1));
524 test_case(10, 5, 3, 0.083333333333333333333, pmf(3));
525 }
526
527 #[test]
528 fn test_ln_pmf() {
529 let ln_pmf = |arg: u64| move |x: Hypergeometric| x.ln_pmf(arg);
530 test_case(0, 0, 0, 0.0, ln_pmf(0));
531 test_case(1, 1, 1, 0.0, ln_pmf(1));
532 test_case(2, 1, 1, -0.6931471805599453094172, ln_pmf(0));
533 test_case(2, 1, 1, -0.6931471805599453094172, ln_pmf(1));
534 test_case(2, 2, 2, 0.0, ln_pmf(2));
535 test_almost(10, 1, 1, -0.1053605156578263012275, 1e-14, ln_pmf(0));
536 test_almost(10, 1, 1, -2.302585092994045684018, 1e-14, ln_pmf(1));
537 test_almost(10, 5, 3, -0.875468737353899935621, 1e-14, ln_pmf(1));
538 test_almost(10, 5, 3, -2.484906649788000310234, 1e-14, ln_pmf(3));
539 }
540
541 #[test]
542 fn test_cdf() {
543 let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg);
544 test_case(2, 1, 1, 0.5, cdf(0));
545 test_almost(10, 1, 1, 0.9, 1e-14, cdf(0));
546 test_almost(10, 5, 3, 0.5, 1e-15, cdf(1));
547 test_almost(10, 5, 3, 11.0 / 12.0, 1e-14, cdf(2));
548 test_almost(10000, 2, 9800, 199.0 / 499950.0, 1e-14, cdf(0));
549 test_almost(10000, 2, 9800, 19799.0 / 499950.0, 1e-12, cdf(1));
550 }
551
552 #[test]
553 fn test_sf() {
554 let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg);
555 test_case(2, 1, 1, 0.5, sf(0));
556 test_almost(10, 1, 1, 0.1, 1e-14, sf(0));
557 test_almost(10, 5, 3, 0.5, 1e-15, sf(1));
558 test_almost(10, 5, 3, 1.0 / 12.0, 1e-14, sf(2));
559 test_almost(10000, 2, 9800, 499751. / 499950.0, 1e-10, sf(0));
560 test_almost(10000, 2, 9800, 480151. / 499950.0, 1e-10, sf(1));
561 }
562
563 #[test]
564 fn test_cdf_arg_too_big() {
565 let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg);
566 test_case(0, 0, 0, 1.0, cdf(0));
567 }
568
569 #[test]
570 fn test_cdf_arg_too_small() {
571 let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg);
572 test_case(2, 2, 2, 0.0, cdf(0));
573 }
574
575 #[test]
576 fn test_sf_arg_too_big() {
577 let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg);
578 test_case(0, 0, 0, 0.0, sf(0));
579 }
580
581 #[test]
582 fn test_sf_arg_too_small() {
583 let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg);
584 test_case(2, 2, 2, 1.0, sf(0));
585 }
586
587 #[test]
588 fn test_discrete() {
589 test::check_discrete_distribution(&try_create(5, 4, 3), 4);
590 test::check_discrete_distribution(&try_create(3, 2, 1), 2);
591 }
592}