statrs/distribution/
hypergeometric.rs

1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::function::factorial;
3use crate::statistics::*;
4use crate::{Result, StatsError};
5use rand::Rng;
6use std::cmp;
7use std::f64;
8
9/// Implements the
10/// [Hypergeometric](http://en.wikipedia.org/wiki/Hypergeometric_distribution)
11/// distribution
12///
13/// # Examples
14///
15/// ```
16/// ```
17#[derive(Debug, Copy, Clone, PartialEq)]
18pub struct Hypergeometric {
19    population: u64,
20    successes: u64,
21    draws: u64,
22}
23
24impl Hypergeometric {
25    /// Constructs a new hypergeometric distribution
26    /// with a population (N) of `population`, number
27    /// of successes (K) of `successes`, and number of draws
28    /// (n) of `draws`
29    ///
30    /// # Errors
31    ///
32    /// If `successes > population` or `draws > population`
33    ///
34    /// # Examples
35    ///
36    /// ```
37    /// use statrs::distribution::Hypergeometric;
38    ///
39    /// let mut result = Hypergeometric::new(2, 2, 2);
40    /// assert!(result.is_ok());
41    ///
42    /// result = Hypergeometric::new(2, 3, 2);
43    /// assert!(result.is_err());
44    /// ```
45    pub fn new(population: u64, successes: u64, draws: u64) -> Result<Hypergeometric> {
46        if successes > population || draws > population {
47            Err(StatsError::BadParams)
48        } else {
49            Ok(Hypergeometric {
50                population,
51                successes,
52                draws,
53            })
54        }
55    }
56
57    /// Returns the population size of the hypergeometric
58    /// distribution
59    ///
60    /// # Examples
61    ///
62    /// ```
63    /// use statrs::distribution::Hypergeometric;
64    ///
65    /// let n = Hypergeometric::new(10, 5, 3).unwrap();
66    /// assert_eq!(n.population(), 10);
67    /// ```
68    pub fn population(&self) -> u64 {
69        self.population
70    }
71
72    /// Returns the number of observed successes of the hypergeometric
73    /// distribution
74    ///
75    /// # Examples
76    ///
77    /// ```
78    /// use statrs::distribution::Hypergeometric;
79    ///
80    /// let n = Hypergeometric::new(10, 5, 3).unwrap();
81    /// assert_eq!(n.successes(), 5);
82    /// ```
83    pub fn successes(&self) -> u64 {
84        self.successes
85    }
86
87    /// Returns the number of draws of the hypergeometric
88    /// distribution
89    ///
90    /// # Examples
91    ///
92    /// ```
93    /// use statrs::distribution::Hypergeometric;
94    ///
95    /// let n = Hypergeometric::new(10, 5, 3).unwrap();
96    /// assert_eq!(n.draws(), 3);
97    /// ```
98    pub fn draws(&self) -> u64 {
99        self.draws
100    }
101
102    /// Returns population, successes, and draws in that order
103    /// as a tuple of doubles
104    fn values_f64(&self) -> (f64, f64, f64) {
105        (
106            self.population as f64,
107            self.successes as f64,
108            self.draws as f64,
109        )
110    }
111}
112
113impl ::rand::distributions::Distribution<f64> for Hypergeometric {
114    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
115        let mut population = self.population as f64;
116        let mut successes = self.successes as f64;
117        let mut draws = self.draws;
118        let mut x = 0.0;
119        loop {
120            let p = successes / population;
121            let next: f64 = rng.gen();
122            if next < p {
123                x += 1.0;
124                successes -= 1.0;
125            }
126            population -= 1.0;
127            draws -= 1;
128            if draws == 0 {
129                break;
130            }
131        }
132        x
133    }
134}
135
136impl DiscreteCDF<u64, f64> for Hypergeometric {
137    /// Calculates the cumulative distribution function for the hypergeometric
138    /// distribution at `x`
139    ///
140    /// # Formula
141    ///
142    /// ```ignore
143    /// 1 - ((n choose k+1) * (N-n choose K-k-1)) / (N choose K) * 3_F_2(1,
144    /// k+1-K, k+1-n; k+2, N+k+2-K-n; 1)
145    /// ```
146    ///
147    /// where `N` is population, `K` is successes, `n` is draws,
148    /// and `p_F_q` is the [generalized hypergeometric
149    /// function](https://en.wikipedia.
150    /// org/wiki/Generalized_hypergeometric_function)
151    ///
152    /// Calculated as a discrete integral over the probability mass
153    /// function evaluated from 0..k+1
154    fn cdf(&self, x: u64) -> f64 {
155        if x < self.min() {
156            0.0
157        } else if x >= self.max() {
158            1.0
159        } else {
160            let k = x;
161            let ln_denom = factorial::ln_binomial(self.population, self.draws);
162            (0..k + 1).fold(0.0, |acc, i| {
163                acc + (factorial::ln_binomial(self.successes, i)
164                    + factorial::ln_binomial(self.population - self.successes, self.draws - i)
165                    - ln_denom)
166                    .exp()
167            })
168        }
169    }
170
171    /// Calculates the survival function for the hypergeometric
172    /// distribution at `x`
173    ///
174    /// # Formula
175    ///
176    /// ```ignore
177    /// 1 - ((n choose k+1) * (N-n choose K-k-1)) / (N choose K) * 3_F_2(1,
178    /// k+1-K, k+1-n; k+2, N+k+2-K-n; 1)
179    /// ```
180    ///
181    /// where `N` is population, `K` is successes, `n` is draws,
182    /// and `p_F_q` is the [generalized hypergeometric
183    /// function](https://en.wikipedia.
184    /// org/wiki/Generalized_hypergeometric_function)
185    ///
186    /// Calculated as a discrete integral over the probability mass
187    /// function evaluated from (k+1)..max
188    fn sf(&self, x: u64) -> f64 {
189        if x < self.min() {
190            1.0
191        } else if x >= self.max() {
192            0.0
193        } else {
194            let k = x;
195            let ln_denom = factorial::ln_binomial(self.population, self.draws);
196            (k + 1 .. self.max() + 1).fold(0.0, |acc, i| {
197                acc + (factorial::ln_binomial(self.successes, i)
198                    + factorial::ln_binomial(self.population - self.successes, self.draws - i)
199                    - ln_denom)
200                    .exp()
201            })
202        }
203    }
204}
205
206impl Min<u64> for Hypergeometric {
207    /// Returns the minimum value in the domain of the
208    /// hypergeometric distribution representable by a 64-bit
209    /// integer
210    ///
211    /// # Formula
212    ///
213    /// ```ignore
214    /// max(0, n + K - N)
215    /// ```
216    ///
217    /// where `N` is population, `K` is successes, and `n` is draws
218    fn min(&self) -> u64 {
219        (self.draws + self.successes).saturating_sub(self.population)
220    }
221}
222
223impl Max<u64> for Hypergeometric {
224    /// Returns the maximum value in the domain of the
225    /// hypergeometric distribution representable by a 64-bit
226    /// integer
227    ///
228    /// # Formula
229    ///
230    /// ```ignore
231    /// min(K, n)
232    /// ```
233    ///
234    /// where `K` is successes and `n` is draws
235    fn max(&self) -> u64 {
236        cmp::min(self.successes, self.draws)
237    }
238}
239
240impl Distribution<f64> for Hypergeometric {
241    /// Returns the mean of the hypergeometric distribution
242    ///
243    /// # None
244    ///
245    /// If `N` is `0`
246    ///
247    /// # Formula
248    ///
249    /// ```ignore
250    /// K * n / N
251    /// ```
252    ///
253    /// where `N` is population, `K` is successes, and `n` is draws
254    fn mean(&self) -> Option<f64> {
255        if self.population == 0 {
256            None
257        } else {
258            Some(self.successes as f64 * self.draws as f64 / self.population as f64)
259        }
260    }
261    /// Returns the variance of the hypergeometric distribution
262    ///
263    /// # None
264    ///
265    /// If `N <= 1`
266    ///
267    /// # Formula
268    ///
269    /// ```ignore
270    /// n * (K / N) * ((N - K) / N) * ((N - n) / (N - 1))
271    /// ```
272    ///
273    /// where `N` is population, `K` is successes, and `n` is draws
274    fn variance(&self) -> Option<f64> {
275        if self.population <= 1 {
276            None
277        } else {
278            let (population, successes, draws) = self.values_f64();
279            let val = draws * successes * (population - draws) * (population - successes)
280                / (population * population * (population - 1.0));
281            Some(val)
282        }
283    }
284    /// Returns the skewness of the hypergeometric distribution
285    ///
286    /// # None
287    ///
288    /// If `N <= 2`
289    ///
290    /// # Formula
291    ///
292    /// ```ignore
293    /// ((N - 2K) * (N - 1)^(1 / 2) * (N - 2n)) / ([n * K * (N - K) * (N -
294    /// n)]^(1 / 2) * (N - 2))
295    /// ```
296    ///
297    /// where `N` is population, `K` is successes, and `n` is draws
298    fn skewness(&self) -> Option<f64> {
299        if self.population <= 2 {
300            None
301        } else {
302            let (population, successes, draws) = self.values_f64();
303            let val = (population - 1.0).sqrt()
304                * (population - 2.0 * draws)
305                * (population - 2.0 * successes)
306                / ((draws * successes * (population - successes) * (population - draws)).sqrt()
307                    * (population - 2.0));
308            Some(val)
309        }
310    }
311}
312
313impl Mode<Option<u64>> for Hypergeometric {
314    /// Returns the mode of the hypergeometric distribution
315    ///
316    /// # Formula
317    ///
318    /// ```ignore
319    /// floor((n + 1) * (k + 1) / (N + 2))
320    /// ```
321    ///
322    /// where `N` is population, `K` is successes, and `n` is draws
323    fn mode(&self) -> Option<u64> {
324        Some(((self.draws + 1) * (self.successes + 1)) / (self.population + 2))
325    }
326}
327
328impl Discrete<u64, f64> for Hypergeometric {
329    /// Calculates the probability mass function for the hypergeometric
330    /// distribution at `x`
331    ///
332    /// # Formula
333    ///
334    /// ```ignore
335    /// (K choose x) * (N-K choose n-x) / (N choose n)
336    /// ```
337    ///
338    /// where `N` is population, `K` is successes, and `n` is draws
339    fn pmf(&self, x: u64) -> f64 {
340        if x > self.draws {
341            0.0
342        } else {
343            factorial::binomial(self.successes, x)
344                * factorial::binomial(self.population - self.successes, self.draws - x)
345                / factorial::binomial(self.population, self.draws)
346        }
347    }
348
349    /// Calculates the log probability mass function for the hypergeometric
350    /// distribution at `x`
351    ///
352    /// # Formula
353    ///
354    /// ```ignore
355    /// ln((K choose x) * (N-K choose n-x) / (N choose n))
356    /// ```
357    ///
358    /// where `N` is population, `K` is successes, and `n` is draws
359    fn ln_pmf(&self, x: u64) -> f64 {
360        factorial::ln_binomial(self.successes, x)
361            + factorial::ln_binomial(self.population - self.successes, self.draws - x)
362            - factorial::ln_binomial(self.population, self.draws)
363    }
364}
365
366#[rustfmt::skip]
367#[cfg(all(test, feature = "nightly"))]
368mod tests {
369    use std::fmt::Debug;
370    use crate::statistics::*;
371    use crate::distribution::{DiscreteCDF, Discrete, Hypergeometric};
372    use crate::distribution::internal::*;
373    use crate::consts::ACC;
374
375    fn try_create(population: u64, successes: u64, draws: u64) -> Hypergeometric {
376        let n = Hypergeometric::new(population, successes, draws);
377        assert!(n.is_ok());
378        n.unwrap()
379    }
380
381    fn create_case(population: u64, successes: u64, draws: u64) {
382        let n = try_create(population, successes, draws);
383        assert_eq!(population, n.population());
384        assert_eq!(successes, n.successes());
385        assert_eq!(draws, n.draws());
386    }
387
388    fn bad_create_case(population: u64, successes: u64, draws: u64) {
389        let n = Hypergeometric::new(population, successes, draws);
390        assert!(n.is_err());
391    }
392
393    fn get_value<T, F>(population: u64, successes: u64, draws: u64, eval: F) -> T
394        where T: PartialEq + Debug,
395              F: Fn(Hypergeometric) -> T
396    {
397        let n = try_create(population, successes, draws);
398        eval(n)
399    }
400
401    fn test_case<T, F>(population: u64, successes: u64, draws: u64, expected: T, eval: F)
402        where T: PartialEq + Debug,
403              F: Fn(Hypergeometric) -> T
404    {
405        let x = get_value(population, successes, draws, eval);
406        assert_eq!(expected, x);
407    }
408
409    fn test_almost<F>(population: u64, successes: u64, draws: u64, expected: f64, acc: f64, eval: F)
410        where F: Fn(Hypergeometric) -> f64
411    {
412        let x = get_value(population, successes, draws, eval);
413        assert_almost_eq!(expected, x, acc);
414    }
415
416    #[test]
417    fn test_create() {
418        create_case(0, 0, 0);
419        create_case(1, 1, 1,);
420        create_case(2, 1, 1);
421        create_case(2, 2, 2);
422        create_case(10, 1, 1);
423        create_case(10, 5, 3);
424    }
425
426    #[test]
427    fn test_bad_create() {
428        bad_create_case(2, 3, 2);
429        bad_create_case(10, 5, 20);
430        bad_create_case(0, 1, 1);
431    }
432
433    #[test]
434    fn test_mean() {
435        let mean = |x: Hypergeometric| x.mean().unwrap();
436        test_case(1, 1, 1, 1.0, mean);
437        test_case(2, 1, 1, 0.5, mean);
438        test_case(2, 2, 2, 2.0, mean);
439        test_case(10, 1, 1, 0.1, mean);
440        test_case(10, 5, 3, 15.0 / 10.0, mean);
441    }
442
443    #[test]
444    #[should_panic]
445    fn test_mean_with_population_0() {
446        let mean = |x: Hypergeometric| x.mean().unwrap();
447        get_value(0, 0, 0, mean);
448    }
449
450    #[test]
451    fn test_variance() {
452        let variance = |x: Hypergeometric| x.variance().unwrap();
453        test_case(2, 1, 1, 0.25, variance);
454        test_case(2, 2, 2, 0.0, variance);
455        test_case(10, 1, 1, 81.0 / 900.0, variance);
456        test_case(10, 5, 3, 525.0 / 900.0, variance);
457    }
458
459    #[test]
460    #[should_panic]
461    fn test_variance_with_pop_lte_1() {
462        let variance = |x: Hypergeometric| x.variance().unwrap();
463        get_value(1, 1, 1, variance);
464    }
465
466    #[test]
467    fn test_skewness() {
468        let skewness = |x: Hypergeometric| x.skewness().unwrap();
469        test_case(10, 1, 1, 8.0 / 3.0, skewness);
470        test_case(10, 5, 3, 0.0, skewness);
471    }
472
473    #[test]
474    #[should_panic]
475    fn test_skewness_with_pop_lte_2() {
476        let skewness = |x: Hypergeometric| x.skewness().unwrap();
477        get_value(2, 2, 2, skewness);
478    }
479
480    #[test]
481    fn test_mode() {
482        let mode = |x: Hypergeometric| x.mode().unwrap();
483        test_case(0, 0, 0, 0, mode);
484        test_case(1, 1, 1, 1, mode);
485        test_case(2, 1, 1, 1, mode);
486        test_case(2, 2, 2, 2, mode);
487        test_case(10, 1, 1, 0, mode);
488        test_case(10, 5, 3, 2, mode);
489    }
490
491    #[test]
492    fn test_min() {
493        let min = |x: Hypergeometric| x.min();
494        test_case(0, 0, 0, 0, min);
495        test_case(1, 1, 1, 1, min);
496        test_case(2, 1, 1, 0, min);
497        test_case(2, 2, 2, 2, min);
498        test_case(10, 1, 1, 0, min);
499        test_case(10, 5, 3, 0, min);
500    }
501
502    #[test]
503    fn test_max() {
504        let max = |x: Hypergeometric| x.max();
505        test_case(0, 0, 0, 0, max);
506        test_case(1, 1, 1, 1, max);
507        test_case(2, 1, 1, 1, max);
508        test_case(2, 2, 2, 2, max);
509        test_case(10, 1, 1, 1, max);
510        test_case(10, 5, 3, 3, max);
511    }
512
513    #[test]
514    fn test_pmf() {
515        let pmf = |arg: u64| move |x: Hypergeometric| x.pmf(arg);
516        test_case(0, 0, 0, 1.0, pmf(0));
517        test_case(1, 1, 1, 1.0, pmf(1));
518        test_case(2, 1, 1, 0.5, pmf(0));
519        test_case(2, 1, 1, 0.5, pmf(1));
520        test_case(2, 2, 2, 1.0, pmf(2));
521        test_case(10, 1, 1, 0.9, pmf(0));
522        test_case(10, 1, 1, 0.1, pmf(1));
523        test_case(10, 5, 3, 0.41666666666666666667, pmf(1));
524        test_case(10, 5, 3, 0.083333333333333333333, pmf(3));
525    }
526
527    #[test]
528    fn test_ln_pmf() {
529        let ln_pmf = |arg: u64| move |x: Hypergeometric| x.ln_pmf(arg);
530        test_case(0, 0, 0, 0.0, ln_pmf(0));
531        test_case(1, 1, 1, 0.0, ln_pmf(1));
532        test_case(2, 1, 1, -0.6931471805599453094172, ln_pmf(0));
533        test_case(2, 1, 1, -0.6931471805599453094172, ln_pmf(1));
534        test_case(2, 2, 2, 0.0, ln_pmf(2));
535        test_almost(10, 1, 1, -0.1053605156578263012275, 1e-14, ln_pmf(0));
536        test_almost(10, 1, 1, -2.302585092994045684018, 1e-14, ln_pmf(1));
537        test_almost(10, 5, 3, -0.875468737353899935621, 1e-14, ln_pmf(1));
538        test_almost(10, 5, 3, -2.484906649788000310234, 1e-14, ln_pmf(3));
539    }
540
541    #[test]
542    fn test_cdf() {
543        let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg);
544        test_case(2, 1, 1, 0.5, cdf(0));
545        test_almost(10, 1, 1, 0.9, 1e-14, cdf(0));
546        test_almost(10, 5, 3, 0.5, 1e-15, cdf(1));
547        test_almost(10, 5, 3, 11.0 / 12.0, 1e-14, cdf(2));
548        test_almost(10000, 2, 9800, 199.0 / 499950.0, 1e-14, cdf(0));
549        test_almost(10000, 2, 9800, 19799.0 / 499950.0, 1e-12, cdf(1));
550    }
551
552    #[test]
553    fn test_sf() {
554        let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg);
555        test_case(2, 1, 1, 0.5, sf(0));
556        test_almost(10, 1, 1, 0.1, 1e-14, sf(0));
557        test_almost(10, 5, 3, 0.5, 1e-15, sf(1));
558        test_almost(10, 5, 3, 1.0 / 12.0, 1e-14, sf(2));
559        test_almost(10000, 2, 9800, 499751. / 499950.0, 1e-10, sf(0));
560        test_almost(10000, 2, 9800, 480151. / 499950.0, 1e-10, sf(1));
561    }
562
563    #[test]
564    fn test_cdf_arg_too_big() {
565        let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg);
566        test_case(0, 0, 0, 1.0, cdf(0));
567    }
568
569    #[test]
570    fn test_cdf_arg_too_small() {
571        let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg);
572        test_case(2, 2, 2, 0.0, cdf(0));
573    }
574
575    #[test]
576    fn test_sf_arg_too_big() {
577        let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg);
578        test_case(0, 0, 0, 0.0, sf(0));
579    }
580
581    #[test]
582    fn test_sf_arg_too_small() {
583        let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg);
584        test_case(2, 2, 2, 1.0, sf(0));
585    }
586
587    #[test]
588    fn test_discrete() {
589        test::check_discrete_distribution(&try_create(5, 4, 3), 4);
590        test::check_discrete_distribution(&try_create(3, 2, 1), 2);
591    }
592}