1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::function::{beta, factorial};
3use crate::statistics::*;
4use std::f64;
5
6#[derive(Copy, Clone, PartialEq, Debug)]
22pub struct Binomial {
23 p: f64,
24 n: u64,
25}
26
27#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
29#[non_exhaustive]
30pub enum BinomialError {
31 ProbabilityInvalid,
33}
34
35impl std::fmt::Display for BinomialError {
36 #[cfg_attr(coverage_nightly, coverage(off))]
37 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
38 match self {
39 BinomialError::ProbabilityInvalid => write!(f, "Probability is NaN or not in [0, 1]"),
40 }
41 }
42}
43
44impl std::error::Error for BinomialError {}
45
46impl Binomial {
47 pub fn new(p: f64, n: u64) -> Result<Binomial, BinomialError> {
68 if p.is_nan() || !(0.0..=1.0).contains(&p) {
69 Err(BinomialError::ProbabilityInvalid)
70 } else {
71 Ok(Binomial { p, n })
72 }
73 }
74
75 pub fn p(&self) -> f64 {
87 self.p
88 }
89
90 pub fn n(&self) -> u64 {
102 self.n
103 }
104}
105
106impl std::fmt::Display for Binomial {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 write!(f, "Bin({},{})", self.p, self.n)
109 }
110}
111
112#[cfg(feature = "rand")]
113#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
114impl ::rand::distributions::Distribution<u64> for Binomial {
115 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
116 (0..self.n).fold(0, |acc, _| {
117 let n: f64 = rng.gen();
118 if n < self.p {
119 acc + 1
120 } else {
121 acc
122 }
123 })
124 }
125}
126
127#[cfg(feature = "rand")]
128#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
129impl ::rand::distributions::Distribution<f64> for Binomial {
130 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
131 rng.sample::<u64, _>(self) as f64
132 }
133}
134
135impl DiscreteCDF<u64, f64> for Binomial {
136 fn cdf(&self, x: u64) -> f64 {
147 if x >= self.n {
148 1.0
149 } else {
150 let k = x;
151 beta::beta_reg((self.n - k) as f64, k as f64 + 1.0, 1.0 - self.p)
152 }
153 }
154
155 fn sf(&self, x: u64) -> f64 {
166 if x >= self.n {
167 0.0
168 } else {
169 let k = x;
170 beta::beta_reg(k as f64 + 1.0, (self.n - k) as f64, self.p)
171 }
172 }
173}
174
175impl Min<u64> for Binomial {
176 fn min(&self) -> u64 {
186 0
187 }
188}
189
190impl Max<u64> for Binomial {
191 fn max(&self) -> u64 {
201 self.n
202 }
203}
204
205impl Distribution<f64> for Binomial {
206 fn mean(&self) -> Option<f64> {
214 Some(self.p * self.n as f64)
215 }
216
217 fn variance(&self) -> Option<f64> {
225 Some(self.p * (1.0 - self.p) * self.n as f64)
226 }
227
228 fn entropy(&self) -> Option<f64> {
236 let entr = if self.p == 0.0 || ulps_eq!(self.p, 1.0) {
237 0.0
238 } else {
239 (0..self.n + 1).fold(0.0, |acc, x| {
240 let p = self.pmf(x);
241 acc - p * p.ln()
242 })
243 };
244 Some(entr)
245 }
246
247 fn skewness(&self) -> Option<f64> {
255 Some((1.0 - 2.0 * self.p) / (self.n as f64 * self.p * (1.0 - self.p)).sqrt())
256 }
257}
258
259impl Median<f64> for Binomial {
260 fn median(&self) -> f64 {
268 (self.p * self.n as f64).floor()
269 }
270}
271
272impl Mode<Option<u64>> for Binomial {
273 fn mode(&self) -> Option<u64> {
281 let mode = if self.p == 0.0 {
282 0
283 } else if ulps_eq!(self.p, 1.0) {
284 self.n
285 } else {
286 ((self.n as f64 + 1.0) * self.p).floor() as u64
287 };
288 Some(mode)
289 }
290}
291
292impl Discrete<u64, f64> for Binomial {
293 fn pmf(&self, x: u64) -> f64 {
302 if x > self.n {
303 0.0
304 } else if self.p == 0.0 {
305 if x == 0 {
306 1.0
307 } else {
308 0.0
309 }
310 } else if ulps_eq!(self.p, 1.0) {
311 if x == self.n {
312 1.0
313 } else {
314 0.0
315 }
316 } else {
317 (factorial::ln_binomial(self.n, x)
318 + x as f64 * self.p.ln()
319 + (self.n - x) as f64 * (1.0 - self.p).ln())
320 .exp()
321 }
322 }
323
324 fn ln_pmf(&self, x: u64) -> f64 {
333 if x > self.n {
334 f64::NEG_INFINITY
335 } else if self.p == 0.0 {
336 if x == 0 {
337 0.0
338 } else {
339 f64::NEG_INFINITY
340 }
341 } else if ulps_eq!(self.p, 1.0) {
342 if x == self.n {
343 0.0
344 } else {
345 f64::NEG_INFINITY
346 }
347 } else {
348 factorial::ln_binomial(self.n, x)
349 + x as f64 * self.p.ln()
350 + (self.n - x) as f64 * (1.0 - self.p).ln()
351 }
352 }
353}
354
355#[rustfmt::skip]
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use crate::distribution::internal::*;
360 use crate::testing_boiler;
361
362 testing_boiler!(p: f64, n: u64; Binomial; BinomialError);
363
364 #[test]
365 fn test_create() {
366 create_ok(0.0, 4);
367 create_ok(0.3, 3);
368 create_ok(1.0, 2);
369 }
370
371 #[test]
372 fn test_bad_create() {
373 create_err(f64::NAN, 1);
374 create_err(-1.0, 1);
375 create_err(2.0, 1);
376 }
377
378 #[test]
379 fn test_mean() {
380 let mean = |x: Binomial| x.mean().unwrap();
381 test_exact(0.0, 4, 0.0, mean);
382 test_absolute(0.3, 3, 0.9, 1e-15, mean);
383 test_exact(1.0, 2, 2.0, mean);
384 }
385
386 #[test]
387 fn test_variance() {
388 let variance = |x: Binomial| x.variance().unwrap();
389 test_exact(0.0, 4, 0.0, variance);
390 test_exact(0.3, 3, 0.63, variance);
391 test_exact(1.0, 2, 0.0, variance);
392 }
393
394 #[test]
395 fn test_entropy() {
396 let entropy = |x: Binomial| x.entropy().unwrap();
397 test_exact(0.0, 4, 0.0, entropy);
398 test_absolute(0.3, 3, 1.1404671643037712668976423399228972051669206536461, 1e-15, entropy);
399 test_exact(1.0, 2, 0.0, entropy);
400 }
401
402 #[test]
403 fn test_skewness() {
404 let skewness = |x: Binomial| x.skewness().unwrap();
405 test_exact(0.0, 4, f64::INFINITY, skewness);
406 test_exact(0.3, 3, 0.503952630678969636286, skewness);
407 test_exact(1.0, 2, f64::NEG_INFINITY, skewness);
408 }
409
410 #[test]
411 fn test_median() {
412 let median = |x: Binomial| x.median();
413 test_exact(0.0, 4, 0.0, median);
414 test_exact(0.3, 3, 0.0, median);
415 test_exact(1.0, 2, 2.0, median);
416 }
417
418 #[test]
419 fn test_mode() {
420 let mode = |x: Binomial| x.mode().unwrap();
421 test_exact(0.0, 4, 0, mode);
422 test_exact(0.3, 3, 1, mode);
423 test_exact(1.0, 2, 2, mode);
424 }
425
426 #[test]
427 fn test_min_max() {
428 let min = |x: Binomial| x.min();
429 let max = |x: Binomial| x.max();
430 test_exact(0.3, 10, 0, min);
431 test_exact(0.3, 10, 10, max);
432 }
433
434 #[test]
435 fn test_pmf() {
436 let pmf = |arg: u64| move |x: Binomial| x.pmf(arg);
437 test_exact(0.0, 1, 1.0, pmf(0));
438 test_exact(0.0, 1, 0.0, pmf(1));
439 test_exact(0.0, 3, 1.0, pmf(0));
440 test_exact(0.0, 3, 0.0, pmf(1));
441 test_exact(0.0, 3, 0.0, pmf(3));
442 test_exact(0.0, 10, 1.0, pmf(0));
443 test_exact(0.0, 10, 0.0, pmf(1));
444 test_exact(0.0, 10, 0.0, pmf(10));
445 test_exact(0.3, 1, 0.69999999999999995559107901499373838305473327636719, pmf(0));
446 test_exact(0.3, 1, 0.2999999999999999888977697537484345957636833190918, pmf(1));
447 test_exact(0.3, 3, 0.34299999999999993471888615204079956461021032657166, pmf(0));
448 test_absolute(0.3, 3, 0.44099999999999992772448109690231306411849135972008, 1e-15, pmf(1));
449 test_absolute(0.3, 3, 0.026999999999999997002397833512077451789759292859569, 1e-16, pmf(3));
450 test_absolute(0.3, 10, 0.02824752489999998207939855277004937778546385011091, 1e-17, pmf(0));
451 test_absolute(0.3, 10, 0.12106082099999992639752977030555903089040470780077, 1e-15, pmf(1));
452 test_absolute(0.3, 10, 0.0000059048999999999978147480206303047454017251032868501, 1e-20, pmf(10));
453 test_exact(1.0, 1, 0.0, pmf(0));
454 test_exact(1.0, 1, 1.0, pmf(1));
455 test_exact(1.0, 3, 0.0, pmf(0));
456 test_exact(1.0, 3, 0.0, pmf(1));
457 test_exact(1.0, 3, 1.0, pmf(3));
458 test_exact(1.0, 10, 0.0, pmf(0));
459 test_exact(1.0, 10, 0.0, pmf(1));
460 test_exact(1.0, 10, 1.0, pmf(10));
461 }
462
463 #[test]
464 fn test_ln_pmf() {
465 let ln_pmf = |arg: u64| move |x: Binomial| x.ln_pmf(arg);
466 test_exact(0.0, 1, 0.0, ln_pmf(0));
467 test_exact(0.0, 1, f64::NEG_INFINITY, ln_pmf(1));
468 test_exact(0.0, 3, 0.0, ln_pmf(0));
469 test_exact(0.0, 3, f64::NEG_INFINITY, ln_pmf(1));
470 test_exact(0.0, 3, f64::NEG_INFINITY, ln_pmf(3));
471 test_exact(0.0, 10, 0.0, ln_pmf(0));
472 test_exact(0.0, 10, f64::NEG_INFINITY, ln_pmf(1));
473 test_exact(0.0, 10, f64::NEG_INFINITY, ln_pmf(10));
474 test_exact(0.3, 1, -0.3566749439387324423539544041072745145718090708995, ln_pmf(0));
475 test_exact(0.3, 1, -1.2039728043259360296301803719337238685164245381839, ln_pmf(1));
476 test_exact(0.3, 3, -1.0700248318161973270618632123218235437154272126985, ln_pmf(0));
477 test_absolute(0.3, 3, -0.81871040353529122294284394322574719301255212216016, 1e-15, ln_pmf(1));
478 test_absolute(0.3, 3, -3.6119184129778080888905411158011716055492736145517, 1e-15, ln_pmf(3));
479 test_exact(0.3, 10, -3.566749439387324423539544041072745145718090708995, ln_pmf(0));
480 test_absolute(0.3, 10, -2.1114622067804823267977785542148302920616046876506, 1e-14, ln_pmf(1));
481 test_exact(0.3, 10, -12.039728043259360296301803719337238685164245381839, ln_pmf(10));
482 test_exact(1.0, 1, f64::NEG_INFINITY, ln_pmf(0));
483 test_exact(1.0, 1, 0.0, ln_pmf(1));
484 test_exact(1.0, 3, f64::NEG_INFINITY, ln_pmf(0));
485 test_exact(1.0, 3, f64::NEG_INFINITY, ln_pmf(1));
486 test_exact(1.0, 3, 0.0, ln_pmf(3));
487 test_exact(1.0, 10, f64::NEG_INFINITY, ln_pmf(0));
488 test_exact(1.0, 10, f64::NEG_INFINITY, ln_pmf(1));
489 test_exact(1.0, 10, 0.0, ln_pmf(10));
490 }
491
492 #[test]
493 fn test_cdf() {
494 let cdf = |arg: u64| move |x: Binomial| x.cdf(arg);
495 test_exact(0.0, 1, 1.0, cdf(0));
496 test_exact(0.0, 1, 1.0, cdf(1));
497 test_exact(0.0, 3, 1.0, cdf(0));
498 test_exact(0.0, 3, 1.0, cdf(1));
499 test_exact(0.0, 3, 1.0, cdf(3));
500 test_exact(0.0, 10, 1.0, cdf(0));
501 test_exact(0.0, 10, 1.0, cdf(1));
502 test_exact(0.0, 10, 1.0, cdf(10));
503 test_absolute(0.3, 1, 0.7, 1e-15, cdf(0));
504 test_exact(0.3, 1, 1.0, cdf(1));
505 test_absolute(0.3, 3, 0.343, 1e-14, cdf(0));
506 test_absolute(0.3, 3, 0.784, 1e-15, cdf(1));
507 test_exact(0.3, 3, 1.0, cdf(3));
508 test_absolute(0.3, 10, 0.0282475249, 1e-16, cdf(0));
509 test_absolute(0.3, 10, 0.1493083459, 1e-14, cdf(1));
510 test_exact(0.3, 10, 1.0, cdf(10));
511 test_exact(1.0, 1, 0.0, cdf(0));
512 test_exact(1.0, 1, 1.0, cdf(1));
513 test_exact(1.0, 3, 0.0, cdf(0));
514 test_exact(1.0, 3, 0.0, cdf(1));
515 test_exact(1.0, 3, 1.0, cdf(3));
516 test_exact(1.0, 10, 0.0, cdf(0));
517 test_exact(1.0, 10, 0.0, cdf(1));
518 test_exact(1.0, 10, 1.0, cdf(10));
519 }
520
521 #[test]
522 fn test_sf() {
523 let sf = |arg: u64| move |x: Binomial| x.sf(arg);
524 test_exact(0.0, 1, 0.0, sf(0));
525 test_exact(0.0, 1, 0.0, sf(1));
526 test_exact(0.0, 3, 0.0, sf(0));
527 test_exact(0.0, 3, 0.0, sf(1));
528 test_exact(0.0, 3, 0.0, sf(3));
529 test_exact(0.0, 10, 0.0, sf(0));
530 test_exact(0.0, 10, 0.0, sf(1));
531 test_exact(0.0, 10, 0.0, sf(10));
532 test_absolute(0.3, 1, 0.3, 1e-15, sf(0));
533 test_exact(0.3, 1, 0.0, sf(1));
534 test_absolute(0.3, 3, 0.657, 1e-14, sf(0));
535 test_absolute(0.3, 3, 0.216, 1e-15, sf(1));
536 test_exact(0.3, 3, 0.0, sf(3));
537 test_absolute(0.3, 10, 0.9717524751000001, 1e-16, sf(0));
538 test_absolute(0.3, 10, 0.850691654100002, 1e-14, sf(1));
539 test_exact(0.3, 10, 0.0, sf(10));
540 test_exact(1.0, 1, 1.0, sf(0));
541 test_exact(1.0, 1, 0.0, sf(1));
542 test_exact(1.0, 3, 1.0, sf(0));
543 test_exact(1.0, 3, 1.0, sf(1));
544 test_exact(1.0, 3, 0.0, sf(3));
545 test_exact(1.0, 10, 1.0, sf(0));
546 test_exact(1.0, 10, 1.0, sf(1));
547 test_exact(1.0, 10, 0.0, sf(10));
548 }
549
550 #[test]
551 fn test_cdf_upper_bound() {
552 let cdf = |arg: u64| move |x: Binomial| x.cdf(arg);
553 test_exact(0.5, 3, 1.0, cdf(5));
554 }
555
556 #[test]
557 fn test_sf_upper_bound() {
558 let sf = |arg: u64| move |x: Binomial| x.sf(arg);
559 test_exact(0.5, 3, 0.0, sf(5));
560 }
561
562 #[test]
563 fn test_inverse_cdf() {
564 let invcdf = |arg: f64| move |x: Binomial| x.inverse_cdf(arg);
565 test_exact(0.4, 5, 2, invcdf(0.3456));
566
567 test_exact(0.018, 465, 1, invcdf(3.472e-4));
569 test_exact(0.5, 6, 4, invcdf(0.75));
570 }
571
572 #[test]
573 fn test_cdf_inverse_cdf() {
574 let cdf_invcdf = |arg: u64| move |x: Binomial| x.inverse_cdf(x.cdf(arg));
575 test_exact(0.3, 10, 3, cdf_invcdf(3));
576 test_exact(0.3, 10, 4, cdf_invcdf(4));
577 test_exact(0.5, 6, 4, cdf_invcdf(4));
578 }
579
580 #[test]
581 fn test_discrete() {
582 test::check_discrete_distribution(&create_ok(0.3, 5), 5);
583 test::check_discrete_distribution(&create_ok(0.7, 10), 10);
584 }
585}