statrs/distribution/
beta.rs

1use crate::distribution::{Continuous, ContinuousCDF};
2use crate::function::{beta, gamma};
3use crate::is_zero;
4use crate::statistics::*;
5use crate::{Result, StatsError};
6use core::f64::INFINITY as INF;
7use rand::Rng;
8
9/// Implements the [Beta](https://en.wikipedia.org/wiki/Beta_distribution)
10/// distribution
11///
12/// # Examples
13///
14/// ```
15/// use statrs::distribution::{Beta, Continuous};
16/// use statrs::statistics::*;
17/// use statrs::prec;
18///
19/// let n = Beta::new(2.0, 2.0).unwrap();
20/// assert_eq!(n.mean().unwrap(), 0.5);
21/// assert!(prec::almost_eq(n.pdf(0.5), 1.5, 1e-14));
22/// ```
23#[derive(Debug, Copy, Clone, PartialEq)]
24pub struct Beta {
25    shape_a: f64,
26    shape_b: f64,
27}
28
29impl Beta {
30    /// Constructs a new beta distribution with shapeA (α) of `shape_a`
31    /// and shapeB (β) of `shape_b`
32    ///
33    /// # Errors
34    ///
35    /// Returns an error if `shape_a` or `shape_b` are `NaN`.
36    /// Also returns an error if `shape_a <= 0.0` or `shape_b <= 0.0`
37    ///
38    /// # Examples
39    ///
40    /// ```
41    /// use statrs::distribution::Beta;
42    ///
43    /// let mut result = Beta::new(2.0, 2.0);
44    /// assert!(result.is_ok());
45    ///
46    /// result = Beta::new(0.0, 0.0);
47    /// assert!(result.is_err());
48    /// ```
49    pub fn new(shape_a: f64, shape_b: f64) -> Result<Beta> {
50        if shape_a.is_nan()
51            || shape_b.is_nan()
52            || shape_a.is_infinite() && shape_b.is_infinite()
53            || shape_a <= 0.0
54            || shape_b <= 0.0
55        {
56            return Err(StatsError::BadParams);
57        };
58        Ok(Beta { shape_a, shape_b })
59    }
60
61    /// Returns the shapeA (α) of the beta distribution
62    ///
63    /// # Examples
64    ///
65    /// ```
66    /// use statrs::distribution::Beta;
67    ///
68    /// let n = Beta::new(2.0, 2.0).unwrap();
69    /// assert_eq!(n.shape_a(), 2.0);
70    /// ```
71    pub fn shape_a(&self) -> f64 {
72        self.shape_a
73    }
74
75    /// Returns the shapeB (β) of the beta distributionβ
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use statrs::distribution::Beta;
81    ///
82    /// let n = Beta::new(2.0, 2.0).unwrap();
83    /// assert_eq!(n.shape_b(), 2.0);
84    /// ```
85    pub fn shape_b(&self) -> f64 {
86        self.shape_b
87    }
88}
89
90impl ::rand::distributions::Distribution<f64> for Beta {
91    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
92        // Generated by sampling two gamma distributions and normalizing.
93        let x = super::gamma::sample_unchecked(rng, self.shape_a, 1.0);
94        let y = super::gamma::sample_unchecked(rng, self.shape_b, 1.0);
95        x / (x + y)
96    }
97}
98
99impl ContinuousCDF<f64, f64> for Beta {
100    /// Calculates the cumulative distribution function for the beta
101    /// distribution
102    /// at `x`
103    ///
104    /// # Formula
105    ///
106    /// ```ignore
107    /// I_x(α, β)
108    /// ```
109    ///
110    /// where `α` is shapeA, `β` is shapeB, and `I_x` is the regularized
111    /// lower incomplete beta function
112    fn cdf(&self, x: f64) -> f64 {
113        if x < 0.0 {
114            0.0
115        } else if x >= 1.0 {
116            1.0
117        } else if self.shape_a.is_infinite() {
118            if x < 1.0 {
119                0.0
120            } else {
121                1.0
122            }
123        } else if self.shape_b.is_infinite() {
124            1.0
125        } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) {
126            x
127        } else {
128            beta::beta_reg(self.shape_a, self.shape_b, x)
129        }
130    }
131
132    /// Calculates the survival function for the beta
133    /// distribution at `x`
134    ///
135    /// # Formula
136    ///
137    /// ```ignore
138    /// I_(1-x)(β, α)
139    /// ```
140    ///
141    /// where `α` is shapeA, `β` is shapeB, and `I_x` is the regularized
142    /// lower incomplete beta function
143    fn sf(&self, x: f64) -> f64 {
144        if x < 0.0 {
145            1.0
146        } else if x >= 1.0 {
147            0.0
148        } else if self.shape_a.is_infinite() {
149            if x < 1.0 {
150                1.0
151            } else {
152                0.0
153            }
154        } else if self.shape_b.is_infinite() {
155            0.0
156        } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) {
157            1. - x
158        } else {
159            beta::beta_reg(self.shape_b, self.shape_a, 1.0 - x) 
160        }
161    }
162}
163
164impl Min<f64> for Beta {
165    /// Returns the minimum value in the domain of the
166    /// beta distribution representable by a double precision
167    /// float
168    ///
169    /// # Formula
170    ///
171    /// ```ignore
172    /// 0
173    /// ```
174    fn min(&self) -> f64 {
175        0.0
176    }
177}
178
179impl Max<f64> for Beta {
180    /// Returns the maximum value in the domain of the
181    /// beta distribution representable by a double precision
182    /// float
183    ///
184    /// # Formula
185    ///
186    /// ```ignore
187    /// 1
188    /// ```
189    fn max(&self) -> f64 {
190        1.0
191    }
192}
193
194impl Distribution<f64> for Beta {
195    /// Returns the mean of the beta distribution
196    ///
197    /// # Formula
198    ///
199    /// ```ignore
200    /// α / (α + β)
201    /// ```
202    ///
203    /// where `α` is shapeA and `β` is shapeB
204    fn mean(&self) -> Option<f64> {
205        let mean = if self.shape_a.is_infinite() {
206            1.0
207        } else {
208            self.shape_a / (self.shape_a + self.shape_b)
209        };
210        Some(mean)
211    }
212    /// Returns the variance of the beta distribution
213    ///
214    /// # Remarks
215    ///
216    /// # Formula
217    ///
218    /// ```ignore
219    /// (α * β) / ((α + β)^2 * (α + β + 1))
220    /// ```
221    ///
222    /// where `α` is shapeA and `β` is shapeB
223    fn variance(&self) -> Option<f64> {
224        let var = if self.shape_a.is_infinite() || self.shape_b.is_infinite() {
225            0.0
226        } else {
227            self.shape_a * self.shape_b
228                / ((self.shape_a + self.shape_b)
229                    * (self.shape_a + self.shape_b)
230                    * (self.shape_a + self.shape_b + 1.0))
231        };
232        Some(var)
233    }
234    /// Returns the entropy of the beta distribution
235    ///
236    /// # Formula
237    ///
238    /// ```ignore
239    /// ln(B(α, β)) - (α - 1)ψ(α) - (β - 1)ψ(β) + (α + β - 2)ψ(α + β)
240    /// ```
241    ///
242    /// where `α` is shapeA, `β` is shapeB and `ψ` is the digamma function
243    fn entropy(&self) -> Option<f64> {
244        let entr = if self.shape_a.is_infinite() || self.shape_b.is_infinite() {
245            // unsupported limit
246            return None;
247        } else {
248            beta::ln_beta(self.shape_a, self.shape_b)
249                - (self.shape_a - 1.0) * gamma::digamma(self.shape_a)
250                - (self.shape_b - 1.0) * gamma::digamma(self.shape_b)
251                + (self.shape_a + self.shape_b - 2.0) * gamma::digamma(self.shape_a + self.shape_b)
252        };
253        Some(entr)
254    }
255    /// Returns the skewness of the Beta distribution
256    ///
257    /// # Formula
258    ///
259    /// ```ignore
260    /// 2(β - α) * sqrt(α + β + 1) / ((α + β + 2) * sqrt(αβ))
261    /// ```
262    ///
263    /// where `α` is shapeA and `β` is shapeB
264    fn skewness(&self) -> Option<f64> {
265        let skew = if self.shape_a.is_infinite() {
266            -2.0
267        } else if self.shape_b.is_infinite() {
268            2.0
269        } else {
270            2.0 * (self.shape_b - self.shape_a) * (self.shape_a + self.shape_b + 1.0).sqrt()
271                / ((self.shape_a + self.shape_b + 2.0) * (self.shape_a * self.shape_b).sqrt())
272        };
273        Some(skew)
274    }
275}
276
277impl Mode<Option<f64>> for Beta {
278    /// Returns the mode of the Beta distribution.
279    ///
280    /// # Remarks
281    ///
282    /// Since the mode is technically only calculate for `α > 1, β > 1`, those
283    /// are the only values we allow. We may consider relaxing this constraint
284    /// in
285    /// the future.
286    ///
287    /// # Panics
288    ///
289    /// If `α <= 1` or `β <= 1`
290    ///
291    /// # Formula
292    ///
293    /// ```ignore
294    /// (α - 1) / (α + β - 2)
295    /// ```
296    ///
297    /// where `α` is shapeA and `β` is shapeB
298    fn mode(&self) -> Option<f64> {
299        // TODO: perhaps relax constraint in order to allow calculation
300        // of 'anti-mode;
301        if self.shape_a <= 1.0 || self.shape_b <= 1.0 {
302            None
303        } else if self.shape_a.is_infinite() {
304            Some(1.0)
305        } else {
306            Some((self.shape_a - 1.0) / (self.shape_a + self.shape_b - 2.0))
307        }
308    }
309}
310
311impl Continuous<f64, f64> for Beta {
312    /// Calculates the probability density function for the beta distribution
313    /// at `x`.
314    ///
315    /// # Formula
316    ///
317    /// ```ignore
318    /// let B(α, β) = Γ(α)Γ(β)/Γ(α + β)
319    ///
320    /// x^(α - 1) * (1 - x)^(β - 1) / B(α, β)
321    /// ```
322    ///
323    /// where `α` is shapeA, `β` is shapeB, and `Γ` is the gamma function
324    fn pdf(&self, x: f64) -> f64 {
325        if !(0.0..=1.0).contains(&x) {
326            0.0
327        } else if self.shape_a.is_infinite() {
328            if ulps_eq!(x, 1.0) {
329                INF
330            } else {
331                0.0
332            }
333        } else if self.shape_b.is_infinite() {
334            if is_zero(x) {
335                INF
336            } else {
337                0.0
338            }
339        } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) {
340            1.0
341        } else if self.shape_a > 80.0 || self.shape_b > 80.0 {
342            self.ln_pdf(x).exp()
343        } else {
344            let bb = gamma::gamma(self.shape_a + self.shape_b)
345                / (gamma::gamma(self.shape_a) * gamma::gamma(self.shape_b));
346            bb * x.powf(self.shape_a - 1.0) * (1.0 - x).powf(self.shape_b - 1.0)
347        }
348    }
349
350    /// Calculates the log probability density function for the beta
351    /// distribution at `x`.
352    ///
353    /// # Formula
354    ///
355    /// ```ignore
356    /// let B(α, β) = Γ(α)Γ(β)/Γ(α + β)
357    ///
358    /// ln(x^(α - 1) * (1 - x)^(β - 1) / B(α, β))
359    /// ```
360    ///
361    /// where `α` is shapeA, `β` is shapeB, and `Γ` is the gamma function
362    fn ln_pdf(&self, x: f64) -> f64 {
363        if !(0.0..=1.0).contains(&x) {
364            -INF
365        } else if self.shape_a.is_infinite() {
366            if ulps_eq!(x, 1.0) {
367                INF
368            } else {
369                -INF
370            }
371        } else if self.shape_b.is_infinite() {
372            if is_zero(x) {
373                INF
374            } else {
375                -INF
376            }
377        } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) {
378            0.0
379        } else {
380            let aa = gamma::ln_gamma(self.shape_a + self.shape_b)
381                - gamma::ln_gamma(self.shape_a)
382                - gamma::ln_gamma(self.shape_b);
383            let bb = if ulps_eq!(self.shape_a, 1.0) && is_zero(x) {
384                0.0
385            } else if is_zero(x) {
386                -INF
387            } else {
388                (self.shape_a - 1.0) * x.ln()
389            };
390            let cc = if ulps_eq!(self.shape_b, 1.0) && ulps_eq!(x, 1.0) {
391                0.0
392            } else if ulps_eq!(x, 1.0) {
393                -INF
394            } else {
395                (self.shape_b - 1.0) * (1.0 - x).ln()
396            };
397            aa + bb + cc
398        }
399    }
400}
401
402#[rustfmt::skip]
403#[cfg(all(test, feature = "nightly"))]
404mod tests {
405    use super::*;
406    use crate::consts::ACC;
407    use super::super::internal::*;
408    use crate::statistics::*;
409    use crate::testing_boiler;
410
411    testing_boiler!((f64, f64), Beta);
412
413    #[test]
414    fn test_create() {
415        let valid = [(1.0, 1.0), (9.0, 1.0), (5.0, 100.0), (1.0, INF), (INF, 1.0)];
416        for &arg in valid.iter() {
417            try_create(arg);
418        }
419    }
420
421    #[test]
422    fn test_bad_create() {
423        let invalid = [
424            (0.0, 0.0),
425            (0.0, 0.1),
426            (1.0, 0.0),
427            (0.0, INF),
428            (INF, 0.0),
429            (f64::NAN, 1.0),
430            (1.0, f64::NAN),
431            (f64::NAN, f64::NAN),
432            (1.0, -1.0),
433            (-1.0, 1.0),
434            (-1.0, -1.0),
435            (INF, INF),
436        ];
437        for &arg in invalid.iter() {
438            bad_create_case(arg);
439        }
440    }
441
442    #[test]
443    fn test_mean() {
444        let f = |x: Beta| x.mean().unwrap();
445        let test = [
446            ((1.0, 1.0), 0.5),
447            ((9.0, 1.0), 0.9),
448            ((5.0, 100.0), 0.047619047619047619047616),
449            ((1.0, INF), 0.0),
450            ((INF, 1.0), 1.0),
451        ];
452        for &(arg, res) in test.iter() {
453            test_case(arg, res, f);
454        }
455    }
456
457    #[test]
458    fn test_variance() {
459        let f = |x: Beta| x.variance().unwrap();
460        let test = [
461            ((1.0, 1.0), 1.0 / 12.0),
462            ((9.0, 1.0), 9.0 / 1100.0),
463            ((5.0, 100.0), 500.0 / 1168650.0),
464            ((1.0, INF), 0.0),
465            ((INF, 1.0), 0.0),
466        ];
467        for &(arg, res) in test.iter() {
468            test_case(arg, res, f);
469        }
470    }
471
472    #[test]
473    fn test_entropy() {
474        let f = |x: Beta| x.entropy().unwrap();
475        let test = [
476            ((9.0, 1.0), -1.3083356884473304939016015),
477            ((5.0, 100.0), -2.52016231876027436794592),
478        ];
479        for &(arg, res) in test.iter() {
480            test_case(arg, res, f);
481        }
482        test_case_special((1.0, 1.0), 0.0, 1e-14, f);
483        let entropy = |x: Beta| x.entropy();
484        test_none((1.0, INF), entropy);
485        test_none((INF, 1.0), entropy);
486    }
487
488    #[test]
489    fn test_skewness() {
490        let skewness = |x: Beta| x.skewness().unwrap();
491        test_case((1.0, 1.0), 0.0, skewness);
492        test_case((9.0, 1.0), -1.4740554623801777107177478829, skewness);
493        test_case((5.0, 100.0), 0.817594109275534303545831591, skewness);
494        test_case((1.0, INF), 2.0, skewness);
495        test_case((INF, 1.0), -2.0, skewness);
496    }
497
498    #[test]
499    fn test_mode() {
500        let mode = |x: Beta| x.mode().unwrap();
501        test_case((5.0, 100.0), 0.038834951456310676243255386, mode);
502        test_case((92.0, INF), 0.0, mode);
503        test_case((INF, 2.0), 1.0, mode);
504    }
505
506    #[test]
507    #[should_panic]
508    fn test_mode_shape_a_lte_1() {
509        let mode = |x: Beta| x.mode().unwrap();
510        get_value((1.0, 5.0), mode);
511    }
512
513    #[test]
514    #[should_panic]
515    fn test_mode_shape_b_lte_1() {
516        let mode = |x: Beta| x.mode().unwrap();
517        get_value((5.0, 1.0), mode);
518    }
519
520    #[test]
521    fn test_min_max() {
522        let min = |x: Beta| x.min();
523        let max = |x: Beta| x.max();
524        test_case((1.0, 1.0), 0.0, min);
525        test_case((1.0, 1.0), 1.0, max);
526    }
527
528    #[test]
529    fn test_pdf() {
530        let f = |arg: f64| move |x: Beta| x.pdf(arg);
531        let test = [
532            ((1.0, 1.0), 0.0, 1.0),
533            ((1.0, 1.0), 0.5, 1.0),
534            ((1.0, 1.0), 1.0, 1.0),
535            ((9.0, 1.0), 0.0, 0.0),
536            ((9.0, 1.0), 0.5, 0.03515625),
537            ((9.0, 1.0), 1.0, 9.0),
538            ((5.0, 100.0), 0.0, 0.0),
539            ((5.0, 100.0), 0.5, 4.534102298350337661e-23),
540            ((5.0, 100.0), 1.0, 0.0),
541            ((5.0, 100.0), 1.0, 0.0),
542            ((1.0, INF), 0.0, INF),
543            ((1.0, INF), 0.5, 0.0),
544            ((1.0, INF), 1.0, 0.0),
545            ((INF, 1.0), 0.0, 0.0),
546            ((INF, 1.0), 0.5, 0.0),
547            ((INF, 1.0), 1.0, INF),
548        ];
549        for &(arg, x, expect) in test.iter() {
550            test_case(arg, expect, f(x));
551        }
552    }
553
554    #[test]
555    fn test_pdf_input_lt_0() {
556        let pdf = |arg: f64| move |x: Beta| x.pdf(arg);
557        test_case((1.0, 1.0), 0.0, pdf(-1.0));
558    }
559
560    #[test]
561    fn test_pdf_input_gt_0() {
562        let pdf = |arg: f64| move |x: Beta| x.pdf(arg);
563        test_case((1.0, 1.0), 0.0, pdf(2.0));
564    }
565
566    #[test]
567    fn test_ln_pdf() {
568        let f = |arg: f64| move |x: Beta| x.ln_pdf(arg);
569        let test = [
570            ((1.0, 1.0), 0.0, 0.0),
571            ((1.0, 1.0), 0.5, 0.0),
572            ((1.0, 1.0), 1.0, 0.0),
573            ((9.0, 1.0), 0.0, -INF),
574            ((9.0, 1.0), 0.5, -3.347952867143343092547366497),
575            ((9.0, 1.0), 1.0, 2.1972245773362193827904904738),
576            ((5.0, 100.0), 0.0, -INF),
577            ((5.0, 100.0), 0.5, -51.447830024537682154565870),
578            ((5.0, 100.0), 1.0, -INF),
579            ((1.0, INF), 0.0, INF),
580            ((1.0, INF), 0.5, -INF),
581            ((1.0, INF), 1.0, -INF),
582            ((INF, 1.0), 0.0, -INF),
583            ((INF, 1.0), 0.5, -INF),
584            ((INF, 1.0), 1.0, INF),
585        ];
586        for &(arg, x, expect) in test.iter() {
587            test_case(arg, expect, f(x));
588        }
589    }
590
591    #[test]
592    fn test_ln_pdf_input_lt_0() {
593        let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg);
594        test_case((1.0, 1.0), -INF, ln_pdf(-1.0));
595    }
596
597    #[test]
598    fn test_ln_pdf_input_gt_1() {
599        let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg);
600        test_case((1.0, 1.0), -INF, ln_pdf(2.0));
601    }
602
603    #[test]
604    fn test_cdf() {
605        let cdf = |arg: f64| move |x: Beta| x.cdf(arg);
606        let test = [
607            ((1.0, 1.0), 0.0, 0.0),
608            ((1.0, 1.0), 0.5, 0.5),
609            ((1.0, 1.0), 1.0, 1.0),
610            ((9.0, 1.0), 0.0, 0.0),
611            ((9.0, 1.0), 0.5, 0.001953125),
612            ((9.0, 1.0), 1.0, 1.0),
613            ((5.0, 100.0), 0.0, 0.0),
614            ((5.0, 100.0), 0.5, 1.0),
615            ((5.0, 100.0), 1.0, 1.0),
616            ((1.0, INF), 0.0, 1.0),
617            ((1.0, INF), 0.5, 1.0),
618            ((1.0, INF), 1.0, 1.0),
619            ((INF, 1.0), 0.0, 0.0),
620            ((INF, 1.0), 0.5, 0.0),
621            ((INF, 1.0), 1.0, 1.0),
622        ];
623        for &(arg, x, expect) in test.iter() {
624            test_case(arg, expect, cdf(x));
625        }
626    }
627
628    #[test]
629    fn test_sf() {
630        let sf = |arg: f64| move |x: Beta| x.sf(arg);
631        let test = [
632            ((1.0, 1.0), 0.0, 1.0),
633            ((1.0, 1.0), 0.5, 0.5),
634            ((1.0, 1.0), 1.0, 0.0),
635            ((9.0, 1.0), 0.0, 1.0),
636            ((9.0, 1.0), 0.5, 0.998046875),
637            ((9.0, 1.0), 1.0, 0.0),
638            ((5.0, 100.0), 0.0, 1.0),
639            ((5.0, 100.0), 0.5, 0.0),
640            ((5.0, 100.0), 1.0, 0.0),
641            ((1.0, INF), 0.0, 0.0),
642            ((1.0, INF), 0.5, 0.0),
643            ((1.0, INF), 1.0, 0.0),
644            ((INF, 1.0), 0.0, 1.0),
645            ((INF, 1.0), 0.5, 1.0),
646            ((INF, 1.0), 1.0, 0.0),
647        ];
648        for &(arg, x, expect) in test.iter() {
649            test_case(arg, expect, sf(x));
650        }
651    }
652
653    #[test]
654    fn test_cdf_input_lt_0() {
655        let cdf = |arg: f64| move |x: Beta| x.cdf(arg);
656        test_case((1.0, 1.0), 0.0, cdf(-1.0));
657    }
658
659    #[test]
660    fn test_cdf_input_gt_1() {
661        let cdf = |arg: f64| move |x: Beta| x.cdf(arg);
662        test_case((1.0, 1.0), 1.0, cdf(2.0));
663    }
664
665    #[test]
666    fn test_sf_input_lt_0() {
667        let sf = |arg: f64| move |x: Beta| x.sf(arg);
668        test_case((1.0, 1.0), 1.0, sf(-1.0));
669    }
670
671    #[test]
672    fn test_sf_input_gt_1() {
673        let sf = |arg: f64| move |x: Beta| x.sf(arg);
674        test_case((1.0, 1.0), 0.0, sf(2.0));
675    }
676
677    #[test]
678    fn test_continuous() {
679        test::check_continuous_distribution(&try_create((1.2, 3.4)), 0.0, 1.0);
680        test::check_continuous_distribution(&try_create((4.5, 6.7)), 0.0, 1.0);
681    }
682}