statrs/distribution/
gamma.rs

1use crate::distribution::{Continuous, ContinuousCDF};
2use crate::function::gamma;
3use crate::prec;
4use crate::statistics::*;
5
6/// Implements the [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution)
7/// distribution
8///
9/// # Examples
10///
11/// ```
12/// use statrs::distribution::{Gamma, Continuous};
13/// use statrs::statistics::Distribution;
14/// use statrs::prec;
15///
16/// let n = Gamma::new(3.0, 1.0).unwrap();
17/// assert_eq!(n.mean().unwrap(), 3.0);
18/// assert!(prec::almost_eq(n.pdf(2.0), 0.270670566473225383788, 1e-15));
19/// ```
20#[derive(Copy, Clone, PartialEq, Debug)]
21pub struct Gamma {
22    shape: f64,
23    rate: f64,
24}
25
26/// Represents the errors that can occur when creating a [`Gamma`].
27#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
28#[non_exhaustive]
29pub enum GammaError {
30    /// The shape is NaN, zero or less than zero.
31    ShapeInvalid,
32
33    /// The rate is NaN, zero or less than zero.
34    RateInvalid,
35
36    /// The shape and rate are both infinite.
37    ShapeAndRateInfinite,
38}
39
40impl std::fmt::Display for GammaError {
41    #[cfg_attr(coverage_nightly, coverage(off))]
42    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43        match self {
44            GammaError::ShapeInvalid => write!(f, "Shape is NaN zero, or less than zero."),
45            GammaError::RateInvalid => write!(f, "Rate is NaN zero, or less than zero."),
46            GammaError::ShapeAndRateInfinite => write!(f, "Shape and rate are infinite"),
47        }
48    }
49}
50
51impl std::error::Error for GammaError {}
52
53impl Gamma {
54    /// Constructs a new gamma distribution with a shape (α)
55    /// of `shape` and a rate (β) of `rate`
56    ///
57    /// # Errors
58    ///
59    /// Returns an error if `shape` is 'NaN' or inf or `rate` is `NaN` or inf.
60    /// Also returns an error if `shape <= 0.0` or `rate <= 0.0`
61    ///
62    /// # Examples
63    ///
64    /// ```
65    /// use statrs::distribution::Gamma;
66    ///
67    /// let mut result = Gamma::new(3.0, 1.0);
68    /// assert!(result.is_ok());
69    ///
70    /// result = Gamma::new(0.0, 0.0);
71    /// assert!(result.is_err());
72    /// ```
73    pub fn new(shape: f64, rate: f64) -> Result<Gamma, GammaError> {
74        if shape.is_nan() || shape <= 0.0 {
75            return Err(GammaError::ShapeInvalid);
76        }
77
78        if rate.is_nan() || rate <= 0.0 {
79            return Err(GammaError::RateInvalid);
80        }
81
82        if shape.is_infinite() && rate.is_infinite() {
83            return Err(GammaError::ShapeAndRateInfinite);
84        }
85
86        Ok(Gamma { shape, rate })
87    }
88
89    /// Returns the shape (α) of the gamma distribution
90    ///
91    /// # Examples
92    ///
93    /// ```
94    /// use statrs::distribution::Gamma;
95    ///
96    /// let n = Gamma::new(3.0, 1.0).unwrap();
97    /// assert_eq!(n.shape(), 3.0);
98    /// ```
99    pub fn shape(&self) -> f64 {
100        self.shape
101    }
102
103    /// Returns the rate (β) of the gamma distribution
104    ///
105    /// # Examples
106    ///
107    /// ```
108    /// use statrs::distribution::Gamma;
109    ///
110    /// let n = Gamma::new(3.0, 1.0).unwrap();
111    /// assert_eq!(n.rate(), 1.0);
112    /// ```
113    pub fn rate(&self) -> f64 {
114        self.rate
115    }
116}
117
118impl std::fmt::Display for Gamma {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        write!(f, "Γ({}, {})", self.shape, self.rate)
121    }
122}
123
124#[cfg(feature = "rand")]
125#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
126impl ::rand::distributions::Distribution<f64> for Gamma {
127    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
128        sample_unchecked(rng, self.shape, self.rate)
129    }
130}
131
132impl ContinuousCDF<f64, f64> for Gamma {
133    /// Calculates the cumulative distribution function for the gamma
134    /// distribution
135    /// at `x`
136    ///
137    /// # Formula
138    ///
139    /// ```text
140    /// (1 / Γ(α)) * γ(α, β * x)
141    /// ```
142    ///
143    /// where `α` is the shape, `β` is the rate, `Γ` is the gamma function,
144    /// and `γ` is the lower incomplete gamma function
145    fn cdf(&self, x: f64) -> f64 {
146        if x <= 0.0 {
147            0.0
148        } else if ulps_eq!(x, self.shape) && self.rate.is_infinite() {
149            1.0
150        } else if self.rate.is_infinite() {
151            0.0
152        } else if x.is_infinite() {
153            1.0
154        } else {
155            gamma::gamma_lr(self.shape, x * self.rate)
156        }
157    }
158
159    /// Calculates the survival function for the gamma
160    /// distribution at `x`
161    ///
162    /// # Formula
163    ///
164    /// ```text
165    /// (1 / Γ(α)) * γ(α, β * x)
166    /// ```
167    ///
168    /// where `α` is the shape, `β` is the rate, `Γ` is the gamma function,
169    /// and `γ` is the upper incomplete gamma function
170    fn sf(&self, x: f64) -> f64 {
171        if x <= 0.0 {
172            1.0
173        } else if ulps_eq!(x, self.shape) && self.rate.is_infinite() {
174            0.0
175        } else if self.rate.is_infinite() {
176            1.0
177        } else if x.is_infinite() {
178            0.0
179        } else {
180            gamma::gamma_ur(self.shape, x * self.rate)
181        }
182    }
183
184    fn inverse_cdf(&self, p: f64) -> f64 {
185        if !(0.0..=1.0).contains(&p) {
186            panic!("default inverse_cdf implementation should be provided probability on [0,1]")
187        }
188        if p == 0.0 {
189            return self.min();
190        };
191        if p == 1.0 {
192            return self.max();
193        };
194
195        // Bisection search for MAX_ITERS.0 iterations
196        let mut high = 2.0;
197        let mut low = 1.0;
198        while self.cdf(low) > p {
199            low /= 2.0;
200        }
201        while self.cdf(high) < p {
202            high *= 2.0;
203        }
204        let mut x_0 = (high + low) / 2.0;
205
206        for _ in 0..8 {
207            if self.cdf(x_0) >= p {
208                high = x_0;
209            } else {
210                low = x_0;
211            }
212            if prec::convergence(&mut x_0, (high + low) / 2.0) {
213                break;
214            }
215        }
216
217        // Newton Raphson, for at least one step
218        for _ in 0..4 {
219            let x_next = x_0 - (self.cdf(x_0) - p) / self.pdf(x_0);
220            if prec::convergence(&mut x_0, x_next) {
221                break;
222            }
223        }
224
225        x_0
226    }
227}
228
229impl Min<f64> for Gamma {
230    /// Returns the minimum value in the domain of the
231    /// gamma distribution representable by a double precision
232    /// float
233    ///
234    /// # Formula
235    ///
236    /// ```text
237    /// 0
238    /// ```
239    fn min(&self) -> f64 {
240        0.0
241    }
242}
243
244impl Max<f64> for Gamma {
245    /// Returns the maximum value in the domain of the
246    /// gamma distribution representable by a double precision
247    /// float
248    ///
249    /// # Formula
250    ///
251    /// ```text
252    /// f64::INFINITY
253    /// ```
254    fn max(&self) -> f64 {
255        f64::INFINITY
256    }
257}
258
259impl Distribution<f64> for Gamma {
260    /// Returns the mean of the gamma distribution
261    ///
262    /// # Formula
263    ///
264    /// ```text
265    /// α / β
266    /// ```
267    ///
268    /// where `α` is the shape and `β` is the rate
269    fn mean(&self) -> Option<f64> {
270        Some(self.shape / self.rate)
271    }
272
273    /// Returns the variance of the gamma distribution
274    ///
275    /// # Formula
276    ///
277    /// ```text
278    /// α / β^2
279    /// ```
280    ///
281    /// where `α` is the shape and `β` is the rate
282    fn variance(&self) -> Option<f64> {
283        Some(self.shape / (self.rate * self.rate))
284    }
285
286    /// Returns the entropy of the gamma distribution
287    ///
288    /// # Formula
289    ///
290    /// ```text
291    /// α - ln(β) + ln(Γ(α)) + (1 - α) * ψ(α)
292    /// ```
293    ///
294    /// where `α` is the shape, `β` is the rate, `Γ` is the gamma function,
295    /// and `ψ` is the digamma function
296    fn entropy(&self) -> Option<f64> {
297        let entr = self.shape - self.rate.ln()
298            + gamma::ln_gamma(self.shape)
299            + (1.0 - self.shape) * gamma::digamma(self.shape);
300        Some(entr)
301    }
302
303    /// Returns the skewness of the gamma distribution
304    ///
305    /// # Formula
306    ///
307    /// ```text
308    /// 2 / sqrt(α)
309    /// ```
310    ///
311    /// where `α` is the shape
312    fn skewness(&self) -> Option<f64> {
313        Some(2.0 / self.shape.sqrt())
314    }
315}
316
317impl Mode<Option<f64>> for Gamma {
318    /// Returns the mode for the gamma distribution
319    ///
320    /// # Formula
321    ///
322    /// ```text
323    /// (α - 1) / β, where α≥1
324    /// ```
325    ///
326    /// where `α` is the shape and `β` is the rate
327    fn mode(&self) -> Option<f64> {
328        if self.shape < 1.0 {
329            None
330        } else {
331            Some((self.shape - 1.0) / self.rate)
332        }
333    }
334}
335
336impl Continuous<f64, f64> for Gamma {
337    /// Calculates the probability density function for the gamma distribution
338    /// at `x`
339    ///
340    /// # Remarks
341    ///
342    /// Returns `NAN` if any of `shape` or `rate` are `f64::INFINITY`
343    /// or if `x` is `f64::INFINITY`
344    ///
345    /// # Formula
346    ///
347    /// ```text
348    /// (β^α / Γ(α)) * x^(α - 1) * e^(-β * x)
349    /// ```
350    ///
351    /// where `α` is the shape, `β` is the rate, and `Γ` is the gamma function
352    fn pdf(&self, x: f64) -> f64 {
353        if x < 0.0 {
354            0.0
355        } else if ulps_eq!(self.shape, 1.0) {
356            self.rate * (-self.rate * x).exp()
357        } else if self.shape > 160.0 {
358            self.ln_pdf(x).exp()
359        } else if x.is_infinite() {
360            0.0
361        } else {
362            self.rate.powf(self.shape) * x.powf(self.shape - 1.0) * (-self.rate * x).exp()
363                / gamma::gamma(self.shape)
364        }
365    }
366
367    /// Calculates the log probability density function for the gamma
368    /// distribution
369    /// at `x`
370    ///
371    /// # Remarks
372    ///
373    /// Returns `NAN` if any of `shape` or `rate` are `f64::INFINITY`
374    /// or if `x` is `f64::INFINITY`
375    ///
376    /// # Formula
377    ///
378    /// ```text
379    /// ln((β^α / Γ(α)) * x^(α - 1) * e ^(-β * x))
380    /// ```
381    ///
382    /// where `α` is the shape, `β` is the rate, and `Γ` is the gamma function
383    fn ln_pdf(&self, x: f64) -> f64 {
384        if x < 0.0 {
385            f64::NEG_INFINITY
386        } else if ulps_eq!(self.shape, 1.0) {
387            self.rate.ln() - self.rate * x
388        } else if x.is_infinite() {
389            f64::NEG_INFINITY
390        } else {
391            self.shape * self.rate.ln() + (self.shape - 1.0) * x.ln()
392                - self.rate * x
393                - gamma::ln_gamma(self.shape)
394        }
395    }
396}
397/// Samples from a gamma distribution with a shape of `shape` and a
398/// rate of `rate` using `rng` as the source of randomness. Implementation from:
399///
400/// _"A Simple Method for Generating Gamma Variables"_ - Marsaglia & Tsang
401///
402/// ACM Transactions on Mathematical Software, Vol. 26, No. 3, September 2000,
403/// Pages 363-372
404#[cfg(feature = "rand")]
405#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
406pub fn sample_unchecked<R: ::rand::Rng + ?Sized>(rng: &mut R, shape: f64, rate: f64) -> f64 {
407    let mut a = shape;
408    let mut afix = 1.0;
409    if shape < 1.0 {
410        a = shape + 1.0;
411        afix = rng.gen::<f64>().powf(1.0 / shape);
412    }
413
414    let d = a - 1.0 / 3.0;
415    let c = 1.0 / (9.0 * d).sqrt();
416    loop {
417        let mut x;
418        let mut v;
419        loop {
420            x = super::normal::sample_unchecked(rng, 0.0, 1.0);
421            v = 1.0 + c * x;
422            if v > 0.0 {
423                break;
424            };
425        }
426
427        v = v * v * v;
428        x = x * x;
429        let u: f64 = rng.gen();
430        if u < 1.0 - 0.0331 * x * x || u.ln() < 0.5 * x + d * (1.0 - v + v.ln()) {
431            return afix * d * v / rate;
432        }
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use crate::distribution::internal::*;
440    use crate::testing_boiler;
441
442    testing_boiler!(shape: f64, rate: f64; Gamma; GammaError);
443
444    #[test]
445    fn test_create() {
446        let valid = [
447            (1.0, 0.1),
448            (1.0, 1.0),
449            (10.0, 10.0),
450            (10.0, 1.0),
451            (10.0, f64::INFINITY),
452        ];
453
454        for (s, r) in valid {
455            create_ok(s, r);
456        }
457    }
458
459    #[test]
460    fn test_bad_create() {
461        let invalid = [
462            (0.0, 0.0, GammaError::ShapeInvalid),
463            (1.0, f64::NAN, GammaError::RateInvalid),
464            (1.0, -1.0, GammaError::RateInvalid),
465            (-1.0, 1.0, GammaError::ShapeInvalid),
466            (-1.0, -1.0, GammaError::ShapeInvalid),
467            (-1.0, f64::NAN, GammaError::ShapeInvalid),
468            (
469                f64::INFINITY,
470                f64::INFINITY,
471                GammaError::ShapeAndRateInfinite,
472            ),
473        ];
474        for (s, r, err) in invalid {
475            test_create_err(s, r, err);
476        }
477    }
478
479    #[test]
480    fn test_mean() {
481        let f = |x: Gamma| x.mean().unwrap();
482        let test = [
483            ((1.0, 0.1), 10.0),
484            ((1.0, 1.0), 1.0),
485            ((10.0, 10.0), 1.0),
486            ((10.0, 1.0), 10.0),
487            ((10.0, f64::INFINITY), 0.0),
488        ];
489        for ((s, r), res) in test {
490            test_relative(s, r, res, f);
491        }
492    }
493
494    #[test]
495    fn test_variance() {
496        let f = |x: Gamma| x.variance().unwrap();
497        let test = [
498            ((1.0, 0.1), 100.0),
499            ((1.0, 1.0), 1.0),
500            ((10.0, 10.0), 0.1),
501            ((10.0, 1.0), 10.0),
502            ((10.0, f64::INFINITY), 0.0),
503        ];
504        for ((s, r), res) in test {
505            test_relative(s, r, res, f);
506        }
507    }
508
509    #[test]
510    fn test_entropy() {
511        let f = |x: Gamma| x.entropy().unwrap();
512        let test = [
513            ((1.0, 0.1), 3.302585092994045628506840223),
514            ((1.0, 1.0), 1.0),
515            ((10.0, 10.0), 0.2334690854869339583626209),
516            ((10.0, 1.0), 2.53605417848097964238061239),
517            ((10.0, f64::INFINITY), f64::NEG_INFINITY),
518        ];
519        for ((s, r), res) in test {
520            test_relative(s, r, res, f);
521        }
522    }
523
524    #[test]
525    fn test_skewness() {
526        let f = |x: Gamma| x.skewness().unwrap();
527        let test = [
528            ((1.0, 0.1), 2.0),
529            ((1.0, 1.0), 2.0),
530            ((10.0, 10.0), 0.6324555320336758663997787),
531            ((10.0, 1.0), 0.63245553203367586639977870),
532            ((10.0, f64::INFINITY), 0.6324555320336758),
533        ];
534        for ((s, r), res) in test {
535            test_relative(s, r, res, f);
536        }
537    }
538
539    #[test]
540    fn test_mode() {
541        let f = |x: Gamma| x.mode().unwrap();
542        let test = [((1.0, 0.1), 0.0), ((1.0, 1.0), 0.0)];
543        for &((s, r), res) in test.iter() {
544            test_absolute(s, r, res, 10e-6, f);
545        }
546        let test = [
547            ((10.0, 10.0), 0.9),
548            ((10.0, 1.0), 9.0),
549            ((10.0, f64::INFINITY), 0.0),
550        ];
551        for ((s, r), res) in test {
552            test_relative(s, r, res, f);
553        }
554    }
555
556    #[test]
557    fn test_min_max() {
558        let f = |x: Gamma| x.min();
559        let test = [
560            ((1.0, 0.1), 0.0),
561            ((1.0, 1.0), 0.0),
562            ((10.0, 10.0), 0.0),
563            ((10.0, 1.0), 0.0),
564            ((10.0, f64::INFINITY), 0.0),
565        ];
566        for ((s, r), res) in test {
567            test_relative(s, r, res, f);
568        }
569        let f = |x: Gamma| x.max();
570        let test = [
571            ((1.0, 0.1), f64::INFINITY),
572            ((1.0, 1.0), f64::INFINITY),
573            ((10.0, 10.0), f64::INFINITY),
574            ((10.0, 1.0), f64::INFINITY),
575            ((10.0, f64::INFINITY), f64::INFINITY),
576        ];
577        for ((s, r), res) in test {
578            test_relative(s, r, res, f);
579        }
580    }
581
582    #[test]
583    fn test_pdf() {
584        let f = |arg: f64| move |x: Gamma| x.pdf(arg);
585        let test = [
586            ((1.0, 0.1), 1.0, 0.090483741803595961836995),
587            ((1.0, 0.1), 10.0, 0.036787944117144234201693),
588            ((1.0, 1.0), 1.0, 0.367879441171442321595523),
589            ((1.0, 1.0), 10.0, 0.000045399929762484851535),
590            ((10.0, 10.0), 1.0, 1.251100357211332989847649),
591            ((10.0, 10.0), 10.0, 1.025153212086870580621609e-30),
592            ((10.0, 1.0), 1.0, 0.000001013777119630297402),
593            ((10.0, 1.0), 10.0, 0.125110035721133298984764),
594        ];
595        for ((s, r), x, res) in test {
596            test_relative(s, r, res, f(x));
597        }
598        // TODO: test special
599        // test_is_nan((10.0, f64::INFINITY), pdf(1.0)); // is this really the behavior we want?
600        // TODO: test special
601        // (10.0, f64::INFINITY, f64::INFINITY, 0.0, pdf(f64::INFINITY)),];
602    }
603
604    #[test]
605    fn test_pdf_at_zero() {
606        test_relative(1.0, 0.1, 0.1, |x| x.pdf(0.0));
607        test_relative(1.0, 0.1, 0.1f64.ln(), |x| x.ln_pdf(0.0));
608    }
609
610    #[test]
611    fn test_ln_pdf() {
612        let f = |arg: f64| move |x: Gamma| x.ln_pdf(arg);
613        let test = [
614            ((1.0, 0.1), 1.0, -2.40258509299404563405795),
615            ((1.0, 0.1), 10.0, -3.30258509299404562850684),
616            ((1.0, 1.0), 1.0, -1.0),
617            ((1.0, 1.0), 10.0, -10.0),
618            ((10.0, 10.0), 1.0, 0.224023449858987228972196),
619            ((10.0, 10.0), 10.0, -69.0527107131946016148658),
620            ((10.0, 1.0), 1.0, -13.8018274800814696112077),
621            ((10.0, 1.0), 10.0, -2.07856164313505845504579),
622            ((10.0, f64::INFINITY), f64::INFINITY, f64::NEG_INFINITY),
623        ];
624        for ((s, r), x, res) in test {
625            test_relative(s, r, res, f(x));
626        }
627        // TODO: test special
628        // test_is_nan((10.0, f64::INFINITY), f(1.0)); // is this really the behavior we want?
629    }
630
631    #[test]
632    fn test_cdf() {
633        let f = |arg: f64| move |x: Gamma| x.cdf(arg);
634        let test = [
635            ((1.0, 0.1), 1.0, 0.095162581964040431858607),
636            ((1.0, 0.1), 10.0, 0.632120558828557678404476),
637            ((1.0, 1.0), 1.0, 0.632120558828557678404476),
638            ((1.0, 1.0), 10.0, 0.999954600070237515148464),
639            ((10.0, 10.0), 1.0, 0.542070285528147791685835),
640            ((10.0, 10.0), 10.0, 0.999999999999999999999999),
641            ((10.0, 1.0), 1.0, 0.000000111425478338720677),
642            ((10.0, 1.0), 10.0, 0.542070285528147791685835),
643            ((10.0, f64::INFINITY), 1.0, 0.0),
644            ((10.0, f64::INFINITY), 10.0, 1.0),
645        ];
646        for ((s, r), x, res) in test {
647            test_relative(s, r, res, f(x));
648        }
649    }
650
651    #[test]
652    fn test_cdf_at_zero() {
653        test_relative(1.0, 0.1, 0.0, |x| x.cdf(0.0));
654    }
655
656    #[test]
657    fn test_cdf_inverse_identity() {
658        let f = |p: f64| move |g: Gamma| g.cdf(g.inverse_cdf(p));
659        let params = [
660            (1.0, 0.1),
661            (1.0, 1.0),
662            (10.0, 10.0),
663            (10.0, 1.0),
664            (100.0, 200.0),
665        ];
666
667        for (s, r) in params {
668            for n in -5..0 {
669                let p = 10.0f64.powi(n);
670                test_relative(s, r, p, f(p));
671            }
672        }
673
674        // test case from issue #200
675        {
676            let x = 20.5567;
677            let f = |x: f64| move |g: Gamma| g.inverse_cdf(g.cdf(x));
678            test_relative(3.0, 0.5, x, f(x))
679        }
680    }
681
682    #[test]
683    fn test_sf() {
684        let f = |arg: f64| move |x: Gamma| x.sf(arg);
685        let test = [
686            ((1.0, 0.1), 1.0, 0.9048374180359595),
687            ((1.0, 0.1), 10.0, 0.3678794411714419),
688            ((1.0, 1.0), 1.0, 0.3678794411714419),
689            ((1.0, 1.0), 10.0, 4.539992976249074e-5),
690            ((10.0, 10.0), 1.0, 0.4579297144718528),
691            ((10.0, 10.0), 10.0, 1.1253473960842808e-31),
692            ((10.0, 1.0), 1.0, 0.9999998885745217),
693            ((10.0, 1.0), 10.0, 0.4579297144718528),
694            ((10.0, f64::INFINITY), 1.0, 1.0),
695            ((10.0, f64::INFINITY), 10.0, 0.0),
696        ];
697        for ((s, r), x, res) in test {
698            test_relative(s, r, res, f(x));
699        }
700    }
701
702    #[test]
703    fn test_sf_at_zero() {
704        test_relative(1.0, 0.1, 1.0, |x| x.sf(0.0));
705    }
706
707    #[test]
708    fn test_continuous() {
709        test::check_continuous_distribution(&create_ok(1.0, 0.5), 0.0, 20.0);
710        test::check_continuous_distribution(&create_ok(9.0, 2.0), 0.0, 20.0);
711    }
712}