statrs/distribution/
binomial.rs

1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::function::{beta, factorial};
3use crate::statistics::*;
4use std::f64;
5
6/// Implements the
7/// [Binomial](https://en.wikipedia.org/wiki/Binomial_distribution)
8/// distribution
9///
10/// # Examples
11///
12/// ```
13/// use statrs::distribution::{Binomial, Discrete};
14/// use statrs::statistics::Distribution;
15///
16/// let n = Binomial::new(0.5, 5).unwrap();
17/// assert_eq!(n.mean().unwrap(), 2.5);
18/// assert_eq!(n.pmf(0), 0.03125);
19/// assert_eq!(n.pmf(3), 0.3125);
20/// ```
21#[derive(Copy, Clone, PartialEq, Debug)]
22pub struct Binomial {
23    p: f64,
24    n: u64,
25}
26
27/// Represents the errors that can occur when creating a [`Binomial`].
28#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
29#[non_exhaustive]
30pub enum BinomialError {
31    /// The probability is NaN or not in `[0, 1]`.
32    ProbabilityInvalid,
33}
34
35impl std::fmt::Display for BinomialError {
36    #[cfg_attr(coverage_nightly, coverage(off))]
37    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
38        match self {
39            BinomialError::ProbabilityInvalid => write!(f, "Probability is NaN or not in [0, 1]"),
40        }
41    }
42}
43
44impl std::error::Error for BinomialError {}
45
46impl Binomial {
47    /// Constructs a new binomial distribution
48    /// with a given `p` probability of success of `n`
49    /// trials.
50    ///
51    /// # Errors
52    ///
53    /// Returns an error if `p` is `NaN`, less than `0.0`,
54    /// greater than `1.0`, or if `n` is less than `0`
55    ///
56    /// # Examples
57    ///
58    /// ```
59    /// use statrs::distribution::Binomial;
60    ///
61    /// let mut result = Binomial::new(0.5, 5);
62    /// assert!(result.is_ok());
63    ///
64    /// result = Binomial::new(-0.5, 5);
65    /// assert!(result.is_err());
66    /// ```
67    pub fn new(p: f64, n: u64) -> Result<Binomial, BinomialError> {
68        if p.is_nan() || !(0.0..=1.0).contains(&p) {
69            Err(BinomialError::ProbabilityInvalid)
70        } else {
71            Ok(Binomial { p, n })
72        }
73    }
74
75    /// Returns the probability of success `p` of
76    /// the binomial distribution.
77    ///
78    /// # Examples
79    ///
80    /// ```
81    /// use statrs::distribution::Binomial;
82    ///
83    /// let n = Binomial::new(0.5, 5).unwrap();
84    /// assert_eq!(n.p(), 0.5);
85    /// ```
86    pub fn p(&self) -> f64 {
87        self.p
88    }
89
90    /// Returns the number of trials `n` of the
91    /// binomial distribution.
92    ///
93    /// # Examples
94    ///
95    /// ```
96    /// use statrs::distribution::Binomial;
97    ///
98    /// let n = Binomial::new(0.5, 5).unwrap();
99    /// assert_eq!(n.n(), 5);
100    /// ```
101    pub fn n(&self) -> u64 {
102        self.n
103    }
104}
105
106impl std::fmt::Display for Binomial {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        write!(f, "Bin({},{})", self.p, self.n)
109    }
110}
111
112#[cfg(feature = "rand")]
113#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
114impl ::rand::distributions::Distribution<u64> for Binomial {
115    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
116        (0..self.n).fold(0, |acc, _| {
117            let n: f64 = rng.gen();
118            if n < self.p {
119                acc + 1
120            } else {
121                acc
122            }
123        })
124    }
125}
126
127#[cfg(feature = "rand")]
128#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
129impl ::rand::distributions::Distribution<f64> for Binomial {
130    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
131        rng.sample::<u64, _>(self) as f64
132    }
133}
134
135impl DiscreteCDF<u64, f64> for Binomial {
136    /// Calculates the cumulative distribution function for the
137    /// binomial distribution at `x`
138    ///
139    /// # Formula
140    ///
141    /// ```text
142    /// I_(1 - p)(n - x, 1 + x)
143    /// ```
144    ///
145    /// where `I_(x)(a, b)` is the regularized incomplete beta function
146    fn cdf(&self, x: u64) -> f64 {
147        if x >= self.n {
148            1.0
149        } else {
150            let k = x;
151            beta::beta_reg((self.n - k) as f64, k as f64 + 1.0, 1.0 - self.p)
152        }
153    }
154
155    /// Calculates the survival function for the
156    /// binomial distribution at `x`
157    ///
158    /// # Formula
159    ///
160    /// ```text
161    /// I_(p)(x + 1, n - x)
162    /// ```
163    ///
164    /// where `I_(x)(a, b)` is the regularized incomplete beta function
165    fn sf(&self, x: u64) -> f64 {
166        if x >= self.n {
167            0.0
168        } else {
169            let k = x;
170            beta::beta_reg(k as f64 + 1.0, (self.n - k) as f64, self.p)
171        }
172    }
173}
174
175impl Min<u64> for Binomial {
176    /// Returns the minimum value in the domain of the
177    /// binomial distribution representable by a 64-bit
178    /// integer
179    ///
180    /// # Formula
181    ///
182    /// ```text
183    /// 0
184    /// ```
185    fn min(&self) -> u64 {
186        0
187    }
188}
189
190impl Max<u64> for Binomial {
191    /// Returns the maximum value in the domain of the
192    /// binomial distribution representable by a 64-bit
193    /// integer
194    ///
195    /// # Formula
196    ///
197    /// ```text
198    /// n
199    /// ```
200    fn max(&self) -> u64 {
201        self.n
202    }
203}
204
205impl Distribution<f64> for Binomial {
206    /// Returns the mean of the binomial distribution
207    ///
208    /// # Formula
209    ///
210    /// ```text
211    /// p * n
212    /// ```
213    fn mean(&self) -> Option<f64> {
214        Some(self.p * self.n as f64)
215    }
216
217    /// Returns the variance of the binomial distribution
218    ///
219    /// # Formula
220    ///
221    /// ```text
222    /// n * p * (1 - p)
223    /// ```
224    fn variance(&self) -> Option<f64> {
225        Some(self.p * (1.0 - self.p) * self.n as f64)
226    }
227
228    /// Returns the entropy of the binomial distribution
229    ///
230    /// # Formula
231    ///
232    /// ```text
233    /// (1 / 2) * ln (2 * π * e * n * p * (1 - p))
234    /// ```
235    fn entropy(&self) -> Option<f64> {
236        let entr = if self.p == 0.0 || ulps_eq!(self.p, 1.0) {
237            0.0
238        } else {
239            (0..self.n + 1).fold(0.0, |acc, x| {
240                let p = self.pmf(x);
241                acc - p * p.ln()
242            })
243        };
244        Some(entr)
245    }
246
247    /// Returns the skewness of the binomial distribution
248    ///
249    /// # Formula
250    ///
251    /// ```text
252    /// (1 - 2p) / sqrt(n * p * (1 - p)))
253    /// ```
254    fn skewness(&self) -> Option<f64> {
255        Some((1.0 - 2.0 * self.p) / (self.n as f64 * self.p * (1.0 - self.p)).sqrt())
256    }
257}
258
259impl Median<f64> for Binomial {
260    /// Returns the median of the binomial distribution
261    ///
262    /// # Formula
263    ///
264    /// ```text
265    /// floor(n * p)
266    /// ```
267    fn median(&self) -> f64 {
268        (self.p * self.n as f64).floor()
269    }
270}
271
272impl Mode<Option<u64>> for Binomial {
273    /// Returns the mode for the binomial distribution
274    ///
275    /// # Formula
276    ///
277    /// ```text
278    /// floor((n + 1) * p)
279    /// ```
280    fn mode(&self) -> Option<u64> {
281        let mode = if self.p == 0.0 {
282            0
283        } else if ulps_eq!(self.p, 1.0) {
284            self.n
285        } else {
286            ((self.n as f64 + 1.0) * self.p).floor() as u64
287        };
288        Some(mode)
289    }
290}
291
292impl Discrete<u64, f64> for Binomial {
293    /// Calculates the probability mass function for the binomial
294    /// distribution at `x`
295    ///
296    /// # Formula
297    ///
298    /// ```text
299    /// (n choose k) * p^k * (1 - p)^(n - k)
300    /// ```
301    fn pmf(&self, x: u64) -> f64 {
302        if x > self.n {
303            0.0
304        } else if self.p == 0.0 {
305            if x == 0 {
306                1.0
307            } else {
308                0.0
309            }
310        } else if ulps_eq!(self.p, 1.0) {
311            if x == self.n {
312                1.0
313            } else {
314                0.0
315            }
316        } else {
317            (factorial::ln_binomial(self.n, x)
318                + x as f64 * self.p.ln()
319                + (self.n - x) as f64 * (1.0 - self.p).ln())
320            .exp()
321        }
322    }
323
324    /// Calculates the log probability mass function for the binomial
325    /// distribution at `x`
326    ///
327    /// # Formula
328    ///
329    /// ```text
330    /// ln((n choose k) * p^k * (1 - p)^(n - k))
331    /// ```
332    fn ln_pmf(&self, x: u64) -> f64 {
333        if x > self.n {
334            f64::NEG_INFINITY
335        } else if self.p == 0.0 {
336            if x == 0 {
337                0.0
338            } else {
339                f64::NEG_INFINITY
340            }
341        } else if ulps_eq!(self.p, 1.0) {
342            if x == self.n {
343                0.0
344            } else {
345                f64::NEG_INFINITY
346            }
347        } else {
348            factorial::ln_binomial(self.n, x)
349                + x as f64 * self.p.ln()
350                + (self.n - x) as f64 * (1.0 - self.p).ln()
351        }
352    }
353}
354
355#[rustfmt::skip]
356#[cfg(test)]
357mod tests {
358    use super::*;
359    use crate::distribution::internal::*;
360    use crate::testing_boiler;
361
362    testing_boiler!(p: f64, n: u64; Binomial; BinomialError);
363
364    #[test]
365    fn test_create() {
366        create_ok(0.0, 4);
367        create_ok(0.3, 3);
368        create_ok(1.0, 2);
369    }
370
371    #[test]
372    fn test_bad_create() {
373        create_err(f64::NAN, 1);
374        create_err(-1.0, 1);
375        create_err(2.0, 1);
376    }
377
378    #[test]
379    fn test_mean() {
380        let mean = |x: Binomial| x.mean().unwrap();
381        test_exact(0.0, 4, 0.0, mean);
382        test_absolute(0.3, 3, 0.9, 1e-15, mean);
383        test_exact(1.0, 2, 2.0, mean);
384    }
385
386    #[test]
387    fn test_variance() {
388        let variance = |x: Binomial| x.variance().unwrap();
389        test_exact(0.0, 4, 0.0, variance);
390        test_exact(0.3, 3, 0.63, variance);
391        test_exact(1.0, 2, 0.0, variance);
392    }
393
394    #[test]
395    fn test_entropy() {
396        let entropy = |x: Binomial| x.entropy().unwrap();
397        test_exact(0.0, 4, 0.0, entropy);
398        test_absolute(0.3, 3, 1.1404671643037712668976423399228972051669206536461, 1e-15, entropy);
399        test_exact(1.0, 2, 0.0, entropy);
400    }
401
402    #[test]
403    fn test_skewness() {
404        let skewness = |x: Binomial| x.skewness().unwrap();
405        test_exact(0.0, 4, f64::INFINITY, skewness);
406        test_exact(0.3, 3, 0.503952630678969636286, skewness);
407        test_exact(1.0, 2, f64::NEG_INFINITY, skewness);
408    }
409
410    #[test]
411    fn test_median() {
412        let median = |x: Binomial| x.median();
413        test_exact(0.0, 4, 0.0, median);
414        test_exact(0.3, 3, 0.0, median);
415        test_exact(1.0, 2, 2.0, median);
416    }
417
418    #[test]
419    fn test_mode() {
420        let mode = |x: Binomial| x.mode().unwrap();
421        test_exact(0.0, 4, 0, mode);
422        test_exact(0.3, 3, 1, mode);
423        test_exact(1.0, 2, 2, mode);
424    }
425
426    #[test]
427    fn test_min_max() {
428        let min = |x: Binomial| x.min();
429        let max = |x: Binomial| x.max();
430        test_exact(0.3, 10, 0, min);
431        test_exact(0.3, 10, 10, max);
432    }
433
434    #[test]
435    fn test_pmf() {
436        let pmf = |arg: u64| move |x: Binomial| x.pmf(arg);
437        test_exact(0.0, 1, 1.0, pmf(0));
438        test_exact(0.0, 1, 0.0, pmf(1));
439        test_exact(0.0, 3, 1.0, pmf(0));
440        test_exact(0.0, 3, 0.0, pmf(1));
441        test_exact(0.0, 3, 0.0, pmf(3));
442        test_exact(0.0, 10, 1.0, pmf(0));
443        test_exact(0.0, 10, 0.0, pmf(1));
444        test_exact(0.0, 10, 0.0, pmf(10));
445        test_exact(0.3, 1, 0.69999999999999995559107901499373838305473327636719, pmf(0));
446        test_exact(0.3, 1, 0.2999999999999999888977697537484345957636833190918, pmf(1));
447        test_exact(0.3, 3, 0.34299999999999993471888615204079956461021032657166, pmf(0));
448        test_absolute(0.3, 3, 0.44099999999999992772448109690231306411849135972008, 1e-15, pmf(1));
449        test_absolute(0.3, 3, 0.026999999999999997002397833512077451789759292859569, 1e-16, pmf(3));
450        test_absolute(0.3, 10, 0.02824752489999998207939855277004937778546385011091, 1e-17, pmf(0));
451        test_absolute(0.3, 10, 0.12106082099999992639752977030555903089040470780077, 1e-15, pmf(1));
452        test_absolute(0.3, 10, 0.0000059048999999999978147480206303047454017251032868501, 1e-20, pmf(10));
453        test_exact(1.0, 1, 0.0, pmf(0));
454        test_exact(1.0, 1, 1.0, pmf(1));
455        test_exact(1.0, 3, 0.0, pmf(0));
456        test_exact(1.0, 3, 0.0, pmf(1));
457        test_exact(1.0, 3, 1.0, pmf(3));
458        test_exact(1.0, 10, 0.0, pmf(0));
459        test_exact(1.0, 10, 0.0, pmf(1));
460        test_exact(1.0, 10, 1.0, pmf(10));
461    }
462
463    #[test]
464    fn test_ln_pmf() {
465        let ln_pmf = |arg: u64| move |x: Binomial| x.ln_pmf(arg);
466        test_exact(0.0, 1, 0.0, ln_pmf(0));
467        test_exact(0.0, 1, f64::NEG_INFINITY, ln_pmf(1));
468        test_exact(0.0, 3, 0.0, ln_pmf(0));
469        test_exact(0.0, 3, f64::NEG_INFINITY, ln_pmf(1));
470        test_exact(0.0, 3, f64::NEG_INFINITY, ln_pmf(3));
471        test_exact(0.0, 10, 0.0, ln_pmf(0));
472        test_exact(0.0, 10, f64::NEG_INFINITY, ln_pmf(1));
473        test_exact(0.0, 10, f64::NEG_INFINITY, ln_pmf(10));
474        test_exact(0.3, 1, -0.3566749439387324423539544041072745145718090708995, ln_pmf(0));
475        test_exact(0.3, 1, -1.2039728043259360296301803719337238685164245381839, ln_pmf(1));
476        test_exact(0.3, 3, -1.0700248318161973270618632123218235437154272126985, ln_pmf(0));
477        test_absolute(0.3, 3, -0.81871040353529122294284394322574719301255212216016, 1e-15, ln_pmf(1));
478        test_absolute(0.3, 3, -3.6119184129778080888905411158011716055492736145517, 1e-15, ln_pmf(3));
479        test_exact(0.3, 10, -3.566749439387324423539544041072745145718090708995, ln_pmf(0));
480        test_absolute(0.3, 10, -2.1114622067804823267977785542148302920616046876506, 1e-14, ln_pmf(1));
481        test_exact(0.3, 10, -12.039728043259360296301803719337238685164245381839, ln_pmf(10));
482        test_exact(1.0, 1, f64::NEG_INFINITY, ln_pmf(0));
483        test_exact(1.0, 1, 0.0, ln_pmf(1));
484        test_exact(1.0, 3, f64::NEG_INFINITY, ln_pmf(0));
485        test_exact(1.0, 3, f64::NEG_INFINITY, ln_pmf(1));
486        test_exact(1.0, 3, 0.0, ln_pmf(3));
487        test_exact(1.0, 10, f64::NEG_INFINITY, ln_pmf(0));
488        test_exact(1.0, 10, f64::NEG_INFINITY, ln_pmf(1));
489        test_exact(1.0, 10, 0.0, ln_pmf(10));
490    }
491
492    #[test]
493    fn test_cdf() {
494        let cdf = |arg: u64| move |x: Binomial| x.cdf(arg);
495        test_exact(0.0, 1, 1.0, cdf(0));
496        test_exact(0.0, 1, 1.0, cdf(1));
497        test_exact(0.0, 3, 1.0, cdf(0));
498        test_exact(0.0, 3, 1.0, cdf(1));
499        test_exact(0.0, 3, 1.0, cdf(3));
500        test_exact(0.0, 10, 1.0, cdf(0));
501        test_exact(0.0, 10, 1.0, cdf(1));
502        test_exact(0.0, 10, 1.0, cdf(10));
503        test_absolute(0.3, 1, 0.7, 1e-15, cdf(0));
504        test_exact(0.3, 1, 1.0, cdf(1));
505        test_absolute(0.3, 3, 0.343, 1e-14, cdf(0));
506        test_absolute(0.3, 3, 0.784, 1e-15, cdf(1));
507        test_exact(0.3, 3, 1.0, cdf(3));
508        test_absolute(0.3, 10, 0.0282475249, 1e-16, cdf(0));
509        test_absolute(0.3, 10, 0.1493083459, 1e-14, cdf(1));
510        test_exact(0.3, 10, 1.0, cdf(10));
511        test_exact(1.0, 1, 0.0, cdf(0));
512        test_exact(1.0, 1, 1.0, cdf(1));
513        test_exact(1.0, 3, 0.0, cdf(0));
514        test_exact(1.0, 3, 0.0, cdf(1));
515        test_exact(1.0, 3, 1.0, cdf(3));
516        test_exact(1.0, 10, 0.0, cdf(0));
517        test_exact(1.0, 10, 0.0, cdf(1));
518        test_exact(1.0, 10, 1.0, cdf(10));
519    }
520
521    #[test]
522    fn test_sf() {
523        let sf = |arg: u64| move |x: Binomial| x.sf(arg);
524        test_exact(0.0, 1, 0.0, sf(0));
525        test_exact(0.0, 1, 0.0, sf(1));
526        test_exact(0.0, 3, 0.0, sf(0));
527        test_exact(0.0, 3, 0.0, sf(1));
528        test_exact(0.0, 3, 0.0, sf(3));
529        test_exact(0.0, 10, 0.0, sf(0));
530        test_exact(0.0, 10, 0.0, sf(1));
531        test_exact(0.0, 10, 0.0, sf(10));
532        test_absolute(0.3, 1, 0.3, 1e-15, sf(0));
533        test_exact(0.3, 1, 0.0, sf(1));
534        test_absolute(0.3, 3, 0.657, 1e-14, sf(0));
535        test_absolute(0.3, 3, 0.216, 1e-15, sf(1));
536        test_exact(0.3, 3, 0.0, sf(3));
537        test_absolute(0.3, 10, 0.9717524751000001, 1e-16, sf(0));
538        test_absolute(0.3, 10, 0.850691654100002, 1e-14, sf(1));
539        test_exact(0.3, 10, 0.0, sf(10));
540        test_exact(1.0, 1, 1.0, sf(0));
541        test_exact(1.0, 1, 0.0, sf(1));
542        test_exact(1.0, 3, 1.0, sf(0));
543        test_exact(1.0, 3, 1.0, sf(1));
544        test_exact(1.0, 3, 0.0, sf(3));
545        test_exact(1.0, 10, 1.0, sf(0));
546        test_exact(1.0, 10, 1.0, sf(1));
547        test_exact(1.0, 10, 0.0, sf(10));
548    }
549
550    #[test]
551    fn test_cdf_upper_bound() {
552        let cdf = |arg: u64| move |x: Binomial| x.cdf(arg);
553        test_exact(0.5, 3, 1.0, cdf(5));
554    }
555
556    #[test]
557    fn test_sf_upper_bound() {
558        let sf = |arg: u64| move |x: Binomial| x.sf(arg);
559        test_exact(0.5, 3, 0.0, sf(5));
560    }
561
562    #[test]
563    fn test_inverse_cdf() {
564        let invcdf = |arg: f64| move |x: Binomial| x.inverse_cdf(arg);
565        test_exact(0.4, 5, 2, invcdf(0.3456));
566
567        // cases in issue #185
568        test_exact(0.018, 465, 1, invcdf(3.472e-4));
569        test_exact(0.5, 6, 4, invcdf(0.75));
570    }
571
572    #[test]
573    fn test_cdf_inverse_cdf() {
574        let cdf_invcdf = |arg: u64| move |x: Binomial| x.inverse_cdf(x.cdf(arg));
575        test_exact(0.3, 10, 3, cdf_invcdf(3));
576        test_exact(0.3, 10, 4, cdf_invcdf(4));
577        test_exact(0.5, 6, 4, cdf_invcdf(4));
578    }
579
580    #[test]
581    fn test_discrete() {
582        test::check_discrete_distribution(&create_ok(0.3, 5), 5);
583        test::check_discrete_distribution(&create_ok(0.7, 10), 10);
584    }
585}