1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::function::{beta, factorial};
3use crate::is_zero;
4use crate::statistics::*;
5use crate::{Result, StatsError};
6use rand::Rng;
7use std::f64;
8
9#[derive(Debug, Copy, Clone, PartialEq)]
25pub struct Binomial {
26 p: f64,
27 n: u64,
28}
29
30impl Binomial {
31 pub fn new(p: f64, n: u64) -> Result<Binomial> {
52 if p.is_nan() || p < 0.0 || p > 1.0 {
53 Err(StatsError::BadParams)
54 } else {
55 Ok(Binomial { p, n })
56 }
57 }
58
59 pub fn p(&self) -> f64 {
71 self.p
72 }
73
74 pub fn n(&self) -> u64 {
86 self.n
87 }
88}
89
90impl ::rand::distributions::Distribution<f64> for Binomial {
91 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
92 (0..self.n).fold(0.0, |acc, _| {
93 let n: f64 = rng.gen();
94 if n < self.p {
95 acc + 1.0
96 } else {
97 acc
98 }
99 })
100 }
101}
102
103impl DiscreteCDF<u64, f64> for Binomial {
104 fn cdf(&self, x: u64) -> f64 {
115 if x >= self.n {
116 1.0
117 } else {
118 let k = x;
119 beta::beta_reg((self.n - k) as f64, k as f64 + 1.0, 1.0 - self.p)
120 }
121 }
122
123 fn sf(&self, x: u64) -> f64 {
134 if x >= self.n {
135 0.0
136 } else {
137 let k = x;
138 beta::beta_reg(k as f64 + 1.0, (self.n - k) as f64, self.p)
139 }
140 }
141}
142
143impl Min<u64> for Binomial {
144 fn min(&self) -> u64 {
154 0
155 }
156}
157
158impl Max<u64> for Binomial {
159 fn max(&self) -> u64 {
169 self.n
170 }
171}
172
173impl Distribution<f64> for Binomial {
174 fn mean(&self) -> Option<f64> {
182 Some(self.p * self.n as f64)
183 }
184 fn variance(&self) -> Option<f64> {
192 Some(self.p * (1.0 - self.p) * self.n as f64)
193 }
194 fn entropy(&self) -> Option<f64> {
202 let entr = if is_zero(self.p) || ulps_eq!(self.p, 1.0) {
203 0.0
204 } else {
205 (0..self.n + 1).fold(0.0, |acc, x| {
206 let p = self.pmf(x);
207 acc - p * p.ln()
208 })
209 };
210 Some(entr)
211 }
212 fn skewness(&self) -> Option<f64> {
220 Some((1.0 - 2.0 * self.p) / (self.n as f64 * self.p * (1.0 - self.p)).sqrt())
221 }
222}
223
224impl Median<f64> for Binomial {
225 fn median(&self) -> f64 {
233 (self.p * self.n as f64).floor()
234 }
235}
236
237impl Mode<Option<u64>> for Binomial {
238 fn mode(&self) -> Option<u64> {
246 let mode = if is_zero(self.p) {
247 0
248 } else if ulps_eq!(self.p, 1.0) {
249 self.n
250 } else {
251 ((self.n as f64 + 1.0) * self.p).floor() as u64
252 };
253 Some(mode)
254 }
255}
256
257impl Discrete<u64, f64> for Binomial {
258 fn pmf(&self, x: u64) -> f64 {
267 if x > self.n {
268 0.0
269 } else if is_zero(self.p) {
270 if x == 0 {
271 1.0
272 } else {
273 0.0
274 }
275 } else if ulps_eq!(self.p, 1.0) {
276 if x == self.n {
277 1.0
278 } else {
279 0.0
280 }
281 } else {
282 (factorial::ln_binomial(self.n as u64, x as u64)
283 + x as f64 * self.p.ln()
284 + (self.n - x) as f64 * (1.0 - self.p).ln())
285 .exp()
286 }
287 }
288
289 fn ln_pmf(&self, x: u64) -> f64 {
298 if x > self.n {
299 f64::NEG_INFINITY
300 } else if is_zero(self.p) {
301 if x == 0 {
302 0.0
303 } else {
304 f64::NEG_INFINITY
305 }
306 } else if ulps_eq!(self.p, 1.0) {
307 if x == self.n {
308 0.0
309 } else {
310 f64::NEG_INFINITY
311 }
312 } else {
313 factorial::ln_binomial(self.n as u64, x as u64)
314 + x as f64 * self.p.ln()
315 + (self.n - x) as f64 * (1.0 - self.p).ln()
316 }
317 }
318}
319
320#[rustfmt::skip]
321#[cfg(all(test, feature = "nightly"))]
322mod tests {
323 use std::fmt::Debug;
324 use crate::statistics::*;
325 use crate::distribution::{DiscreteCDF, Discrete, Binomial};
326 use crate::distribution::internal::*;
327 use crate::consts::ACC;
328
329 fn try_create(p: f64, n: u64) -> Binomial {
330 let n = Binomial::new(p, n);
331 assert!(n.is_ok());
332 n.unwrap()
333 }
334
335 fn create_case(p: f64, n: u64) {
336 let dist = try_create(p, n);
337 assert_eq!(p, dist.p());
338 assert_eq!(n, dist.n());
339 }
340
341 fn bad_create_case(p: f64, n: u64) {
342 let n = Binomial::new(p, n);
343 assert!(n.is_err());
344 }
345
346 fn get_value<T, F>(p: f64, n: u64, eval: F) -> T
347 where T: PartialEq + Debug,
348 F: Fn(Binomial) -> T
349 {
350 let n = try_create(p, n);
351 eval(n)
352 }
353
354 fn test_case<T, F>(p: f64, n: u64, expected: T, eval: F)
355 where T: PartialEq + Debug,
356 F: Fn(Binomial) -> T
357 {
358 let x = get_value(p, n, eval);
359 println!("{} {} {:?}", p, n, expected);
360 assert_eq!(expected, x);
361 }
362
363 fn test_almost<F>(p: f64, n: u64, expected: f64, acc: f64, eval: F)
364 where F: Fn(Binomial) -> f64
365 {
366 let x = get_value(p, n, eval);
367 assert_almost_eq!(expected, x, acc);
368 }
369
370 #[test]
371 fn test_create() {
372 create_case(0.0, 4);
373 create_case(0.3, 3);
374 create_case(1.0, 2);
375 }
376
377 #[test]
378 fn test_bad_create() {
379 bad_create_case(f64::NAN, 1);
380 bad_create_case(-1.0, 1);
381 bad_create_case(2.0, 1);
382 }
383
384 #[test]
385 fn test_mean() {
386 let mean = |x: Binomial| x.mean().unwrap();
387 test_case(0.0, 4, 0.0, mean);
388 test_almost(0.3, 3, 0.9, 1e-15, mean);
389 test_case(1.0, 2, 2.0, mean);
390 }
391
392 #[test]
393 fn test_variance() {
394 let variance = |x: Binomial| x.variance().unwrap();
395 test_case(0.0, 4, 0.0, variance);
396 test_case(0.3, 3, 0.63, variance);
397 test_case(1.0, 2, 0.0, variance);
398 }
399
400 #[test]
401 fn test_entropy() {
402 let entropy = |x: Binomial| x.entropy().unwrap();
403 test_case(0.0, 4, 0.0, entropy);
404 test_almost(0.3, 3, 1.1404671643037712668976423399228972051669206536461, 1e-15, entropy);
405 test_case(1.0, 2, 0.0, entropy);
406 }
407
408 #[test]
409 fn test_skewness() {
410 let skewness = |x: Binomial| x.skewness().unwrap();
411 test_case(0.0, 4, f64::INFINITY, skewness);
412 test_case(0.3, 3, 0.503952630678969636286, skewness);
413 test_case(1.0, 2, f64::NEG_INFINITY, skewness);
414 }
415
416 #[test]
417 fn test_median() {
418 let median = |x: Binomial| x.median();
419 test_case(0.0, 4, 0.0, median);
420 test_case(0.3, 3, 0.0, median);
421 test_case(1.0, 2, 2.0, median);
422 }
423
424 #[test]
425 fn test_mode() {
426 let mode = |x: Binomial| x.mode().unwrap();
427 test_case(0.0, 4, 0, mode);
428 test_case(0.3, 3, 1, mode);
429 test_case(1.0, 2, 2, mode);
430 }
431
432 #[test]
433 fn test_min_max() {
434 let min = |x: Binomial| x.min();
435 let max = |x: Binomial| x.max();
436 test_case(0.3, 10, 0, min);
437 test_case(0.3, 10, 10, max);
438 }
439
440 #[test]
441 fn test_pmf() {
442 let pmf = |arg: u64| move |x: Binomial| x.pmf(arg);
443 test_case(0.0, 1, 1.0, pmf(0));
444 test_case(0.0, 1, 0.0, pmf(1));
445 test_case(0.0, 3, 1.0, pmf(0));
446 test_case(0.0, 3, 0.0, pmf(1));
447 test_case(0.0, 3, 0.0, pmf(3));
448 test_case(0.0, 10, 1.0, pmf(0));
449 test_case(0.0, 10, 0.0, pmf(1));
450 test_case(0.0, 10, 0.0, pmf(10));
451 test_case(0.3, 1, 0.69999999999999995559107901499373838305473327636719, pmf(0));
452 test_case(0.3, 1, 0.2999999999999999888977697537484345957636833190918, pmf(1));
453 test_case(0.3, 3, 0.34299999999999993471888615204079956461021032657166, pmf(0));
454 test_almost(0.3, 3, 0.44099999999999992772448109690231306411849135972008, 1e-15, pmf(1));
455 test_almost(0.3, 3, 0.026999999999999997002397833512077451789759292859569, 1e-16, pmf(3));
456 test_almost(0.3, 10, 0.02824752489999998207939855277004937778546385011091, 1e-17, pmf(0));
457 test_almost(0.3, 10, 0.12106082099999992639752977030555903089040470780077, 1e-15, pmf(1));
458 test_almost(0.3, 10, 0.0000059048999999999978147480206303047454017251032868501, 1e-20, pmf(10));
459 test_case(1.0, 1, 0.0, pmf(0));
460 test_case(1.0, 1, 1.0, pmf(1));
461 test_case(1.0, 3, 0.0, pmf(0));
462 test_case(1.0, 3, 0.0, pmf(1));
463 test_case(1.0, 3, 1.0, pmf(3));
464 test_case(1.0, 10, 0.0, pmf(0));
465 test_case(1.0, 10, 0.0, pmf(1));
466 test_case(1.0, 10, 1.0, pmf(10));
467 }
468
469 #[test]
470 fn test_ln_pmf() {
471 let ln_pmf = |arg: u64| move |x: Binomial| x.ln_pmf(arg);
472 test_case(0.0, 1, 0.0, ln_pmf(0));
473 test_case(0.0, 1, f64::NEG_INFINITY, ln_pmf(1));
474 test_case(0.0, 3, 0.0, ln_pmf(0));
475 test_case(0.0, 3, f64::NEG_INFINITY, ln_pmf(1));
476 test_case(0.0, 3, f64::NEG_INFINITY, ln_pmf(3));
477 test_case(0.0, 10, 0.0, ln_pmf(0));
478 test_case(0.0, 10, f64::NEG_INFINITY, ln_pmf(1));
479 test_case(0.0, 10, f64::NEG_INFINITY, ln_pmf(10));
480 test_case(0.3, 1, -0.3566749439387324423539544041072745145718090708995, ln_pmf(0));
481 test_case(0.3, 1, -1.2039728043259360296301803719337238685164245381839, ln_pmf(1));
482 test_case(0.3, 3, -1.0700248318161973270618632123218235437154272126985, ln_pmf(0));
483 test_almost(0.3, 3, -0.81871040353529122294284394322574719301255212216016, 1e-15, ln_pmf(1));
484 test_almost(0.3, 3, -3.6119184129778080888905411158011716055492736145517, 1e-15, ln_pmf(3));
485 test_case(0.3, 10, -3.566749439387324423539544041072745145718090708995, ln_pmf(0));
486 test_almost(0.3, 10, -2.1114622067804823267977785542148302920616046876506, 1e-14, ln_pmf(1));
487 test_case(0.3, 10, -12.039728043259360296301803719337238685164245381839, ln_pmf(10));
488 test_case(1.0, 1, f64::NEG_INFINITY, ln_pmf(0));
489 test_case(1.0, 1, 0.0, ln_pmf(1));
490 test_case(1.0, 3, f64::NEG_INFINITY, ln_pmf(0));
491 test_case(1.0, 3, f64::NEG_INFINITY, ln_pmf(1));
492 test_case(1.0, 3, 0.0, ln_pmf(3));
493 test_case(1.0, 10, f64::NEG_INFINITY, ln_pmf(0));
494 test_case(1.0, 10, f64::NEG_INFINITY, ln_pmf(1));
495 test_case(1.0, 10, 0.0, ln_pmf(10));
496 }
497
498 #[test]
499 fn test_cdf() {
500 let cdf = |arg: u64| move |x: Binomial| x.cdf(arg);
501 test_case(0.0, 1, 1.0, cdf(0));
502 test_case(0.0, 1, 1.0, cdf(1));
503 test_case(0.0, 3, 1.0, cdf(0));
504 test_case(0.0, 3, 1.0, cdf(1));
505 test_case(0.0, 3, 1.0, cdf(3));
506 test_case(0.0, 10, 1.0, cdf(0));
507 test_case(0.0, 10, 1.0, cdf(1));
508 test_case(0.0, 10, 1.0, cdf(10));
509 test_almost(0.3, 1, 0.7, 1e-15, cdf(0));
510 test_case(0.3, 1, 1.0, cdf(1));
511 test_almost(0.3, 3, 0.343, 1e-14, cdf(0));
512 test_almost(0.3, 3, 0.784, 1e-15, cdf(1));
513 test_case(0.3, 3, 1.0, cdf(3));
514 test_almost(0.3, 10, 0.0282475249, 1e-16, cdf(0));
515 test_almost(0.3, 10, 0.1493083459, 1e-14, cdf(1));
516 test_case(0.3, 10, 1.0, cdf(10));
517 test_case(1.0, 1, 0.0, cdf(0));
518 test_case(1.0, 1, 1.0, cdf(1));
519 test_case(1.0, 3, 0.0, cdf(0));
520 test_case(1.0, 3, 0.0, cdf(1));
521 test_case(1.0, 3, 1.0, cdf(3));
522 test_case(1.0, 10, 0.0, cdf(0));
523 test_case(1.0, 10, 0.0, cdf(1));
524 test_case(1.0, 10, 1.0, cdf(10));
525 }
526
527 #[test]
528 fn test_sf() {
529 let sf = |arg: u64| move |x: Binomial| x.sf(arg);
530 test_case(0.0, 1, 0.0, sf(0));
531 test_case(0.0, 1, 0.0, sf(1));
532 test_case(0.0, 3, 0.0, sf(0));
533 test_case(0.0, 3, 0.0, sf(1));
534 test_case(0.0, 3, 0.0, sf(3));
535 test_case(0.0, 10, 0.0, sf(0));
536 test_case(0.0, 10, 0.0, sf(1));
537 test_case(0.0, 10, 0.0, sf(10));
538 test_almost(0.3, 1, 0.3, 1e-15, sf(0));
539 test_case(0.3, 1, 0.0, sf(1));
540 test_almost(0.3, 3, 0.657, 1e-14, sf(0));
541 test_almost(0.3, 3, 0.216, 1e-15, sf(1));
542 test_case(0.3, 3, 0.0, sf(3));
543 test_almost(0.3, 10, 0.9717524751000001, 1e-16, sf(0));
544 test_almost(0.3, 10, 0.850691654100002, 1e-14, sf(1));
545 test_case(0.3, 10, 0.0, sf(10));
546 test_case(1.0, 1, 1.0, sf(0));
547 test_case(1.0, 1, 0.0, sf(1));
548 test_case(1.0, 3, 1.0, sf(0));
549 test_case(1.0, 3, 1.0, sf(1));
550 test_case(1.0, 3, 0.0, sf(3));
551 test_case(1.0, 10, 1.0, sf(0));
552 test_case(1.0, 10, 1.0, sf(1));
553 test_case(1.0, 10, 0.0, sf(10));
554 }
555
556 #[test]
557 fn test_cdf_upper_bound() {
558 let cdf = |arg: u64| move |x: Binomial| x.cdf(arg);
559 test_case(0.5, 3, 1.0, cdf(5));
560 }
561
562 #[test]
563 fn test_sf_upper_bound() {
564 let sf = |arg: u64| move |x: Binomial| x.sf(arg);
565 test_case(0.5, 3, 0.0, sf(5));
566 }
567
568 #[test]
569 fn test_inverse_cdf() {
570 let invcdf = |arg: f64| move |x: Binomial| x.inverse_cdf(arg);
571 test_case(0.4, 5, 2, invcdf(0.3456));
572
573 test_case(0.018, 465, 1, invcdf(3.472e-4));
575 test_case(0.5, 6, 4, invcdf(0.75));
576 }
577
578 #[test]
579 fn test_cdf_inverse_cdf() {
580 let cdf_invcdf = |arg: u64| move |x: Binomial| x.inverse_cdf(x.cdf(arg));
581 test_case(0.3, 10, 3, cdf_invcdf(3));
582 test_case(0.3, 10, 4, cdf_invcdf(4));
583 test_case(0.5, 6, 4, cdf_invcdf(4));
584 }
585
586 #[test]
587 fn test_discrete() {
588 test::check_discrete_distribution(&try_create(0.3, 5), 5);
589 test::check_discrete_distribution(&try_create(0.7, 10), 10);
590 }
591}