statrs/distribution/
hypergeometric.rs

1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::function::factorial;
3use crate::statistics::*;
4use std::cmp;
5use std::f64;
6
7/// Implements the
8/// [Hypergeometric](http://en.wikipedia.org/wiki/Hypergeometric_distribution)
9/// distribution
10///
11/// # Examples
12///
13/// ```
14/// use statrs::distribution::{Hypergeometric, Discrete};
15/// use statrs::statistics::Distribution;
16/// use statrs::prec;
17///
18/// let n = Hypergeometric::new(500, 50, 100).unwrap();
19/// assert_eq!(n.mean().unwrap(), 10.);
20/// assert!(prec::almost_eq(n.pmf(10), 0.14736784, 1e-8));
21/// assert!(prec::almost_eq(n.pmf(25), 3.537e-7, 1e-10));
22/// ```
23#[derive(Copy, Clone, PartialEq, Eq, Debug)]
24pub struct Hypergeometric {
25    population: u64,
26    successes: u64,
27    draws: u64,
28}
29
30/// Represents the errors that can occur when creating a [`Hypergeometric`].
31#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
32#[non_exhaustive]
33pub enum HypergeometricError {
34    /// The number of successes is greater than the population.
35    TooManySuccesses,
36
37    /// The number of draws is greater than the population.
38    TooManyDraws,
39}
40
41impl std::fmt::Display for HypergeometricError {
42    #[cfg_attr(coverage_nightly, coverage(off))]
43    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
44        match self {
45            HypergeometricError::TooManySuccesses => write!(f, "successes > population"),
46            HypergeometricError::TooManyDraws => write!(f, "draws > population"),
47        }
48    }
49}
50
51impl std::error::Error for HypergeometricError {}
52
53impl Hypergeometric {
54    /// Constructs a new hypergeometric distribution
55    /// with a population (N) of `population`, number
56    /// of successes (K) of `successes`, and number of draws
57    /// (n) of `draws`.
58    ///
59    /// # Errors
60    ///
61    /// If `successes > population` or `draws > population`.
62    ///
63    /// # Examples
64    ///
65    /// ```
66    /// use statrs::distribution::Hypergeometric;
67    ///
68    /// let mut result = Hypergeometric::new(2, 2, 2);
69    /// assert!(result.is_ok());
70    ///
71    /// result = Hypergeometric::new(2, 3, 2);
72    /// assert!(result.is_err());
73    /// ```
74    pub fn new(
75        population: u64,
76        successes: u64,
77        draws: u64,
78    ) -> Result<Hypergeometric, HypergeometricError> {
79        if successes > population {
80            return Err(HypergeometricError::TooManySuccesses);
81        }
82
83        if draws > population {
84            return Err(HypergeometricError::TooManyDraws);
85        }
86
87        Ok(Hypergeometric {
88            population,
89            successes,
90            draws,
91        })
92    }
93
94    /// Returns the population size of the hypergeometric
95    /// distribution
96    ///
97    /// # Examples
98    ///
99    /// ```
100    /// use statrs::distribution::Hypergeometric;
101    ///
102    /// let n = Hypergeometric::new(10, 5, 3).unwrap();
103    /// assert_eq!(n.population(), 10);
104    /// ```
105    pub fn population(&self) -> u64 {
106        self.population
107    }
108
109    /// Returns the number of observed successes of the hypergeometric
110    /// distribution
111    ///
112    /// # Examples
113    ///
114    /// ```
115    /// use statrs::distribution::Hypergeometric;
116    ///
117    /// let n = Hypergeometric::new(10, 5, 3).unwrap();
118    /// assert_eq!(n.successes(), 5);
119    /// ```
120    pub fn successes(&self) -> u64 {
121        self.successes
122    }
123
124    /// Returns the number of draws of the hypergeometric
125    /// distribution
126    ///
127    /// # Examples
128    ///
129    /// ```
130    /// use statrs::distribution::Hypergeometric;
131    ///
132    /// let n = Hypergeometric::new(10, 5, 3).unwrap();
133    /// assert_eq!(n.draws(), 3);
134    /// ```
135    pub fn draws(&self) -> u64 {
136        self.draws
137    }
138
139    /// Returns population, successes, and draws in that order
140    /// as a tuple of doubles
141    fn values_f64(&self) -> (f64, f64, f64) {
142        (
143            self.population as f64,
144            self.successes as f64,
145            self.draws as f64,
146        )
147    }
148}
149
150impl std::fmt::Display for Hypergeometric {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        write!(
153            f,
154            "Hypergeometric({},{},{})",
155            self.population, self.successes, self.draws
156        )
157    }
158}
159
160#[cfg(feature = "rand")]
161#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
162impl ::rand::distributions::Distribution<u64> for Hypergeometric {
163    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
164        let mut population = self.population as f64;
165        let mut successes = self.successes as f64;
166        let mut draws = self.draws;
167        let mut x = 0;
168        loop {
169            let p = successes / population;
170            let next: f64 = rng.gen();
171            if next < p {
172                x += 1;
173                successes -= 1.0;
174            }
175            population -= 1.0;
176            draws -= 1;
177            if draws == 0 {
178                break;
179            }
180        }
181        x
182    }
183}
184
185#[cfg(feature = "rand")]
186#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
187impl ::rand::distributions::Distribution<f64> for Hypergeometric {
188    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
189        rng.sample::<u64, _>(self) as f64
190    }
191}
192
193impl DiscreteCDF<u64, f64> for Hypergeometric {
194    /// Calculates the cumulative distribution function for the hypergeometric
195    /// distribution at `x`
196    ///
197    /// # Formula
198    ///
199    /// ```text
200    /// 1 - ((n choose x+1) * (N-n choose K-x-1)) / (N choose K) * 3_F_2(1,
201    /// x+1-K, x+1-n; k+2, N+x+2-K-n; 1)
202    /// ```
203    ///
204    /// where `N` is population, `K` is successes, `n` is draws,
205    /// and `p_F_q` is the
206    /// [generalized hypergeometric function](https://en.wikipedia.org/wiki/Generalized_hypergeometric_function)
207    ///
208    /// Calculated as a discrete integral over the probability mass
209    /// function evaluated from 0..x+1
210    fn cdf(&self, x: u64) -> f64 {
211        if x < self.min() {
212            0.0
213        } else if x >= self.max() {
214            1.0
215        } else {
216            let k = x;
217            let ln_denom = factorial::ln_binomial(self.population, self.draws);
218            (0..k + 1).fold(0.0, |acc, i| {
219                acc + (factorial::ln_binomial(self.successes, i)
220                    + factorial::ln_binomial(self.population - self.successes, self.draws - i)
221                    - ln_denom)
222                    .exp()
223            })
224        }
225    }
226
227    /// Calculates the survival function for the hypergeometric
228    /// distribution at `x`
229    ///
230    /// # Formula
231    ///
232    /// ```text
233    /// 1 - ((n choose x+1) * (N-n choose K-x-1)) / (N choose K) * 3_F_2(1,
234    /// x+1-K, x+1-n; x+2, N+x+2-K-n; 1)
235    /// ```
236    ///
237    /// where `N` is population, `K` is successes, `n` is draws,
238    /// and `p_F_q` is the
239    /// [generalized hypergeometric function](https://en.wikipedia.org/wiki/Generalized_hypergeometric_function)
240    ///
241    /// Calculated as a discrete integral over the probability mass
242    /// function evaluated from (x+1)..max
243    fn sf(&self, x: u64) -> f64 {
244        if x < self.min() {
245            1.0
246        } else if x >= self.max() {
247            0.0
248        } else {
249            let k = x;
250            let ln_denom = factorial::ln_binomial(self.population, self.draws);
251            (k + 1..=self.max()).fold(0.0, |acc, i| {
252                acc + (factorial::ln_binomial(self.successes, i)
253                    + factorial::ln_binomial(self.population - self.successes, self.draws - i)
254                    - ln_denom)
255                    .exp()
256            })
257        }
258    }
259}
260
261impl Min<u64> for Hypergeometric {
262    /// Returns the minimum value in the domain of the
263    /// hypergeometric distribution representable by a 64-bit
264    /// integer
265    ///
266    /// # Formula
267    ///
268    /// ```text
269    /// max(0, n + K - N)
270    /// ```
271    ///
272    /// where `N` is population, `K` is successes, and `n` is draws
273    fn min(&self) -> u64 {
274        (self.draws + self.successes).saturating_sub(self.population)
275    }
276}
277
278impl Max<u64> for Hypergeometric {
279    /// Returns the maximum value in the domain of the
280    /// hypergeometric distribution representable by a 64-bit
281    /// integer
282    ///
283    /// # Formula
284    ///
285    /// ```text
286    /// min(K, n)
287    /// ```
288    ///
289    /// where `K` is successes and `n` is draws
290    fn max(&self) -> u64 {
291        cmp::min(self.successes, self.draws)
292    }
293}
294
295impl Distribution<f64> for Hypergeometric {
296    /// Returns the mean of the hypergeometric distribution
297    ///
298    /// # None
299    ///
300    /// If `N` is `0`
301    ///
302    /// # Formula
303    ///
304    /// ```text
305    /// K * n / N
306    /// ```
307    ///
308    /// where `N` is population, `K` is successes, and `n` is draws
309    fn mean(&self) -> Option<f64> {
310        if self.population == 0 {
311            None
312        } else {
313            Some(self.successes as f64 * self.draws as f64 / self.population as f64)
314        }
315    }
316
317    /// Returns the variance of the hypergeometric distribution
318    ///
319    /// # None
320    ///
321    /// If `N <= 1`
322    ///
323    /// # Formula
324    ///
325    /// ```text
326    /// n * (K / N) * ((N - K) / N) * ((N - n) / (N - 1))
327    /// ```
328    ///
329    /// where `N` is population, `K` is successes, and `n` is draws
330    fn variance(&self) -> Option<f64> {
331        if self.population <= 1 {
332            None
333        } else {
334            let (population, successes, draws) = self.values_f64();
335            let val = draws * successes * (population - draws) * (population - successes)
336                / (population * population * (population - 1.0));
337            Some(val)
338        }
339    }
340
341    /// Returns the skewness of the hypergeometric distribution
342    ///
343    /// # None
344    ///
345    /// If `N <= 2`
346    ///
347    /// # Formula
348    ///
349    /// ```text
350    /// ((N - 2K) * (N - 1)^(1 / 2) * (N - 2n)) / ([n * K * (N - K) * (N -
351    /// n)]^(1 / 2) * (N - 2))
352    /// ```
353    ///
354    /// where `N` is population, `K` is successes, and `n` is draws
355    fn skewness(&self) -> Option<f64> {
356        if self.population <= 2 {
357            None
358        } else {
359            let (population, successes, draws) = self.values_f64();
360            let val = (population - 1.0).sqrt()
361                * (population - 2.0 * draws)
362                * (population - 2.0 * successes)
363                / ((draws * successes * (population - successes) * (population - draws)).sqrt()
364                    * (population - 2.0));
365            Some(val)
366        }
367    }
368}
369
370impl Mode<Option<u64>> for Hypergeometric {
371    /// Returns the mode of the hypergeometric distribution
372    ///
373    /// # Formula
374    ///
375    /// ```text
376    /// floor((n + 1) * (k + 1) / (N + 2))
377    /// ```
378    ///
379    /// where `N` is population, `K` is successes, and `n` is draws
380    fn mode(&self) -> Option<u64> {
381        Some(((self.draws + 1) * (self.successes + 1)) / (self.population + 2))
382    }
383}
384
385impl Discrete<u64, f64> for Hypergeometric {
386    /// Calculates the probability mass function for the hypergeometric
387    /// distribution at `x`
388    ///
389    /// # Formula
390    ///
391    /// ```text
392    /// (K choose x) * (N-K choose n-x) / (N choose n)
393    /// ```
394    ///
395    /// where `N` is population, `K` is successes, and `n` is draws
396    fn pmf(&self, x: u64) -> f64 {
397        if x > self.draws {
398            0.0
399        } else {
400            factorial::binomial(self.successes, x)
401                * factorial::binomial(self.population - self.successes, self.draws - x)
402                / factorial::binomial(self.population, self.draws)
403        }
404    }
405
406    /// Calculates the log probability mass function for the hypergeometric
407    /// distribution at `x`
408    ///
409    /// # Formula
410    ///
411    /// ```text
412    /// ln((K choose x) * (N-K choose n-x) / (N choose n))
413    /// ```
414    ///
415    /// where `N` is population, `K` is successes, and `n` is draws
416    fn ln_pmf(&self, x: u64) -> f64 {
417        factorial::ln_binomial(self.successes, x)
418            + factorial::ln_binomial(self.population - self.successes, self.draws - x)
419            - factorial::ln_binomial(self.population, self.draws)
420    }
421}
422
423#[rustfmt::skip]
424#[cfg(test)]
425mod tests {
426    use super::*;
427    use crate::distribution::internal::*;
428    use crate::testing_boiler;
429
430    testing_boiler!(population: u64, successes: u64, draws: u64; Hypergeometric; HypergeometricError);
431
432    #[test]
433    fn test_create() {
434        create_ok(0, 0, 0);
435        create_ok(1, 1, 1,);
436        create_ok(2, 1, 1);
437        create_ok(2, 2, 2);
438        create_ok(10, 1, 1);
439        create_ok(10, 5, 3);
440    }
441
442    #[test]
443    fn test_bad_create() {
444        test_create_err(2, 3, 2, HypergeometricError::TooManySuccesses);
445        test_create_err(10, 5, 20, HypergeometricError::TooManyDraws);
446        create_err(0, 1, 1);
447    }
448
449    #[test]
450    fn test_mean() {
451        let mean = |x: Hypergeometric| x.mean().unwrap();
452        test_exact(1, 1, 1, 1.0, mean);
453        test_exact(2, 1, 1, 0.5, mean);
454        test_exact(2, 2, 2, 2.0, mean);
455        test_exact(10, 1, 1, 0.1, mean);
456        test_exact(10, 5, 3, 15.0 / 10.0, mean);
457    }
458
459    #[test]
460    fn test_mean_with_population_0() {
461        test_none(0, 0, 0, |dist| dist.mean());
462    }
463
464    #[test]
465    fn test_variance() {
466        let variance = |x: Hypergeometric| x.variance().unwrap();
467        test_exact(2, 1, 1, 0.25, variance);
468        test_exact(2, 2, 2, 0.0, variance);
469        test_exact(10, 1, 1, 81.0 / 900.0, variance);
470        test_exact(10, 5, 3, 525.0 / 900.0, variance);
471    }
472
473    #[test]
474    fn test_variance_with_pop_lte_1() {
475        test_none(1, 1, 1, |dist| dist.variance());
476    }
477
478    #[test]
479    fn test_skewness() {
480        let skewness = |x: Hypergeometric| x.skewness().unwrap();
481        test_exact(10, 1, 1, 8.0 / 3.0, skewness);
482        test_exact(10, 5, 3, 0.0, skewness);
483    }
484
485    #[test]
486    fn test_skewness_with_pop_lte_2() {
487        test_none(2, 2, 2, |dist| dist.skewness());
488    }
489
490    #[test]
491    fn test_mode() {
492        let mode = |x: Hypergeometric| x.mode().unwrap();
493        test_exact(0, 0, 0, 0, mode);
494        test_exact(1, 1, 1, 1, mode);
495        test_exact(2, 1, 1, 1, mode);
496        test_exact(2, 2, 2, 2, mode);
497        test_exact(10, 1, 1, 0, mode);
498        test_exact(10, 5, 3, 2, mode);
499    }
500
501    #[test]
502    fn test_min() {
503        let min = |x: Hypergeometric| x.min();
504        test_exact(0, 0, 0, 0, min);
505        test_exact(1, 1, 1, 1, min);
506        test_exact(2, 1, 1, 0, min);
507        test_exact(2, 2, 2, 2, min);
508        test_exact(10, 1, 1, 0, min);
509        test_exact(10, 5, 3, 0, min);
510    }
511
512    #[test]
513    fn test_max() {
514        let max = |x: Hypergeometric| x.max();
515        test_exact(0, 0, 0, 0, max);
516        test_exact(1, 1, 1, 1, max);
517        test_exact(2, 1, 1, 1, max);
518        test_exact(2, 2, 2, 2, max);
519        test_exact(10, 1, 1, 1, max);
520        test_exact(10, 5, 3, 3, max);
521    }
522
523    #[test]
524    fn test_pmf() {
525        let pmf = |arg: u64| move |x: Hypergeometric| x.pmf(arg);
526        test_exact(0, 0, 0, 1.0, pmf(0));
527        test_exact(1, 1, 1, 1.0, pmf(1));
528        test_exact(2, 1, 1, 0.5, pmf(0));
529        test_exact(2, 1, 1, 0.5, pmf(1));
530        test_exact(2, 2, 2, 1.0, pmf(2));
531        test_exact(10, 1, 1, 0.9, pmf(0));
532        test_exact(10, 1, 1, 0.1, pmf(1));
533        test_exact(10, 5, 3, 0.41666666666666666667, pmf(1));
534        test_exact(10, 5, 3, 0.083333333333333333333, pmf(3));
535    }
536
537    #[test]
538    fn test_ln_pmf() {
539        let ln_pmf = |arg: u64| move |x: Hypergeometric| x.ln_pmf(arg);
540        test_exact(0, 0, 0, 0.0, ln_pmf(0));
541        test_exact(1, 1, 1, 0.0, ln_pmf(1));
542        test_exact(2, 1, 1, -0.6931471805599453094172, ln_pmf(0));
543        test_exact(2, 1, 1, -0.6931471805599453094172, ln_pmf(1));
544        test_exact(2, 2, 2, 0.0, ln_pmf(2));
545        test_absolute(10, 1, 1, -0.1053605156578263012275, 1e-14, ln_pmf(0));
546        test_absolute(10, 1, 1, -2.302585092994045684018, 1e-14, ln_pmf(1));
547        test_absolute(10, 5, 3, -0.875468737353899935621, 1e-14, ln_pmf(1));
548        test_absolute(10, 5, 3, -2.484906649788000310234, 1e-14, ln_pmf(3));
549    }
550
551    #[test]
552    fn test_cdf() {
553        let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg);
554        test_exact(2, 1, 1, 0.5, cdf(0));
555        test_absolute(10, 1, 1, 0.9, 1e-14, cdf(0));
556        test_absolute(10, 5, 3, 0.5, 1e-15, cdf(1));
557        test_absolute(10, 5, 3, 11.0 / 12.0, 1e-14, cdf(2));
558        test_absolute(10000, 2, 9800, 199.0 / 499950.0, 1e-14, cdf(0));
559        test_absolute(10000, 2, 9800, 19799.0 / 499950.0, 1e-12, cdf(1));
560    }
561
562    #[test]
563    fn test_sf() {
564        let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg);
565        test_exact(2, 1, 1, 0.5, sf(0));
566        test_absolute(10, 1, 1, 0.1, 1e-14, sf(0));
567        test_absolute(10, 5, 3, 0.5, 1e-15, sf(1));
568        test_absolute(10, 5, 3, 1.0 / 12.0, 1e-14, sf(2));
569        test_absolute(10000, 2, 9800, 499751. / 499950.0, 1e-10, sf(0));
570        test_absolute(10000, 2, 9800, 480151. / 499950.0, 1e-10, sf(1));
571    }
572
573    #[test]
574    fn test_cdf_arg_too_big() {
575        let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg);
576        test_exact(0, 0, 0, 1.0, cdf(0));
577    }
578
579    #[test]
580    fn test_cdf_arg_too_small() {
581        let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg);
582        test_exact(2, 2, 2, 0.0, cdf(0));
583    }
584
585    #[test]
586    fn test_sf_arg_too_big() {
587        let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg);
588        test_exact(0, 0, 0, 0.0, sf(0));
589    }
590
591    #[test]
592    fn test_sf_arg_too_small() {
593        let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg);
594        test_exact(2, 2, 2, 1.0, sf(0));
595    }
596
597    #[test]
598    fn test_discrete() {
599        test::check_discrete_distribution(&create_ok(5, 4, 3), 4);
600        test::check_discrete_distribution(&create_ok(3, 2, 1), 2);
601    }
602}