statrs/distribution/
binomial.rs

1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::function::{beta, factorial};
3use crate::is_zero;
4use crate::statistics::*;
5use crate::{Result, StatsError};
6use rand::Rng;
7use std::f64;
8
9/// Implements the
10/// [Binomial](https://en.wikipedia.org/wiki/Binomial_distribution)
11/// distribution
12///
13/// # Examples
14///
15/// ```
16/// use statrs::distribution::{Binomial, Discrete};
17/// use statrs::statistics::Distribution;
18///
19/// let n = Binomial::new(0.5, 5).unwrap();
20/// assert_eq!(n.mean().unwrap(), 2.5);
21/// assert_eq!(n.pmf(0), 0.03125);
22/// assert_eq!(n.pmf(3), 0.3125);
23/// ```
24#[derive(Debug, Copy, Clone, PartialEq)]
25pub struct Binomial {
26    p: f64,
27    n: u64,
28}
29
30impl Binomial {
31    /// Constructs a new binomial distribution
32    /// with a given `p` probability of success of `n`
33    /// trials.
34    ///
35    /// # Errors
36    ///
37    /// Returns an error if `p` is `NaN`, less than `0.0`,
38    /// greater than `1.0`, or if `n` is less than `0`
39    ///
40    /// # Examples
41    ///
42    /// ```
43    /// use statrs::distribution::Binomial;
44    ///
45    /// let mut result = Binomial::new(0.5, 5);
46    /// assert!(result.is_ok());
47    ///
48    /// result = Binomial::new(-0.5, 5);
49    /// assert!(result.is_err());
50    /// ```
51    pub fn new(p: f64, n: u64) -> Result<Binomial> {
52        if p.is_nan() || p < 0.0 || p > 1.0 {
53            Err(StatsError::BadParams)
54        } else {
55            Ok(Binomial { p, n })
56        }
57    }
58
59    /// Returns the probability of success `p` of
60    /// the binomial distribution.
61    ///
62    /// # Examples
63    ///
64    /// ```
65    /// use statrs::distribution::Binomial;
66    ///
67    /// let n = Binomial::new(0.5, 5).unwrap();
68    /// assert_eq!(n.p(), 0.5);
69    /// ```
70    pub fn p(&self) -> f64 {
71        self.p
72    }
73
74    /// Returns the number of trials `n` of the
75    /// binomial distribution.
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use statrs::distribution::Binomial;
81    ///
82    /// let n = Binomial::new(0.5, 5).unwrap();
83    /// assert_eq!(n.n(), 5);
84    /// ```
85    pub fn n(&self) -> u64 {
86        self.n
87    }
88}
89
90impl ::rand::distributions::Distribution<f64> for Binomial {
91    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
92        (0..self.n).fold(0.0, |acc, _| {
93            let n: f64 = rng.gen();
94            if n < self.p {
95                acc + 1.0
96            } else {
97                acc
98            }
99        })
100    }
101}
102
103impl DiscreteCDF<u64, f64> for Binomial {
104    /// Calculates the cumulative distribution function for the
105    /// binomial distribution at `x`
106    ///
107    /// # Formula
108    ///
109    /// ```ignore
110    /// I_(1 - p)(n - x, 1 + x)
111    /// ```
112    ///
113    /// where `I_(x)(a, b)` is the regularized incomplete beta function
114    fn cdf(&self, x: u64) -> f64 {
115        if x >= self.n {
116            1.0
117        } else {
118            let k = x;
119            beta::beta_reg((self.n - k) as f64, k as f64 + 1.0, 1.0 - self.p)
120        }
121    }
122
123    /// Calculates the survival function for the
124    /// binomial distribution at `x`
125    ///
126    /// # Formula
127    ///
128    /// ```ignore
129    /// I_(p)(x + 1, n - x)
130    /// ```
131    ///
132    /// where `I_(x)(a, b)` is the regularized incomplete beta function
133    fn sf(&self, x: u64) -> f64 {
134        if x >= self.n {
135            0.0
136        } else {
137            let k = x;
138            beta::beta_reg(k as f64 + 1.0, (self.n - k) as f64, self.p)
139        }
140    }
141}
142
143impl Min<u64> for Binomial {
144    /// Returns the minimum value in the domain of the
145    /// binomial distribution representable by a 64-bit
146    /// integer
147    ///
148    /// # Formula
149    ///
150    /// ```ignore
151    /// 0
152    /// ```
153    fn min(&self) -> u64 {
154        0
155    }
156}
157
158impl Max<u64> for Binomial {
159    /// Returns the maximum value in the domain of the
160    /// binomial distribution representable by a 64-bit
161    /// integer
162    ///
163    /// # Formula
164    ///
165    /// ```ignore
166    /// n
167    /// ```
168    fn max(&self) -> u64 {
169        self.n
170    }
171}
172
173impl Distribution<f64> for Binomial {
174    /// Returns the mean of the binomial distribution
175    ///
176    /// # Formula
177    ///
178    /// ```ignore
179    /// p * n
180    /// ```
181    fn mean(&self) -> Option<f64> {
182        Some(self.p * self.n as f64)
183    }
184    /// Returns the variance of the binomial distribution
185    ///
186    /// # Formula
187    ///
188    /// ```ignore
189    /// n * p * (1 - p)
190    /// ```
191    fn variance(&self) -> Option<f64> {
192        Some(self.p * (1.0 - self.p) * self.n as f64)
193    }
194    /// Returns the entropy of the binomial distribution
195    ///
196    /// # Formula
197    ///
198    /// ```ignore
199    /// (1 / 2) * ln (2 * π * e * n * p * (1 - p))
200    /// ```
201    fn entropy(&self) -> Option<f64> {
202        let entr = if is_zero(self.p) || ulps_eq!(self.p, 1.0) {
203            0.0
204        } else {
205            (0..self.n + 1).fold(0.0, |acc, x| {
206                let p = self.pmf(x);
207                acc - p * p.ln()
208            })
209        };
210        Some(entr)
211    }
212    /// Returns the skewness of the binomial distribution
213    ///
214    /// # Formula
215    ///
216    /// ```ignore
217    /// (1 - 2p) / sqrt(n * p * (1 - p)))
218    /// ```
219    fn skewness(&self) -> Option<f64> {
220        Some((1.0 - 2.0 * self.p) / (self.n as f64 * self.p * (1.0 - self.p)).sqrt())
221    }
222}
223
224impl Median<f64> for Binomial {
225    /// Returns the median of the binomial distribution
226    ///
227    /// # Formula
228    ///
229    /// ```ignore
230    /// floor(n * p)
231    /// ```
232    fn median(&self) -> f64 {
233        (self.p * self.n as f64).floor()
234    }
235}
236
237impl Mode<Option<u64>> for Binomial {
238    /// Returns the mode for the binomial distribution
239    ///
240    /// # Formula
241    ///
242    /// ```ignore
243    /// floor((n + 1) * p)
244    /// ```
245    fn mode(&self) -> Option<u64> {
246        let mode = if is_zero(self.p) {
247            0
248        } else if ulps_eq!(self.p, 1.0) {
249            self.n
250        } else {
251            ((self.n as f64 + 1.0) * self.p).floor() as u64
252        };
253        Some(mode)
254    }
255}
256
257impl Discrete<u64, f64> for Binomial {
258    /// Calculates the probability mass function for the binomial
259    /// distribution at `x`
260    ///
261    /// # Formula
262    ///
263    /// ```ignore
264    /// (n choose k) * p^k * (1 - p)^(n - k)
265    /// ```
266    fn pmf(&self, x: u64) -> f64 {
267        if x > self.n {
268            0.0
269        } else if is_zero(self.p) {
270            if x == 0 {
271                1.0
272            } else {
273                0.0
274            }
275        } else if ulps_eq!(self.p, 1.0) {
276            if x == self.n {
277                1.0
278            } else {
279                0.0
280            }
281        } else {
282            (factorial::ln_binomial(self.n as u64, x as u64)
283                + x as f64 * self.p.ln()
284                + (self.n - x) as f64 * (1.0 - self.p).ln())
285            .exp()
286        }
287    }
288
289    /// Calculates the log probability mass function for the binomial
290    /// distribution at `x`
291    ///
292    /// # Formula
293    ///
294    /// ```ignore
295    /// ln((n choose k) * p^k * (1 - p)^(n - k))
296    /// ```
297    fn ln_pmf(&self, x: u64) -> f64 {
298        if x > self.n {
299            f64::NEG_INFINITY
300        } else if is_zero(self.p) {
301            if x == 0 {
302                0.0
303            } else {
304                f64::NEG_INFINITY
305            }
306        } else if ulps_eq!(self.p, 1.0) {
307            if x == self.n {
308                0.0
309            } else {
310                f64::NEG_INFINITY
311            }
312        } else {
313            factorial::ln_binomial(self.n as u64, x as u64)
314                + x as f64 * self.p.ln()
315                + (self.n - x) as f64 * (1.0 - self.p).ln()
316        }
317    }
318}
319
320#[rustfmt::skip]
321#[cfg(all(test, feature = "nightly"))]
322mod tests {
323    use std::fmt::Debug;
324    use crate::statistics::*;
325    use crate::distribution::{DiscreteCDF, Discrete, Binomial};
326    use crate::distribution::internal::*;
327    use crate::consts::ACC;
328
329    fn try_create(p: f64, n: u64) -> Binomial {
330        let n = Binomial::new(p, n);
331        assert!(n.is_ok());
332        n.unwrap()
333    }
334
335    fn create_case(p: f64, n: u64) {
336        let dist = try_create(p, n);
337        assert_eq!(p, dist.p());
338        assert_eq!(n, dist.n());
339    }
340
341    fn bad_create_case(p: f64, n: u64) {
342        let n = Binomial::new(p, n);
343        assert!(n.is_err());
344    }
345
346    fn get_value<T, F>(p: f64, n: u64, eval: F) -> T
347        where T: PartialEq + Debug,
348              F: Fn(Binomial) -> T
349    {
350        let n = try_create(p, n);
351        eval(n)
352    }
353
354    fn test_case<T, F>(p: f64, n: u64, expected: T, eval: F)
355        where T: PartialEq + Debug,
356              F: Fn(Binomial) -> T
357    {
358        let x = get_value(p, n, eval);
359        println!("{} {} {:?}", p, n, expected);
360        assert_eq!(expected, x);
361    }
362
363    fn test_almost<F>(p: f64, n: u64, expected: f64, acc: f64, eval: F)
364        where F: Fn(Binomial) -> f64
365    {
366        let x = get_value(p, n, eval);
367        assert_almost_eq!(expected, x, acc);
368    }
369
370    #[test]
371    fn test_create() {
372        create_case(0.0, 4);
373        create_case(0.3, 3);
374        create_case(1.0, 2);
375    }
376
377    #[test]
378    fn test_bad_create() {
379        bad_create_case(f64::NAN, 1);
380        bad_create_case(-1.0, 1);
381        bad_create_case(2.0, 1);
382    }
383
384    #[test]
385    fn test_mean() {
386        let mean = |x: Binomial| x.mean().unwrap();
387        test_case(0.0, 4, 0.0, mean);
388        test_almost(0.3, 3, 0.9, 1e-15, mean);
389        test_case(1.0, 2, 2.0, mean);
390    }
391
392    #[test]
393    fn test_variance() {
394        let variance = |x: Binomial| x.variance().unwrap();
395        test_case(0.0, 4, 0.0, variance);
396        test_case(0.3, 3, 0.63, variance);
397        test_case(1.0, 2, 0.0, variance);
398    }
399
400    #[test]
401    fn test_entropy() {
402        let entropy = |x: Binomial| x.entropy().unwrap();
403        test_case(0.0, 4, 0.0, entropy);
404        test_almost(0.3, 3, 1.1404671643037712668976423399228972051669206536461, 1e-15, entropy);
405        test_case(1.0, 2, 0.0, entropy);
406    }
407
408    #[test]
409    fn test_skewness() {
410        let skewness = |x: Binomial| x.skewness().unwrap();
411        test_case(0.0, 4, f64::INFINITY, skewness);
412        test_case(0.3, 3, 0.503952630678969636286, skewness);
413        test_case(1.0, 2, f64::NEG_INFINITY, skewness);
414    }
415
416    #[test]
417    fn test_median() {
418        let median = |x: Binomial| x.median();
419        test_case(0.0, 4, 0.0, median);
420        test_case(0.3, 3, 0.0, median);
421        test_case(1.0, 2, 2.0, median);
422    }
423
424    #[test]
425    fn test_mode() {
426        let mode = |x: Binomial| x.mode().unwrap();
427        test_case(0.0, 4, 0, mode);
428        test_case(0.3, 3, 1, mode);
429        test_case(1.0, 2, 2, mode);
430    }
431
432    #[test]
433    fn test_min_max() {
434        let min = |x: Binomial| x.min();
435        let max = |x: Binomial| x.max();
436        test_case(0.3, 10, 0, min);
437        test_case(0.3, 10, 10, max);
438    }
439
440    #[test]
441    fn test_pmf() {
442        let pmf = |arg: u64| move |x: Binomial| x.pmf(arg);
443        test_case(0.0, 1, 1.0, pmf(0));
444        test_case(0.0, 1, 0.0, pmf(1));
445        test_case(0.0, 3, 1.0, pmf(0));
446        test_case(0.0, 3, 0.0, pmf(1));
447        test_case(0.0, 3, 0.0, pmf(3));
448        test_case(0.0, 10, 1.0, pmf(0));
449        test_case(0.0, 10, 0.0, pmf(1));
450        test_case(0.0, 10, 0.0, pmf(10));
451        test_case(0.3, 1, 0.69999999999999995559107901499373838305473327636719, pmf(0));
452        test_case(0.3, 1, 0.2999999999999999888977697537484345957636833190918, pmf(1));
453        test_case(0.3, 3, 0.34299999999999993471888615204079956461021032657166, pmf(0));
454        test_almost(0.3, 3, 0.44099999999999992772448109690231306411849135972008, 1e-15, pmf(1));
455        test_almost(0.3, 3, 0.026999999999999997002397833512077451789759292859569, 1e-16, pmf(3));
456        test_almost(0.3, 10, 0.02824752489999998207939855277004937778546385011091, 1e-17, pmf(0));
457        test_almost(0.3, 10, 0.12106082099999992639752977030555903089040470780077, 1e-15, pmf(1));
458        test_almost(0.3, 10, 0.0000059048999999999978147480206303047454017251032868501, 1e-20, pmf(10));
459        test_case(1.0, 1, 0.0, pmf(0));
460        test_case(1.0, 1, 1.0, pmf(1));
461        test_case(1.0, 3, 0.0, pmf(0));
462        test_case(1.0, 3, 0.0, pmf(1));
463        test_case(1.0, 3, 1.0, pmf(3));
464        test_case(1.0, 10, 0.0, pmf(0));
465        test_case(1.0, 10, 0.0, pmf(1));
466        test_case(1.0, 10, 1.0, pmf(10));
467    }
468
469    #[test]
470    fn test_ln_pmf() {
471        let ln_pmf = |arg: u64| move |x: Binomial| x.ln_pmf(arg);
472        test_case(0.0, 1, 0.0, ln_pmf(0));
473        test_case(0.0, 1, f64::NEG_INFINITY, ln_pmf(1));
474        test_case(0.0, 3, 0.0, ln_pmf(0));
475        test_case(0.0, 3, f64::NEG_INFINITY, ln_pmf(1));
476        test_case(0.0, 3, f64::NEG_INFINITY, ln_pmf(3));
477        test_case(0.0, 10, 0.0, ln_pmf(0));
478        test_case(0.0, 10, f64::NEG_INFINITY, ln_pmf(1));
479        test_case(0.0, 10, f64::NEG_INFINITY, ln_pmf(10));
480        test_case(0.3, 1, -0.3566749439387324423539544041072745145718090708995, ln_pmf(0));
481        test_case(0.3, 1, -1.2039728043259360296301803719337238685164245381839, ln_pmf(1));
482        test_case(0.3, 3, -1.0700248318161973270618632123218235437154272126985, ln_pmf(0));
483        test_almost(0.3, 3, -0.81871040353529122294284394322574719301255212216016, 1e-15, ln_pmf(1));
484        test_almost(0.3, 3, -3.6119184129778080888905411158011716055492736145517, 1e-15, ln_pmf(3));
485        test_case(0.3, 10, -3.566749439387324423539544041072745145718090708995, ln_pmf(0));
486        test_almost(0.3, 10, -2.1114622067804823267977785542148302920616046876506, 1e-14, ln_pmf(1));
487        test_case(0.3, 10, -12.039728043259360296301803719337238685164245381839, ln_pmf(10));
488        test_case(1.0, 1, f64::NEG_INFINITY, ln_pmf(0));
489        test_case(1.0, 1, 0.0, ln_pmf(1));
490        test_case(1.0, 3, f64::NEG_INFINITY, ln_pmf(0));
491        test_case(1.0, 3, f64::NEG_INFINITY, ln_pmf(1));
492        test_case(1.0, 3, 0.0, ln_pmf(3));
493        test_case(1.0, 10, f64::NEG_INFINITY, ln_pmf(0));
494        test_case(1.0, 10, f64::NEG_INFINITY, ln_pmf(1));
495        test_case(1.0, 10, 0.0, ln_pmf(10));
496    }
497
498    #[test]
499    fn test_cdf() {
500        let cdf = |arg: u64| move |x: Binomial| x.cdf(arg);
501        test_case(0.0, 1, 1.0, cdf(0));
502        test_case(0.0, 1, 1.0, cdf(1));
503        test_case(0.0, 3, 1.0, cdf(0));
504        test_case(0.0, 3, 1.0, cdf(1));
505        test_case(0.0, 3, 1.0, cdf(3));
506        test_case(0.0, 10, 1.0, cdf(0));
507        test_case(0.0, 10, 1.0, cdf(1));
508        test_case(0.0, 10, 1.0, cdf(10));
509        test_almost(0.3, 1, 0.7, 1e-15, cdf(0));
510        test_case(0.3, 1, 1.0, cdf(1));
511        test_almost(0.3, 3, 0.343, 1e-14, cdf(0));
512        test_almost(0.3, 3, 0.784, 1e-15, cdf(1));
513        test_case(0.3, 3, 1.0, cdf(3));
514        test_almost(0.3, 10, 0.0282475249, 1e-16, cdf(0));
515        test_almost(0.3, 10, 0.1493083459, 1e-14, cdf(1));
516        test_case(0.3, 10, 1.0, cdf(10));
517        test_case(1.0, 1, 0.0, cdf(0));
518        test_case(1.0, 1, 1.0, cdf(1));
519        test_case(1.0, 3, 0.0, cdf(0));
520        test_case(1.0, 3, 0.0, cdf(1));
521        test_case(1.0, 3, 1.0, cdf(3));
522        test_case(1.0, 10, 0.0, cdf(0));
523        test_case(1.0, 10, 0.0, cdf(1));
524        test_case(1.0, 10, 1.0, cdf(10));
525    }
526
527    #[test]
528    fn test_sf() {
529        let sf = |arg: u64| move |x: Binomial| x.sf(arg);
530        test_case(0.0, 1, 0.0, sf(0));
531        test_case(0.0, 1, 0.0, sf(1));
532        test_case(0.0, 3, 0.0, sf(0));
533        test_case(0.0, 3, 0.0, sf(1));
534        test_case(0.0, 3, 0.0, sf(3));
535        test_case(0.0, 10, 0.0, sf(0));
536        test_case(0.0, 10, 0.0, sf(1));
537        test_case(0.0, 10, 0.0, sf(10));
538        test_almost(0.3, 1, 0.3, 1e-15, sf(0));
539        test_case(0.3, 1, 0.0, sf(1));
540        test_almost(0.3, 3, 0.657, 1e-14, sf(0));
541        test_almost(0.3, 3, 0.216, 1e-15, sf(1));
542        test_case(0.3, 3, 0.0, sf(3));
543        test_almost(0.3, 10, 0.9717524751000001, 1e-16, sf(0));
544        test_almost(0.3, 10, 0.850691654100002, 1e-14, sf(1));
545        test_case(0.3, 10, 0.0, sf(10));
546        test_case(1.0, 1, 1.0, sf(0));
547        test_case(1.0, 1, 0.0, sf(1));
548        test_case(1.0, 3, 1.0, sf(0));
549        test_case(1.0, 3, 1.0, sf(1));
550        test_case(1.0, 3, 0.0, sf(3));
551        test_case(1.0, 10, 1.0, sf(0));
552        test_case(1.0, 10, 1.0, sf(1));
553        test_case(1.0, 10, 0.0, sf(10));
554    }
555
556    #[test]
557    fn test_cdf_upper_bound() {
558        let cdf = |arg: u64| move |x: Binomial| x.cdf(arg);
559        test_case(0.5, 3, 1.0, cdf(5));
560    }
561
562    #[test]
563    fn test_sf_upper_bound() {
564        let sf = |arg: u64| move |x: Binomial| x.sf(arg);
565        test_case(0.5, 3, 0.0, sf(5));
566    }
567
568    #[test]
569    fn test_inverse_cdf() {
570        let invcdf = |arg: f64| move |x: Binomial| x.inverse_cdf(arg);
571        test_case(0.4, 5, 2, invcdf(0.3456));
572
573        // cases in issue #185
574        test_case(0.018, 465, 1, invcdf(3.472e-4));
575        test_case(0.5, 6, 4, invcdf(0.75));
576    }
577
578    #[test]
579    fn test_cdf_inverse_cdf() {
580        let cdf_invcdf = |arg: u64| move |x: Binomial| x.inverse_cdf(x.cdf(arg));
581        test_case(0.3, 10, 3, cdf_invcdf(3));
582        test_case(0.3, 10, 4, cdf_invcdf(4));
583        test_case(0.5, 6, 4, cdf_invcdf(4));
584    }
585
586    #[test]
587    fn test_discrete() {
588        test::check_discrete_distribution(&try_create(0.3, 5), 5);
589        test::check_discrete_distribution(&try_create(0.7, 10), 10);
590    }
591}