1use crate::distribution::{Continuous, ContinuousCDF};
2use crate::function::{beta, gamma};
3use crate::is_zero;
4use crate::statistics::*;
5use crate::{Result, StatsError};
6use core::f64::INFINITY as INF;
7use rand::Rng;
8
9#[derive(Debug, Copy, Clone, PartialEq)]
24pub struct Beta {
25 shape_a: f64,
26 shape_b: f64,
27}
28
29impl Beta {
30 pub fn new(shape_a: f64, shape_b: f64) -> Result<Beta> {
50 if shape_a.is_nan()
51 || shape_b.is_nan()
52 || shape_a.is_infinite() && shape_b.is_infinite()
53 || shape_a <= 0.0
54 || shape_b <= 0.0
55 {
56 return Err(StatsError::BadParams);
57 };
58 Ok(Beta { shape_a, shape_b })
59 }
60
61 pub fn shape_a(&self) -> f64 {
72 self.shape_a
73 }
74
75 pub fn shape_b(&self) -> f64 {
86 self.shape_b
87 }
88}
89
90impl ::rand::distributions::Distribution<f64> for Beta {
91 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
92 let x = super::gamma::sample_unchecked(rng, self.shape_a, 1.0);
94 let y = super::gamma::sample_unchecked(rng, self.shape_b, 1.0);
95 x / (x + y)
96 }
97}
98
99impl ContinuousCDF<f64, f64> for Beta {
100 fn cdf(&self, x: f64) -> f64 {
113 if x < 0.0 {
114 0.0
115 } else if x >= 1.0 {
116 1.0
117 } else if self.shape_a.is_infinite() {
118 if x < 1.0 {
119 0.0
120 } else {
121 1.0
122 }
123 } else if self.shape_b.is_infinite() {
124 1.0
125 } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) {
126 x
127 } else {
128 beta::beta_reg(self.shape_a, self.shape_b, x)
129 }
130 }
131
132 fn sf(&self, x: f64) -> f64 {
144 if x < 0.0 {
145 1.0
146 } else if x >= 1.0 {
147 0.0
148 } else if self.shape_a.is_infinite() {
149 if x < 1.0 {
150 1.0
151 } else {
152 0.0
153 }
154 } else if self.shape_b.is_infinite() {
155 0.0
156 } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) {
157 1. - x
158 } else {
159 beta::beta_reg(self.shape_b, self.shape_a, 1.0 - x)
160 }
161 }
162}
163
164impl Min<f64> for Beta {
165 fn min(&self) -> f64 {
175 0.0
176 }
177}
178
179impl Max<f64> for Beta {
180 fn max(&self) -> f64 {
190 1.0
191 }
192}
193
194impl Distribution<f64> for Beta {
195 fn mean(&self) -> Option<f64> {
205 let mean = if self.shape_a.is_infinite() {
206 1.0
207 } else {
208 self.shape_a / (self.shape_a + self.shape_b)
209 };
210 Some(mean)
211 }
212 fn variance(&self) -> Option<f64> {
224 let var = if self.shape_a.is_infinite() || self.shape_b.is_infinite() {
225 0.0
226 } else {
227 self.shape_a * self.shape_b
228 / ((self.shape_a + self.shape_b)
229 * (self.shape_a + self.shape_b)
230 * (self.shape_a + self.shape_b + 1.0))
231 };
232 Some(var)
233 }
234 fn entropy(&self) -> Option<f64> {
244 let entr = if self.shape_a.is_infinite() || self.shape_b.is_infinite() {
245 return None;
247 } else {
248 beta::ln_beta(self.shape_a, self.shape_b)
249 - (self.shape_a - 1.0) * gamma::digamma(self.shape_a)
250 - (self.shape_b - 1.0) * gamma::digamma(self.shape_b)
251 + (self.shape_a + self.shape_b - 2.0) * gamma::digamma(self.shape_a + self.shape_b)
252 };
253 Some(entr)
254 }
255 fn skewness(&self) -> Option<f64> {
265 let skew = if self.shape_a.is_infinite() {
266 -2.0
267 } else if self.shape_b.is_infinite() {
268 2.0
269 } else {
270 2.0 * (self.shape_b - self.shape_a) * (self.shape_a + self.shape_b + 1.0).sqrt()
271 / ((self.shape_a + self.shape_b + 2.0) * (self.shape_a * self.shape_b).sqrt())
272 };
273 Some(skew)
274 }
275}
276
277impl Mode<Option<f64>> for Beta {
278 fn mode(&self) -> Option<f64> {
299 if self.shape_a <= 1.0 || self.shape_b <= 1.0 {
302 None
303 } else if self.shape_a.is_infinite() {
304 Some(1.0)
305 } else {
306 Some((self.shape_a - 1.0) / (self.shape_a + self.shape_b - 2.0))
307 }
308 }
309}
310
311impl Continuous<f64, f64> for Beta {
312 fn pdf(&self, x: f64) -> f64 {
325 if !(0.0..=1.0).contains(&x) {
326 0.0
327 } else if self.shape_a.is_infinite() {
328 if ulps_eq!(x, 1.0) {
329 INF
330 } else {
331 0.0
332 }
333 } else if self.shape_b.is_infinite() {
334 if is_zero(x) {
335 INF
336 } else {
337 0.0
338 }
339 } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) {
340 1.0
341 } else if self.shape_a > 80.0 || self.shape_b > 80.0 {
342 self.ln_pdf(x).exp()
343 } else {
344 let bb = gamma::gamma(self.shape_a + self.shape_b)
345 / (gamma::gamma(self.shape_a) * gamma::gamma(self.shape_b));
346 bb * x.powf(self.shape_a - 1.0) * (1.0 - x).powf(self.shape_b - 1.0)
347 }
348 }
349
350 fn ln_pdf(&self, x: f64) -> f64 {
363 if !(0.0..=1.0).contains(&x) {
364 -INF
365 } else if self.shape_a.is_infinite() {
366 if ulps_eq!(x, 1.0) {
367 INF
368 } else {
369 -INF
370 }
371 } else if self.shape_b.is_infinite() {
372 if is_zero(x) {
373 INF
374 } else {
375 -INF
376 }
377 } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) {
378 0.0
379 } else {
380 let aa = gamma::ln_gamma(self.shape_a + self.shape_b)
381 - gamma::ln_gamma(self.shape_a)
382 - gamma::ln_gamma(self.shape_b);
383 let bb = if ulps_eq!(self.shape_a, 1.0) && is_zero(x) {
384 0.0
385 } else if is_zero(x) {
386 -INF
387 } else {
388 (self.shape_a - 1.0) * x.ln()
389 };
390 let cc = if ulps_eq!(self.shape_b, 1.0) && ulps_eq!(x, 1.0) {
391 0.0
392 } else if ulps_eq!(x, 1.0) {
393 -INF
394 } else {
395 (self.shape_b - 1.0) * (1.0 - x).ln()
396 };
397 aa + bb + cc
398 }
399 }
400}
401
402#[rustfmt::skip]
403#[cfg(all(test, feature = "nightly"))]
404mod tests {
405 use super::*;
406 use crate::consts::ACC;
407 use super::super::internal::*;
408 use crate::statistics::*;
409 use crate::testing_boiler;
410
411 testing_boiler!((f64, f64), Beta);
412
413 #[test]
414 fn test_create() {
415 let valid = [(1.0, 1.0), (9.0, 1.0), (5.0, 100.0), (1.0, INF), (INF, 1.0)];
416 for &arg in valid.iter() {
417 try_create(arg);
418 }
419 }
420
421 #[test]
422 fn test_bad_create() {
423 let invalid = [
424 (0.0, 0.0),
425 (0.0, 0.1),
426 (1.0, 0.0),
427 (0.0, INF),
428 (INF, 0.0),
429 (f64::NAN, 1.0),
430 (1.0, f64::NAN),
431 (f64::NAN, f64::NAN),
432 (1.0, -1.0),
433 (-1.0, 1.0),
434 (-1.0, -1.0),
435 (INF, INF),
436 ];
437 for &arg in invalid.iter() {
438 bad_create_case(arg);
439 }
440 }
441
442 #[test]
443 fn test_mean() {
444 let f = |x: Beta| x.mean().unwrap();
445 let test = [
446 ((1.0, 1.0), 0.5),
447 ((9.0, 1.0), 0.9),
448 ((5.0, 100.0), 0.047619047619047619047616),
449 ((1.0, INF), 0.0),
450 ((INF, 1.0), 1.0),
451 ];
452 for &(arg, res) in test.iter() {
453 test_case(arg, res, f);
454 }
455 }
456
457 #[test]
458 fn test_variance() {
459 let f = |x: Beta| x.variance().unwrap();
460 let test = [
461 ((1.0, 1.0), 1.0 / 12.0),
462 ((9.0, 1.0), 9.0 / 1100.0),
463 ((5.0, 100.0), 500.0 / 1168650.0),
464 ((1.0, INF), 0.0),
465 ((INF, 1.0), 0.0),
466 ];
467 for &(arg, res) in test.iter() {
468 test_case(arg, res, f);
469 }
470 }
471
472 #[test]
473 fn test_entropy() {
474 let f = |x: Beta| x.entropy().unwrap();
475 let test = [
476 ((9.0, 1.0), -1.3083356884473304939016015),
477 ((5.0, 100.0), -2.52016231876027436794592),
478 ];
479 for &(arg, res) in test.iter() {
480 test_case(arg, res, f);
481 }
482 test_case_special((1.0, 1.0), 0.0, 1e-14, f);
483 let entropy = |x: Beta| x.entropy();
484 test_none((1.0, INF), entropy);
485 test_none((INF, 1.0), entropy);
486 }
487
488 #[test]
489 fn test_skewness() {
490 let skewness = |x: Beta| x.skewness().unwrap();
491 test_case((1.0, 1.0), 0.0, skewness);
492 test_case((9.0, 1.0), -1.4740554623801777107177478829, skewness);
493 test_case((5.0, 100.0), 0.817594109275534303545831591, skewness);
494 test_case((1.0, INF), 2.0, skewness);
495 test_case((INF, 1.0), -2.0, skewness);
496 }
497
498 #[test]
499 fn test_mode() {
500 let mode = |x: Beta| x.mode().unwrap();
501 test_case((5.0, 100.0), 0.038834951456310676243255386, mode);
502 test_case((92.0, INF), 0.0, mode);
503 test_case((INF, 2.0), 1.0, mode);
504 }
505
506 #[test]
507 #[should_panic]
508 fn test_mode_shape_a_lte_1() {
509 let mode = |x: Beta| x.mode().unwrap();
510 get_value((1.0, 5.0), mode);
511 }
512
513 #[test]
514 #[should_panic]
515 fn test_mode_shape_b_lte_1() {
516 let mode = |x: Beta| x.mode().unwrap();
517 get_value((5.0, 1.0), mode);
518 }
519
520 #[test]
521 fn test_min_max() {
522 let min = |x: Beta| x.min();
523 let max = |x: Beta| x.max();
524 test_case((1.0, 1.0), 0.0, min);
525 test_case((1.0, 1.0), 1.0, max);
526 }
527
528 #[test]
529 fn test_pdf() {
530 let f = |arg: f64| move |x: Beta| x.pdf(arg);
531 let test = [
532 ((1.0, 1.0), 0.0, 1.0),
533 ((1.0, 1.0), 0.5, 1.0),
534 ((1.0, 1.0), 1.0, 1.0),
535 ((9.0, 1.0), 0.0, 0.0),
536 ((9.0, 1.0), 0.5, 0.03515625),
537 ((9.0, 1.0), 1.0, 9.0),
538 ((5.0, 100.0), 0.0, 0.0),
539 ((5.0, 100.0), 0.5, 4.534102298350337661e-23),
540 ((5.0, 100.0), 1.0, 0.0),
541 ((5.0, 100.0), 1.0, 0.0),
542 ((1.0, INF), 0.0, INF),
543 ((1.0, INF), 0.5, 0.0),
544 ((1.0, INF), 1.0, 0.0),
545 ((INF, 1.0), 0.0, 0.0),
546 ((INF, 1.0), 0.5, 0.0),
547 ((INF, 1.0), 1.0, INF),
548 ];
549 for &(arg, x, expect) in test.iter() {
550 test_case(arg, expect, f(x));
551 }
552 }
553
554 #[test]
555 fn test_pdf_input_lt_0() {
556 let pdf = |arg: f64| move |x: Beta| x.pdf(arg);
557 test_case((1.0, 1.0), 0.0, pdf(-1.0));
558 }
559
560 #[test]
561 fn test_pdf_input_gt_0() {
562 let pdf = |arg: f64| move |x: Beta| x.pdf(arg);
563 test_case((1.0, 1.0), 0.0, pdf(2.0));
564 }
565
566 #[test]
567 fn test_ln_pdf() {
568 let f = |arg: f64| move |x: Beta| x.ln_pdf(arg);
569 let test = [
570 ((1.0, 1.0), 0.0, 0.0),
571 ((1.0, 1.0), 0.5, 0.0),
572 ((1.0, 1.0), 1.0, 0.0),
573 ((9.0, 1.0), 0.0, -INF),
574 ((9.0, 1.0), 0.5, -3.347952867143343092547366497),
575 ((9.0, 1.0), 1.0, 2.1972245773362193827904904738),
576 ((5.0, 100.0), 0.0, -INF),
577 ((5.0, 100.0), 0.5, -51.447830024537682154565870),
578 ((5.0, 100.0), 1.0, -INF),
579 ((1.0, INF), 0.0, INF),
580 ((1.0, INF), 0.5, -INF),
581 ((1.0, INF), 1.0, -INF),
582 ((INF, 1.0), 0.0, -INF),
583 ((INF, 1.0), 0.5, -INF),
584 ((INF, 1.0), 1.0, INF),
585 ];
586 for &(arg, x, expect) in test.iter() {
587 test_case(arg, expect, f(x));
588 }
589 }
590
591 #[test]
592 fn test_ln_pdf_input_lt_0() {
593 let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg);
594 test_case((1.0, 1.0), -INF, ln_pdf(-1.0));
595 }
596
597 #[test]
598 fn test_ln_pdf_input_gt_1() {
599 let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg);
600 test_case((1.0, 1.0), -INF, ln_pdf(2.0));
601 }
602
603 #[test]
604 fn test_cdf() {
605 let cdf = |arg: f64| move |x: Beta| x.cdf(arg);
606 let test = [
607 ((1.0, 1.0), 0.0, 0.0),
608 ((1.0, 1.0), 0.5, 0.5),
609 ((1.0, 1.0), 1.0, 1.0),
610 ((9.0, 1.0), 0.0, 0.0),
611 ((9.0, 1.0), 0.5, 0.001953125),
612 ((9.0, 1.0), 1.0, 1.0),
613 ((5.0, 100.0), 0.0, 0.0),
614 ((5.0, 100.0), 0.5, 1.0),
615 ((5.0, 100.0), 1.0, 1.0),
616 ((1.0, INF), 0.0, 1.0),
617 ((1.0, INF), 0.5, 1.0),
618 ((1.0, INF), 1.0, 1.0),
619 ((INF, 1.0), 0.0, 0.0),
620 ((INF, 1.0), 0.5, 0.0),
621 ((INF, 1.0), 1.0, 1.0),
622 ];
623 for &(arg, x, expect) in test.iter() {
624 test_case(arg, expect, cdf(x));
625 }
626 }
627
628 #[test]
629 fn test_sf() {
630 let sf = |arg: f64| move |x: Beta| x.sf(arg);
631 let test = [
632 ((1.0, 1.0), 0.0, 1.0),
633 ((1.0, 1.0), 0.5, 0.5),
634 ((1.0, 1.0), 1.0, 0.0),
635 ((9.0, 1.0), 0.0, 1.0),
636 ((9.0, 1.0), 0.5, 0.998046875),
637 ((9.0, 1.0), 1.0, 0.0),
638 ((5.0, 100.0), 0.0, 1.0),
639 ((5.0, 100.0), 0.5, 0.0),
640 ((5.0, 100.0), 1.0, 0.0),
641 ((1.0, INF), 0.0, 0.0),
642 ((1.0, INF), 0.5, 0.0),
643 ((1.0, INF), 1.0, 0.0),
644 ((INF, 1.0), 0.0, 1.0),
645 ((INF, 1.0), 0.5, 1.0),
646 ((INF, 1.0), 1.0, 0.0),
647 ];
648 for &(arg, x, expect) in test.iter() {
649 test_case(arg, expect, sf(x));
650 }
651 }
652
653 #[test]
654 fn test_cdf_input_lt_0() {
655 let cdf = |arg: f64| move |x: Beta| x.cdf(arg);
656 test_case((1.0, 1.0), 0.0, cdf(-1.0));
657 }
658
659 #[test]
660 fn test_cdf_input_gt_1() {
661 let cdf = |arg: f64| move |x: Beta| x.cdf(arg);
662 test_case((1.0, 1.0), 1.0, cdf(2.0));
663 }
664
665 #[test]
666 fn test_sf_input_lt_0() {
667 let sf = |arg: f64| move |x: Beta| x.sf(arg);
668 test_case((1.0, 1.0), 1.0, sf(-1.0));
669 }
670
671 #[test]
672 fn test_sf_input_gt_1() {
673 let sf = |arg: f64| move |x: Beta| x.sf(arg);
674 test_case((1.0, 1.0), 0.0, sf(2.0));
675 }
676
677 #[test]
678 fn test_continuous() {
679 test::check_continuous_distribution(&try_create((1.2, 3.4)), 0.0, 1.0);
680 test::check_continuous_distribution(&try_create((4.5, 6.7)), 0.0, 1.0);
681 }
682}