statrs/distribution/
beta.rs

1use crate::distribution::{Continuous, ContinuousCDF};
2use crate::function::{beta, gamma};
3use crate::statistics::*;
4
5/// Implements the [Beta](https://en.wikipedia.org/wiki/Beta_distribution)
6/// distribution
7///
8/// # Examples
9///
10/// ```
11/// use statrs::distribution::{Beta, Continuous};
12/// use statrs::statistics::*;
13/// use statrs::prec;
14///
15/// let n = Beta::new(2.0, 2.0).unwrap();
16/// assert_eq!(n.mean().unwrap(), 0.5);
17/// assert!(prec::almost_eq(n.pdf(0.5), 1.5, 1e-14));
18/// ```
19#[derive(Copy, Clone, PartialEq, Debug)]
20pub struct Beta {
21    shape_a: f64,
22    shape_b: f64,
23}
24
25/// Represents the errors that can occur when creating a [`Beta`].
26#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
27#[non_exhaustive]
28pub enum BetaError {
29    /// Shape A is NaN, infinite, zero or negative.
30    ShapeAInvalid,
31
32    /// Shape B is NaN, infinite, zero or negative.
33    ShapeBInvalid,
34}
35
36impl std::fmt::Display for BetaError {
37    #[cfg_attr(coverage_nightly, coverage(off))]
38    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
39        match self {
40            BetaError::ShapeAInvalid => write!(f, "Shape A is NaN, infinite, zero or negative"),
41            BetaError::ShapeBInvalid => write!(f, "Shape B is NaN, infinite, zero or negative"),
42        }
43    }
44}
45
46impl std::error::Error for BetaError {}
47
48impl Beta {
49    /// Constructs a new beta distribution with shapeA (α) of `shape_a`
50    /// and shapeB (β) of `shape_b`
51    ///
52    /// # Errors
53    ///
54    /// Returns an error if `shape_a` or `shape_b` are `NaN` or infinite.
55    /// Also returns an error if `shape_a <= 0.0` or `shape_b <= 0.0`
56    ///
57    /// # Examples
58    ///
59    /// ```
60    /// use statrs::distribution::Beta;
61    ///
62    /// let mut result = Beta::new(2.0, 2.0);
63    /// assert!(result.is_ok());
64    ///
65    /// result = Beta::new(0.0, 0.0);
66    /// assert!(result.is_err());
67    /// ```
68    pub fn new(shape_a: f64, shape_b: f64) -> Result<Beta, BetaError> {
69        if shape_a.is_nan() || shape_a.is_infinite() || shape_a <= 0.0 {
70            return Err(BetaError::ShapeAInvalid);
71        }
72
73        if shape_b.is_nan() || shape_b.is_infinite() || shape_b <= 0.0 {
74            return Err(BetaError::ShapeBInvalid);
75        }
76
77        Ok(Beta { shape_a, shape_b })
78    }
79
80    /// Returns the shapeA (α) of the beta distribution
81    ///
82    /// # Examples
83    ///
84    /// ```
85    /// use statrs::distribution::Beta;
86    ///
87    /// let n = Beta::new(1.0, 2.0).unwrap();
88    /// assert_eq!(n.shape_a(), 1.0);
89    /// ```
90    pub fn shape_a(&self) -> f64 {
91        self.shape_a
92    }
93
94    /// Returns the shapeB (β) of the beta distributionβ
95    ///
96    /// # Examples
97    ///
98    /// ```
99    /// use statrs::distribution::Beta;
100    ///
101    /// let n = Beta::new(1.0, 2.0).unwrap();
102    /// assert_eq!(n.shape_b(), 2.0);
103    /// ```
104    pub fn shape_b(&self) -> f64 {
105        self.shape_b
106    }
107}
108
109impl std::fmt::Display for Beta {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        write!(f, "Beta(a={}, b={})", self.shape_a, self.shape_b)
112    }
113}
114
115#[cfg(feature = "rand")]
116#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
117impl ::rand::distributions::Distribution<f64> for Beta {
118    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
119        // Generated by sampling two gamma distributions and normalizing.
120        let x = super::gamma::sample_unchecked(rng, self.shape_a, 1.0);
121        let y = super::gamma::sample_unchecked(rng, self.shape_b, 1.0);
122        x / (x + y)
123    }
124}
125
126impl ContinuousCDF<f64, f64> for Beta {
127    /// Calculates the cumulative distribution function for the beta
128    /// distribution at `x`.
129    ///
130    /// # Formula
131    ///
132    /// ```text
133    /// I_x(α, β)
134    /// ```
135    ///
136    /// where `α` is shapeA, `β` is shapeB, and `I_x` is the regularized
137    /// lower incomplete beta function.
138    fn cdf(&self, x: f64) -> f64 {
139        if x < 0.0 {
140            0.0
141        } else if x >= 1.0 {
142            1.0
143        } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) {
144            x
145        } else {
146            beta::beta_reg(self.shape_a, self.shape_b, x)
147        }
148    }
149
150    /// Calculates the survival function for the beta distribution at `x`.
151    ///
152    /// # Formula
153    ///
154    /// ```text
155    /// I_(1-x)(β, α)
156    /// ```
157    ///
158    /// where `α` is shapeA, `β` is shapeB, and `I_x` is the regularized
159    /// lower incomplete beta function.
160    fn sf(&self, x: f64) -> f64 {
161        if x < 0.0 {
162            1.0
163        } else if x >= 1.0 {
164            0.0
165        } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) {
166            1. - x
167        } else {
168            beta::beta_reg(self.shape_b, self.shape_a, 1.0 - x)
169        }
170    }
171
172    /// Calculates the inverse cumulative distribution function for the beta
173    /// distribution at `x`.
174    ///
175    /// # Panics
176    ///
177    /// If x is not in `[0, 1]`.
178    ///
179    /// # Formula
180    ///
181    /// ```text
182    /// I^{-1}_x(α, β)
183    /// ```
184    ///
185    /// where `α` is shapeA, `β` is shapeB, and `I_x` is the inverse of the
186    /// regularized lower incomplete beta function.
187    fn inverse_cdf(&self, x: f64) -> f64 {
188        if !(0.0..=1.0).contains(&x) {
189            panic!("x must be in [0, 1]");
190        } else {
191            beta::inv_beta_reg(self.shape_a, self.shape_b, x)
192        }
193    }
194}
195
196impl Min<f64> for Beta {
197    /// Returns the minimum value in the domain of the beta distribution
198    /// representable by a double precision float.
199    ///
200    /// # Formula
201    ///
202    /// ```text
203    /// 0
204    /// ```
205    fn min(&self) -> f64 {
206        0.0
207    }
208}
209
210impl Max<f64> for Beta {
211    /// Returns the maximum value in the domain of the beta distribution
212    /// representable by a double precision float.
213    ///
214    /// # Formula
215    ///
216    /// ```text
217    /// 1
218    /// ```
219    fn max(&self) -> f64 {
220        1.0
221    }
222}
223
224impl Distribution<f64> for Beta {
225    /// Returns the mean of the beta distribution.
226    ///
227    /// # Formula
228    ///
229    /// ```text
230    /// α / (α + β)
231    /// ```
232    ///
233    /// where `α` is shapeA and `β` is shapeB.
234    fn mean(&self) -> Option<f64> {
235        Some(self.shape_a / (self.shape_a + self.shape_b))
236    }
237
238    /// Returns the variance of the beta distribution.
239    ///
240    /// # Formula
241    ///
242    /// ```text
243    /// (α * β) / ((α + β)^2 * (α + β + 1))
244    /// ```
245    ///
246    /// where `α` is shapeA and `β` is shapeB.
247    fn variance(&self) -> Option<f64> {
248        Some(
249            self.shape_a * self.shape_b
250                / ((self.shape_a + self.shape_b)
251                    * (self.shape_a + self.shape_b)
252                    * (self.shape_a + self.shape_b + 1.0)),
253        )
254    }
255
256    /// Returns the entropy of the beta distribution.
257    ///
258    /// # Formula
259    ///
260    /// ```text
261    /// ln(B(α, β)) - (α - 1)ψ(α) - (β - 1)ψ(β) + (α + β - 2)ψ(α + β)
262    /// ```
263    ///
264    /// where `α` is shapeA, `β` is shapeB and `ψ` is the digamma function.
265    fn entropy(&self) -> Option<f64> {
266        Some(
267            beta::ln_beta(self.shape_a, self.shape_b)
268                - (self.shape_a - 1.0) * gamma::digamma(self.shape_a)
269                - (self.shape_b - 1.0) * gamma::digamma(self.shape_b)
270                + (self.shape_a + self.shape_b - 2.0) * gamma::digamma(self.shape_a + self.shape_b),
271        )
272    }
273
274    /// Returns the skewness of the Beta distribution.
275    ///
276    /// # Formula
277    ///
278    /// ```text
279    /// 2(β - α) * sqrt(α + β + 1) / ((α + β + 2) * sqrt(αβ))
280    /// ```
281    ///
282    /// where `α` is shapeA and `β` is shapeB.
283    fn skewness(&self) -> Option<f64> {
284        Some(
285            2.0 * (self.shape_b - self.shape_a) * (self.shape_a + self.shape_b + 1.0).sqrt()
286                / ((self.shape_a + self.shape_b + 2.0) * (self.shape_a * self.shape_b).sqrt()),
287        )
288    }
289}
290
291impl Mode<Option<f64>> for Beta {
292    /// Returns the mode of the Beta distribution. Returns `None` if `α <= 1`
293    /// or `β <= 1`.
294    ///
295    /// # Remarks
296    ///
297    /// Since the mode is technically only calculated for `α > 1, β > 1`, those
298    /// are the only values we allow. We may consider relaxing this constraint
299    /// in the future.
300    ///
301    /// # Formula
302    ///
303    /// ```text
304    /// (α - 1) / (α + β - 2)
305    /// ```
306    ///
307    /// where `α` is shapeA and `β` is shapeB
308    fn mode(&self) -> Option<f64> {
309        // TODO: perhaps relax constraint in order to allow calculation
310        // of 'anti-mode;
311        if self.shape_a <= 1.0 || self.shape_b <= 1.0 {
312            None
313        } else {
314            Some((self.shape_a - 1.0) / (self.shape_a + self.shape_b - 2.0))
315        }
316    }
317}
318
319impl Continuous<f64, f64> for Beta {
320    /// Calculates the probability density function for the beta distribution
321    /// at `x`.
322    ///
323    /// # Formula
324    ///
325    /// ```text
326    /// let B(α, β) = Γ(α)Γ(β)/Γ(α + β)
327    ///
328    /// x^(α - 1) * (1 - x)^(β - 1) / B(α, β)
329    /// ```
330    ///
331    /// where `α` is shapeA, `β` is shapeB, and `Γ` is the gamma function
332    fn pdf(&self, x: f64) -> f64 {
333        if !(0.0..=1.0).contains(&x) {
334            0.0
335        } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) {
336            1.0
337        } else if self.shape_a > 80.0 || self.shape_b > 80.0 {
338            self.ln_pdf(x).exp()
339        } else {
340            let bb = gamma::gamma(self.shape_a + self.shape_b)
341                / (gamma::gamma(self.shape_a) * gamma::gamma(self.shape_b));
342            bb * x.powf(self.shape_a - 1.0) * (1.0 - x).powf(self.shape_b - 1.0)
343        }
344    }
345
346    /// Calculates the log probability density function for the beta
347    /// distribution at `x`.
348    ///
349    /// # Formula
350    ///
351    /// ```text
352    /// let B(α, β) = Γ(α)Γ(β)/Γ(α + β)
353    ///
354    /// ln(x^(α - 1) * (1 - x)^(β - 1) / B(α, β))
355    /// ```
356    ///
357    /// where `α` is shapeA, `β` is shapeB, and `Γ` is the gamma function.
358    fn ln_pdf(&self, x: f64) -> f64 {
359        if !(0.0..=1.0).contains(&x) {
360            f64::NEG_INFINITY
361        } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) {
362            0.0
363        } else {
364            let aa = gamma::ln_gamma(self.shape_a + self.shape_b)
365                - gamma::ln_gamma(self.shape_a)
366                - gamma::ln_gamma(self.shape_b);
367            let bb = if ulps_eq!(self.shape_a, 1.0) && x == 0.0 {
368                0.0
369            } else if x == 0.0 {
370                f64::NEG_INFINITY
371            } else {
372                (self.shape_a - 1.0) * x.ln()
373            };
374            let cc = if ulps_eq!(self.shape_b, 1.0) && ulps_eq!(x, 1.0) {
375                0.0
376            } else if ulps_eq!(x, 1.0) {
377                f64::NEG_INFINITY
378            } else {
379                (self.shape_b - 1.0) * (1.0 - x).ln()
380            };
381            aa + bb + cc
382        }
383    }
384}
385
386#[rustfmt::skip]
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use super::super::internal::*;
391    use crate::testing_boiler;
392
393    testing_boiler!(a: f64, b: f64; Beta; BetaError);
394
395    #[test]
396    fn test_create() {
397        let valid = [(1.0, 1.0), (9.0, 1.0), (5.0, 100.0)];
398        for (a, b) in valid {
399            create_ok(a, b);
400        }
401    }
402
403    #[test]
404    fn test_bad_create() {
405        let invalid = [
406            (0.0, 0.0),
407            (0.0, 0.1),
408            (1.0, 0.0),
409            (0.5, f64::INFINITY),
410            (f64::INFINITY, 0.5),
411            (f64::NAN, 1.0),
412            (1.0, f64::NAN),
413            (f64::NAN, f64::NAN),
414            (1.0, -1.0),
415            (-1.0, 1.0),
416            (-1.0, -1.0),
417            (f64::INFINITY, f64::INFINITY),
418        ];
419        for (a, b) in invalid {
420            create_err(a, b);
421        }
422    }
423
424    #[test]
425    fn test_mean() {
426        let f = |x: Beta| x.mean().unwrap();
427        let test = [
428            ((1.0, 1.0), 0.5),
429            ((9.0, 1.0), 0.9),
430            ((5.0, 100.0), 0.047619047619047619047616),
431        ];
432        for ((a, b), res) in test {
433            test_relative(a, b, res, f);
434        }
435    }
436
437    #[test]
438    fn test_variance() {
439        let f = |x: Beta| x.variance().unwrap();
440        let test = [
441            ((1.0, 1.0), 1.0 / 12.0),
442            ((9.0, 1.0), 9.0 / 1100.0),
443            ((5.0, 100.0), 500.0 / 1168650.0),
444        ];
445        for ((a, b), res) in test {
446            test_relative(a, b, res, f);
447        }
448    }
449
450    #[test]
451    fn test_entropy() {
452        let f = |x: Beta| x.entropy().unwrap();
453        let test = [
454            ((9.0, 1.0), -1.3083356884473304939016015),
455            ((5.0, 100.0), -2.52016231876027436794592),
456        ];
457        for ((a, b), res) in test {
458            test_relative(a, b, res, f);
459        }
460        test_absolute(1.0, 1.0, 0.0, 1e-14, f);
461    }
462
463    #[test]
464    fn test_skewness() {
465        let skewness = |x: Beta| x.skewness().unwrap();
466        test_relative(1.0, 1.0, 0.0, skewness);
467        test_relative(9.0, 1.0, -1.4740554623801777107177478829, skewness);
468        test_relative(5.0, 100.0, 0.817594109275534303545831591, skewness);
469    }
470
471    #[test]
472    fn test_mode() {
473        let mode = |x: Beta| x.mode().unwrap();
474        test_relative(5.0, 100.0, 0.038834951456310676243255386, mode);
475    }
476
477    #[test]
478    fn test_mode_shape_a_lte_1() {
479        test_none(1.0, 5.0, |dist| dist.mode());
480    }
481
482    #[test]
483    fn test_mode_shape_b_lte_1() {
484        test_none(5.0, 1.0, |dist| dist.mode());
485    }
486
487    #[test]
488    fn test_min_max() {
489        let min = |x: Beta| x.min();
490        let max = |x: Beta| x.max();
491        test_relative(1.0, 1.0, 0.0, min);
492        test_relative(1.0, 1.0, 1.0, max);
493    }
494
495    #[test]
496    fn test_pdf() {
497        let f = |arg: f64| move |x: Beta| x.pdf(arg);
498        let test = [
499            ((1.0, 1.0), 0.0, 1.0),
500            ((1.0, 1.0), 0.5, 1.0),
501            ((1.0, 1.0), 1.0, 1.0),
502            ((9.0, 1.0), 0.0, 0.0),
503            ((9.0, 1.0), 0.5, 0.03515625),
504            ((9.0, 1.0), 1.0, 9.0),
505            ((5.0, 100.0), 0.0, 0.0),
506            ((5.0, 100.0), 0.5, 4.534102298350337661e-23),
507            ((5.0, 100.0), 1.0, 0.0),
508            ((5.0, 100.0), 1.0, 0.0)
509        ];
510        for ((a, b), x, expect) in test {
511            test_relative(a, b, expect, f(x));
512        }
513    }
514
515    #[test]
516    fn test_pdf_input_lt_0() {
517        let pdf = |arg: f64| move |x: Beta| x.pdf(arg);
518        test_relative(1.0, 1.0, 0.0, pdf(-1.0));
519    }
520
521    #[test]
522    fn test_pdf_input_gt_0() {
523        let pdf = |arg: f64| move |x: Beta| x.pdf(arg);
524        test_relative(1.0, 1.0, 0.0, pdf(2.0));
525    }
526
527    #[test]
528    fn test_ln_pdf() {
529        let f = |arg: f64| move |x: Beta| x.ln_pdf(arg);
530        let test = [
531            ((1.0, 1.0), 0.0, 0.0),
532            ((1.0, 1.0), 0.5, 0.0),
533            ((1.0, 1.0), 1.0, 0.0),
534            ((9.0, 1.0), 0.0, f64::NEG_INFINITY),
535            ((9.0, 1.0), 0.5, -3.347952867143343092547366497),
536            ((9.0, 1.0), 1.0, 2.1972245773362193827904904738),
537            ((5.0, 100.0), 0.0, f64::NEG_INFINITY),
538            ((5.0, 100.0), 0.5, -51.447830024537682154565870),
539            ((5.0, 100.0), 1.0, f64::NEG_INFINITY),
540        ];
541        for ((a, b), x, expect) in test {
542            test_relative(a, b, expect, f(x));
543        }
544    }
545
546    #[test]
547    fn test_ln_pdf_input_lt_0() {
548        let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg);
549        test_relative(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(-1.0));
550    }
551
552    #[test]
553    fn test_ln_pdf_input_gt_1() {
554        let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg);
555        test_relative(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(2.0));
556    }
557
558    #[test]
559    fn test_cdf() {
560        let cdf = |arg: f64| move |x: Beta| x.cdf(arg);
561        let test = [
562            ((1.0, 1.0), 0.0, 0.0),
563            ((1.0, 1.0), 0.5, 0.5),
564            ((1.0, 1.0), 1.0, 1.0),
565            ((9.0, 1.0), 0.0, 0.0),
566            ((9.0, 1.0), 0.5, 0.001953125),
567            ((9.0, 1.0), 1.0, 1.0),
568            ((5.0, 100.0), 0.0, 0.0),
569            ((5.0, 100.0), 0.5, 1.0),
570            ((5.0, 100.0), 1.0, 1.0),
571        ];
572        for ((a, b), x, expect) in test {
573            test_relative(a, b, expect, cdf(x));
574        }
575    }
576
577    #[test]
578    fn test_sf() {
579        let sf = |arg: f64| move |x: Beta| x.sf(arg);
580        let test = [
581            ((1.0, 1.0), 0.0, 1.0),
582            ((1.0, 1.0), 0.5, 0.5),
583            ((1.0, 1.0), 1.0, 0.0),
584            ((9.0, 1.0), 0.0, 1.0),
585            ((9.0, 1.0), 0.5, 0.998046875),
586            ((9.0, 1.0), 1.0, 0.0),
587            ((5.0, 100.0), 0.0, 1.0),
588            ((5.0, 100.0), 0.5, 0.0),
589            ((5.0, 100.0), 1.0, 0.0),
590        ];
591        for ((a, b), x, expect) in test {
592            test_relative(a, b, expect, sf(x));
593        }
594    }
595
596    #[test]
597    fn test_inverse_cdf() {
598        // let inverse_cdf = |arg: f64| move |x: Beta| x.inverse_cdf(arg);
599        let func = |arg: f64| move |x: Beta| x.inverse_cdf(x.cdf(arg));
600        let test = [
601            ((1.0, 1.0), 0.0, 0.0),
602            ((1.0, 1.0), 0.5, 0.5),
603            ((1.0, 1.0), 1.0, 1.0),
604            ((9.0, 1.0), 0.0, 0.0),
605            ((9.0, 1.0), 0.001953125, 0.001953125),
606            ((9.0, 1.0), 0.5, 0.5),
607            ((9.0, 1.0), 1.0, 1.0),
608            ((5.0, 100.0), 0.0, 0.0),
609            ((5.0, 100.0), 0.01, 0.01),
610            ((5.0, 100.0), 1.0, 1.0),
611        ];
612        for ((a, b), x, expect) in test {
613           test_relative(a, b, expect, func(x));
614        };
615    }
616
617    #[test]
618    fn test_cdf_input_lt_0() {
619        let cdf = |arg: f64| move |x: Beta| x.cdf(arg);
620        test_relative(1.0, 1.0, 0.0, cdf(-1.0));
621    }
622
623    #[test]
624    fn test_cdf_input_gt_1() {
625        let cdf = |arg: f64| move |x: Beta| x.cdf(arg);
626        test_relative(1.0, 1.0, 1.0, cdf(2.0));
627    }
628
629    #[test]
630    fn test_sf_input_lt_0() {
631        let sf = |arg: f64| move |x: Beta| x.sf(arg);
632        test_relative(1.0, 1.0, 1.0, sf(-1.0));
633    }
634
635    #[test]
636    fn test_sf_input_gt_1() {
637        let sf = |arg: f64| move |x: Beta| x.sf(arg);
638        test_relative(1.0, 1.0, 0.0, sf(2.0));
639    }
640
641    #[test]
642    fn test_continuous() {
643        test::check_continuous_distribution(&create_ok(1.2, 3.4), 0.0, 1.0);
644        test::check_continuous_distribution(&create_ok(4.5, 6.7), 0.0, 1.0);
645    }
646}