1use crate::distribution::{Continuous, ContinuousCDF};
2use crate::function::{beta, gamma};
3use crate::statistics::*;
4
5#[derive(Copy, Clone, PartialEq, Debug)]
20pub struct Beta {
21 shape_a: f64,
22 shape_b: f64,
23}
24
25#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
27#[non_exhaustive]
28pub enum BetaError {
29 ShapeAInvalid,
31
32 ShapeBInvalid,
34}
35
36impl std::fmt::Display for BetaError {
37 #[cfg_attr(coverage_nightly, coverage(off))]
38 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
39 match self {
40 BetaError::ShapeAInvalid => write!(f, "Shape A is NaN, infinite, zero or negative"),
41 BetaError::ShapeBInvalid => write!(f, "Shape B is NaN, infinite, zero or negative"),
42 }
43 }
44}
45
46impl std::error::Error for BetaError {}
47
48impl Beta {
49 pub fn new(shape_a: f64, shape_b: f64) -> Result<Beta, BetaError> {
69 if shape_a.is_nan() || shape_a.is_infinite() || shape_a <= 0.0 {
70 return Err(BetaError::ShapeAInvalid);
71 }
72
73 if shape_b.is_nan() || shape_b.is_infinite() || shape_b <= 0.0 {
74 return Err(BetaError::ShapeBInvalid);
75 }
76
77 Ok(Beta { shape_a, shape_b })
78 }
79
80 pub fn shape_a(&self) -> f64 {
91 self.shape_a
92 }
93
94 pub fn shape_b(&self) -> f64 {
105 self.shape_b
106 }
107}
108
109impl std::fmt::Display for Beta {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 write!(f, "Beta(a={}, b={})", self.shape_a, self.shape_b)
112 }
113}
114
115#[cfg(feature = "rand")]
116#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
117impl ::rand::distributions::Distribution<f64> for Beta {
118 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
119 let x = super::gamma::sample_unchecked(rng, self.shape_a, 1.0);
121 let y = super::gamma::sample_unchecked(rng, self.shape_b, 1.0);
122 x / (x + y)
123 }
124}
125
126impl ContinuousCDF<f64, f64> for Beta {
127 fn cdf(&self, x: f64) -> f64 {
139 if x < 0.0 {
140 0.0
141 } else if x >= 1.0 {
142 1.0
143 } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) {
144 x
145 } else {
146 beta::beta_reg(self.shape_a, self.shape_b, x)
147 }
148 }
149
150 fn sf(&self, x: f64) -> f64 {
161 if x < 0.0 {
162 1.0
163 } else if x >= 1.0 {
164 0.0
165 } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) {
166 1. - x
167 } else {
168 beta::beta_reg(self.shape_b, self.shape_a, 1.0 - x)
169 }
170 }
171
172 fn inverse_cdf(&self, x: f64) -> f64 {
188 if !(0.0..=1.0).contains(&x) {
189 panic!("x must be in [0, 1]");
190 } else {
191 beta::inv_beta_reg(self.shape_a, self.shape_b, x)
192 }
193 }
194}
195
196impl Min<f64> for Beta {
197 fn min(&self) -> f64 {
206 0.0
207 }
208}
209
210impl Max<f64> for Beta {
211 fn max(&self) -> f64 {
220 1.0
221 }
222}
223
224impl Distribution<f64> for Beta {
225 fn mean(&self) -> Option<f64> {
235 Some(self.shape_a / (self.shape_a + self.shape_b))
236 }
237
238 fn variance(&self) -> Option<f64> {
248 Some(
249 self.shape_a * self.shape_b
250 / ((self.shape_a + self.shape_b)
251 * (self.shape_a + self.shape_b)
252 * (self.shape_a + self.shape_b + 1.0)),
253 )
254 }
255
256 fn entropy(&self) -> Option<f64> {
266 Some(
267 beta::ln_beta(self.shape_a, self.shape_b)
268 - (self.shape_a - 1.0) * gamma::digamma(self.shape_a)
269 - (self.shape_b - 1.0) * gamma::digamma(self.shape_b)
270 + (self.shape_a + self.shape_b - 2.0) * gamma::digamma(self.shape_a + self.shape_b),
271 )
272 }
273
274 fn skewness(&self) -> Option<f64> {
284 Some(
285 2.0 * (self.shape_b - self.shape_a) * (self.shape_a + self.shape_b + 1.0).sqrt()
286 / ((self.shape_a + self.shape_b + 2.0) * (self.shape_a * self.shape_b).sqrt()),
287 )
288 }
289}
290
291impl Mode<Option<f64>> for Beta {
292 fn mode(&self) -> Option<f64> {
309 if self.shape_a <= 1.0 || self.shape_b <= 1.0 {
312 None
313 } else {
314 Some((self.shape_a - 1.0) / (self.shape_a + self.shape_b - 2.0))
315 }
316 }
317}
318
319impl Continuous<f64, f64> for Beta {
320 fn pdf(&self, x: f64) -> f64 {
333 if !(0.0..=1.0).contains(&x) {
334 0.0
335 } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) {
336 1.0
337 } else if self.shape_a > 80.0 || self.shape_b > 80.0 {
338 self.ln_pdf(x).exp()
339 } else {
340 let bb = gamma::gamma(self.shape_a + self.shape_b)
341 / (gamma::gamma(self.shape_a) * gamma::gamma(self.shape_b));
342 bb * x.powf(self.shape_a - 1.0) * (1.0 - x).powf(self.shape_b - 1.0)
343 }
344 }
345
346 fn ln_pdf(&self, x: f64) -> f64 {
359 if !(0.0..=1.0).contains(&x) {
360 f64::NEG_INFINITY
361 } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) {
362 0.0
363 } else {
364 let aa = gamma::ln_gamma(self.shape_a + self.shape_b)
365 - gamma::ln_gamma(self.shape_a)
366 - gamma::ln_gamma(self.shape_b);
367 let bb = if ulps_eq!(self.shape_a, 1.0) && x == 0.0 {
368 0.0
369 } else if x == 0.0 {
370 f64::NEG_INFINITY
371 } else {
372 (self.shape_a - 1.0) * x.ln()
373 };
374 let cc = if ulps_eq!(self.shape_b, 1.0) && ulps_eq!(x, 1.0) {
375 0.0
376 } else if ulps_eq!(x, 1.0) {
377 f64::NEG_INFINITY
378 } else {
379 (self.shape_b - 1.0) * (1.0 - x).ln()
380 };
381 aa + bb + cc
382 }
383 }
384}
385
386#[rustfmt::skip]
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use super::super::internal::*;
391 use crate::testing_boiler;
392
393 testing_boiler!(a: f64, b: f64; Beta; BetaError);
394
395 #[test]
396 fn test_create() {
397 let valid = [(1.0, 1.0), (9.0, 1.0), (5.0, 100.0)];
398 for (a, b) in valid {
399 create_ok(a, b);
400 }
401 }
402
403 #[test]
404 fn test_bad_create() {
405 let invalid = [
406 (0.0, 0.0),
407 (0.0, 0.1),
408 (1.0, 0.0),
409 (0.5, f64::INFINITY),
410 (f64::INFINITY, 0.5),
411 (f64::NAN, 1.0),
412 (1.0, f64::NAN),
413 (f64::NAN, f64::NAN),
414 (1.0, -1.0),
415 (-1.0, 1.0),
416 (-1.0, -1.0),
417 (f64::INFINITY, f64::INFINITY),
418 ];
419 for (a, b) in invalid {
420 create_err(a, b);
421 }
422 }
423
424 #[test]
425 fn test_mean() {
426 let f = |x: Beta| x.mean().unwrap();
427 let test = [
428 ((1.0, 1.0), 0.5),
429 ((9.0, 1.0), 0.9),
430 ((5.0, 100.0), 0.047619047619047619047616),
431 ];
432 for ((a, b), res) in test {
433 test_relative(a, b, res, f);
434 }
435 }
436
437 #[test]
438 fn test_variance() {
439 let f = |x: Beta| x.variance().unwrap();
440 let test = [
441 ((1.0, 1.0), 1.0 / 12.0),
442 ((9.0, 1.0), 9.0 / 1100.0),
443 ((5.0, 100.0), 500.0 / 1168650.0),
444 ];
445 for ((a, b), res) in test {
446 test_relative(a, b, res, f);
447 }
448 }
449
450 #[test]
451 fn test_entropy() {
452 let f = |x: Beta| x.entropy().unwrap();
453 let test = [
454 ((9.0, 1.0), -1.3083356884473304939016015),
455 ((5.0, 100.0), -2.52016231876027436794592),
456 ];
457 for ((a, b), res) in test {
458 test_relative(a, b, res, f);
459 }
460 test_absolute(1.0, 1.0, 0.0, 1e-14, f);
461 }
462
463 #[test]
464 fn test_skewness() {
465 let skewness = |x: Beta| x.skewness().unwrap();
466 test_relative(1.0, 1.0, 0.0, skewness);
467 test_relative(9.0, 1.0, -1.4740554623801777107177478829, skewness);
468 test_relative(5.0, 100.0, 0.817594109275534303545831591, skewness);
469 }
470
471 #[test]
472 fn test_mode() {
473 let mode = |x: Beta| x.mode().unwrap();
474 test_relative(5.0, 100.0, 0.038834951456310676243255386, mode);
475 }
476
477 #[test]
478 fn test_mode_shape_a_lte_1() {
479 test_none(1.0, 5.0, |dist| dist.mode());
480 }
481
482 #[test]
483 fn test_mode_shape_b_lte_1() {
484 test_none(5.0, 1.0, |dist| dist.mode());
485 }
486
487 #[test]
488 fn test_min_max() {
489 let min = |x: Beta| x.min();
490 let max = |x: Beta| x.max();
491 test_relative(1.0, 1.0, 0.0, min);
492 test_relative(1.0, 1.0, 1.0, max);
493 }
494
495 #[test]
496 fn test_pdf() {
497 let f = |arg: f64| move |x: Beta| x.pdf(arg);
498 let test = [
499 ((1.0, 1.0), 0.0, 1.0),
500 ((1.0, 1.0), 0.5, 1.0),
501 ((1.0, 1.0), 1.0, 1.0),
502 ((9.0, 1.0), 0.0, 0.0),
503 ((9.0, 1.0), 0.5, 0.03515625),
504 ((9.0, 1.0), 1.0, 9.0),
505 ((5.0, 100.0), 0.0, 0.0),
506 ((5.0, 100.0), 0.5, 4.534102298350337661e-23),
507 ((5.0, 100.0), 1.0, 0.0),
508 ((5.0, 100.0), 1.0, 0.0)
509 ];
510 for ((a, b), x, expect) in test {
511 test_relative(a, b, expect, f(x));
512 }
513 }
514
515 #[test]
516 fn test_pdf_input_lt_0() {
517 let pdf = |arg: f64| move |x: Beta| x.pdf(arg);
518 test_relative(1.0, 1.0, 0.0, pdf(-1.0));
519 }
520
521 #[test]
522 fn test_pdf_input_gt_0() {
523 let pdf = |arg: f64| move |x: Beta| x.pdf(arg);
524 test_relative(1.0, 1.0, 0.0, pdf(2.0));
525 }
526
527 #[test]
528 fn test_ln_pdf() {
529 let f = |arg: f64| move |x: Beta| x.ln_pdf(arg);
530 let test = [
531 ((1.0, 1.0), 0.0, 0.0),
532 ((1.0, 1.0), 0.5, 0.0),
533 ((1.0, 1.0), 1.0, 0.0),
534 ((9.0, 1.0), 0.0, f64::NEG_INFINITY),
535 ((9.0, 1.0), 0.5, -3.347952867143343092547366497),
536 ((9.0, 1.0), 1.0, 2.1972245773362193827904904738),
537 ((5.0, 100.0), 0.0, f64::NEG_INFINITY),
538 ((5.0, 100.0), 0.5, -51.447830024537682154565870),
539 ((5.0, 100.0), 1.0, f64::NEG_INFINITY),
540 ];
541 for ((a, b), x, expect) in test {
542 test_relative(a, b, expect, f(x));
543 }
544 }
545
546 #[test]
547 fn test_ln_pdf_input_lt_0() {
548 let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg);
549 test_relative(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(-1.0));
550 }
551
552 #[test]
553 fn test_ln_pdf_input_gt_1() {
554 let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg);
555 test_relative(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(2.0));
556 }
557
558 #[test]
559 fn test_cdf() {
560 let cdf = |arg: f64| move |x: Beta| x.cdf(arg);
561 let test = [
562 ((1.0, 1.0), 0.0, 0.0),
563 ((1.0, 1.0), 0.5, 0.5),
564 ((1.0, 1.0), 1.0, 1.0),
565 ((9.0, 1.0), 0.0, 0.0),
566 ((9.0, 1.0), 0.5, 0.001953125),
567 ((9.0, 1.0), 1.0, 1.0),
568 ((5.0, 100.0), 0.0, 0.0),
569 ((5.0, 100.0), 0.5, 1.0),
570 ((5.0, 100.0), 1.0, 1.0),
571 ];
572 for ((a, b), x, expect) in test {
573 test_relative(a, b, expect, cdf(x));
574 }
575 }
576
577 #[test]
578 fn test_sf() {
579 let sf = |arg: f64| move |x: Beta| x.sf(arg);
580 let test = [
581 ((1.0, 1.0), 0.0, 1.0),
582 ((1.0, 1.0), 0.5, 0.5),
583 ((1.0, 1.0), 1.0, 0.0),
584 ((9.0, 1.0), 0.0, 1.0),
585 ((9.0, 1.0), 0.5, 0.998046875),
586 ((9.0, 1.0), 1.0, 0.0),
587 ((5.0, 100.0), 0.0, 1.0),
588 ((5.0, 100.0), 0.5, 0.0),
589 ((5.0, 100.0), 1.0, 0.0),
590 ];
591 for ((a, b), x, expect) in test {
592 test_relative(a, b, expect, sf(x));
593 }
594 }
595
596 #[test]
597 fn test_inverse_cdf() {
598 let func = |arg: f64| move |x: Beta| x.inverse_cdf(x.cdf(arg));
600 let test = [
601 ((1.0, 1.0), 0.0, 0.0),
602 ((1.0, 1.0), 0.5, 0.5),
603 ((1.0, 1.0), 1.0, 1.0),
604 ((9.0, 1.0), 0.0, 0.0),
605 ((9.0, 1.0), 0.001953125, 0.001953125),
606 ((9.0, 1.0), 0.5, 0.5),
607 ((9.0, 1.0), 1.0, 1.0),
608 ((5.0, 100.0), 0.0, 0.0),
609 ((5.0, 100.0), 0.01, 0.01),
610 ((5.0, 100.0), 1.0, 1.0),
611 ];
612 for ((a, b), x, expect) in test {
613 test_relative(a, b, expect, func(x));
614 };
615 }
616
617 #[test]
618 fn test_cdf_input_lt_0() {
619 let cdf = |arg: f64| move |x: Beta| x.cdf(arg);
620 test_relative(1.0, 1.0, 0.0, cdf(-1.0));
621 }
622
623 #[test]
624 fn test_cdf_input_gt_1() {
625 let cdf = |arg: f64| move |x: Beta| x.cdf(arg);
626 test_relative(1.0, 1.0, 1.0, cdf(2.0));
627 }
628
629 #[test]
630 fn test_sf_input_lt_0() {
631 let sf = |arg: f64| move |x: Beta| x.sf(arg);
632 test_relative(1.0, 1.0, 1.0, sf(-1.0));
633 }
634
635 #[test]
636 fn test_sf_input_gt_1() {
637 let sf = |arg: f64| move |x: Beta| x.sf(arg);
638 test_relative(1.0, 1.0, 0.0, sf(2.0));
639 }
640
641 #[test]
642 fn test_continuous() {
643 test::check_continuous_distribution(&create_ok(1.2, 3.4), 0.0, 1.0);
644 test::check_continuous_distribution(&create_ok(4.5, 6.7), 0.0, 1.0);
645 }
646}