statrs/distribution/
gamma.rs

1use crate::distribution::{Continuous, ContinuousCDF};
2use crate::function::gamma;
3use crate::statistics::*;
4use crate::{Result, StatsError};
5use core::f64::INFINITY as INF;
6use rand::Rng;
7
8/// Implements the [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution)
9/// distribution
10///
11/// # Examples
12///
13/// ```
14/// use statrs::distribution::{Gamma, Continuous};
15/// use statrs::statistics::Distribution;
16/// use statrs::prec;
17///
18/// let n = Gamma::new(3.0, 1.0).unwrap();
19/// assert_eq!(n.mean().unwrap(), 3.0);
20/// assert!(prec::almost_eq(n.pdf(2.0), 0.270670566473225383788, 1e-15));
21/// ```
22#[derive(Debug, Copy, Clone, PartialEq)]
23pub struct Gamma {
24    shape: f64,
25    rate: f64,
26}
27
28impl Gamma {
29    /// Constructs a new gamma distribution with a shape (α)
30    /// of `shape` and a rate (β) of `rate`
31    ///
32    /// # Errors
33    ///
34    /// Returns an error if `shape` is 'NaN' or inf or `rate` is `NaN` or inf.
35    /// Also returns an error if `shape <= 0.0` or `rate <= 0.0`
36    ///
37    /// # Examples
38    ///
39    /// ```
40    /// use statrs::distribution::Gamma;
41    ///
42    /// let mut result = Gamma::new(3.0, 1.0);
43    /// assert!(result.is_ok());
44    ///
45    /// result = Gamma::new(0.0, 0.0);
46    /// assert!(result.is_err());
47    /// ```
48    pub fn new(shape: f64, rate: f64) -> Result<Gamma> {
49        if shape.is_nan()
50            || rate.is_nan()
51            || shape.is_infinite() && rate.is_infinite()
52            || shape <= 0.0
53            || rate <= 0.0
54        {
55            return Err(StatsError::BadParams);
56        }
57        Ok(Gamma { shape, rate })
58    }
59
60    /// Returns the shape (α) of the gamma distribution
61    ///
62    /// # Examples
63    ///
64    /// ```
65    /// use statrs::distribution::Gamma;
66    ///
67    /// let n = Gamma::new(3.0, 1.0).unwrap();
68    /// assert_eq!(n.shape(), 3.0);
69    /// ```
70    pub fn shape(&self) -> f64 {
71        self.shape
72    }
73
74    /// Returns the rate (β) of the gamma distribution
75    ///
76    /// # Examples
77    ///
78    /// ```
79    /// use statrs::distribution::Gamma;
80    ///
81    /// let n = Gamma::new(3.0, 1.0).unwrap();
82    /// assert_eq!(n.rate(), 1.0);
83    /// ```
84    pub fn rate(&self) -> f64 {
85        self.rate
86    }
87}
88
89impl ::rand::distributions::Distribution<f64> for Gamma {
90    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
91        sample_unchecked(rng, self.shape, self.rate)
92    }
93}
94
95impl ContinuousCDF<f64, f64> for Gamma {
96    /// Calculates the cumulative distribution function for the gamma
97    /// distribution
98    /// at `x`
99    ///
100    /// # Formula
101    ///
102    /// ```ignore
103    /// (1 / Γ(α)) * γ(α, β * x)
104    /// ```
105    ///
106    /// where `α` is the shape, `β` is the rate, `Γ` is the gamma function,
107    /// and `γ` is the lower incomplete gamma function
108    fn cdf(&self, x: f64) -> f64 {
109        if x <= 0.0 {
110            0.0
111        } else if ulps_eq!(x, self.shape) && self.rate.is_infinite() {
112            1.0
113        } else if self.rate.is_infinite() {
114            0.0
115        } else if x.is_infinite() {
116            1.0
117        } else {
118            gamma::gamma_lr(self.shape, x * self.rate)
119        }
120    }
121
122    /// Calculates the survival function for the gamma
123    /// distribution at `x`
124    ///
125    /// # Formula
126    ///
127    /// ```ignore
128    /// (1 / Γ(α)) * γ(α, β * x)
129    /// ```
130    ///
131    /// where `α` is the shape, `β` is the rate, `Γ` is the gamma function,
132    /// and `γ` is the upper incomplete gamma function
133    fn sf(&self, x: f64) -> f64 {
134        if x <= 0.0 {
135            1.0
136        }
137        else if ulps_eq!(x, self.shape) && self.rate.is_infinite() {
138            0.0
139        }
140        else if self.rate.is_infinite() {
141            1.0
142        }
143        else if x.is_infinite() {
144            0.0
145        }
146        else {
147            gamma::gamma_ur(self.shape, x * self.rate)
148        }
149    }
150}
151
152impl Min<f64> for Gamma {
153    /// Returns the minimum value in the domain of the
154    /// gamma distribution representable by a double precision
155    /// float
156    ///
157    /// # Formula
158    ///
159    /// ```ignore
160    /// 0
161    /// ```
162    fn min(&self) -> f64 {
163        0.0
164    }
165}
166
167impl Max<f64> for Gamma {
168    /// Returns the maximum value in the domain of the
169    /// gamma distribution representable by a double precision
170    /// float
171    ///
172    /// # Formula
173    ///
174    /// ```ignore
175    /// INF
176    /// ```
177    fn max(&self) -> f64 {
178        INF
179    }
180}
181
182impl Distribution<f64> for Gamma {
183    /// Returns the mean of the gamma distribution
184    ///
185    /// # Formula
186    ///
187    /// ```ignore
188    /// α / β
189    /// ```
190    ///
191    /// where `α` is the shape and `β` is the rate
192    fn mean(&self) -> Option<f64> {
193        Some(self.shape / self.rate)
194    }
195    /// Returns the variance of the gamma distribution
196    ///
197    /// # Formula
198    ///
199    /// ```ignore
200    /// α / β^2
201    /// ```
202    ///
203    /// where `α` is the shape and `β` is the rate
204    fn variance(&self) -> Option<f64> {
205        Some(self.shape / (self.rate * self.rate))
206    }
207    /// Returns the entropy of the gamma distribution
208    ///
209    /// # Formula
210    ///
211    /// ```ignore
212    /// α - ln(β) + ln(Γ(α)) + (1 - α) * ψ(α)
213    /// ```
214    ///
215    /// where `α` is the shape, `β` is the rate, `Γ` is the gamma function,
216    /// and `ψ` is the digamma function
217    fn entropy(&self) -> Option<f64> {
218        let entr = self.shape - self.rate.ln()
219            + gamma::ln_gamma(self.shape)
220            + (1.0 - self.shape) * gamma::digamma(self.shape);
221        Some(entr)
222    }
223    /// Returns the skewness of the gamma distribution
224    ///
225    /// # Formula
226    ///
227    /// ```ignore
228    /// 2 / sqrt(α)
229    /// ```
230    ///
231    /// where `α` is the shape
232    fn skewness(&self) -> Option<f64> {
233        Some(2.0 / self.shape.sqrt())
234    }
235}
236
237impl Mode<Option<f64>> for Gamma {
238    /// Returns the mode for the gamma distribution
239    ///
240    /// # Formula
241    ///
242    /// ```ignore
243    /// (α - 1) / β
244    /// ```
245    ///
246    /// where `α` is the shape and `β` is the rate
247    fn mode(&self) -> Option<f64> {
248        Some((self.shape - 1.0) / self.rate)
249    }
250}
251
252impl Continuous<f64, f64> for Gamma {
253    /// Calculates the probability density function for the gamma distribution
254    /// at `x`
255    ///
256    /// # Remarks
257    ///
258    /// Returns `NAN` if any of `shape` or `rate` are `INF`
259    /// or if `x` is `INF`
260    ///
261    /// # Formula
262    ///
263    /// ```ignore
264    /// (β^α / Γ(α)) * x^(α - 1) * e^(-β * x)
265    /// ```
266    ///
267    /// where `α` is the shape, `β` is the rate, and `Γ` is the gamma function
268    fn pdf(&self, x: f64) -> f64 {
269        if x < 0.0 {
270            0.0
271        } else if ulps_eq!(self.shape, 1.0) {
272            self.rate * (-self.rate * x).exp()
273        } else if self.shape > 160.0 {
274            self.ln_pdf(x).exp()
275        } else if x.is_infinite() {
276            0.0
277        } else {
278            self.rate.powf(self.shape) * x.powf(self.shape - 1.0) * (-self.rate * x).exp()
279                / gamma::gamma(self.shape)
280        }
281    }
282
283    /// Calculates the log probability density function for the gamma
284    /// distribution
285    /// at `x`
286    ///
287    /// # Remarks
288    ///
289    /// Returns `NAN` if any of `shape` or `rate` are `INF`
290    /// or if `x` is `INF`
291    ///
292    /// # Formula
293    ///
294    /// ```ignore
295    /// ln((β^α / Γ(α)) * x^(α - 1) * e ^(-β * x))
296    /// ```
297    ///
298    /// where `α` is the shape, `β` is the rate, and `Γ` is the gamma function
299    fn ln_pdf(&self, x: f64) -> f64 {
300        if x < 0.0 {
301            f64::NEG_INFINITY
302        } else if ulps_eq!(self.shape, 1.0) {
303            self.rate.ln() - self.rate * x
304        } else if x.is_infinite() {
305            f64::NEG_INFINITY
306        } else {
307            self.shape * self.rate.ln() + (self.shape - 1.0) * x.ln()
308                - self.rate * x
309                - gamma::ln_gamma(self.shape)
310        }
311    }
312}
313/// Samples from a gamma distribution with a shape of `shape` and a
314/// rate of `rate` using `rng` as the source of randomness. Implementation from:
315/// <br />
316/// <div>
317/// <i>"A Simple Method for Generating Gamma Variables"</i> - Marsaglia & Tsang
318/// </div>
319/// <div>
320/// ACM Transactions on Mathematical Software, Vol. 26, No. 3, September 2000,
321/// Pages 363-372
322/// </div>
323/// <br />
324pub fn sample_unchecked<R: Rng + ?Sized>(rng: &mut R, shape: f64, rate: f64) -> f64 {
325    let mut a = shape;
326    let mut afix = 1.0;
327    if shape < 1.0 {
328        a = shape + 1.0;
329        afix = rng.gen::<f64>().powf(1.0 / shape);
330    }
331
332    let d = a - 1.0 / 3.0;
333    let c = 1.0 / (9.0 * d).sqrt();
334    loop {
335        let mut x;
336        let mut v;
337        loop {
338            x = super::normal::sample_unchecked(rng, 0.0, 1.0);
339            v = 1.0 + c * x;
340            if v > 0.0 {
341                break;
342            };
343        }
344
345        v *= v * v;
346        x *= x;
347        let u: f64 = rng.gen();
348        if u < 1.0 - 0.0331 * x * x || u.ln() < 0.5 * x + d * (1.0 - v + v.ln()) {
349            return afix * d * v / rate;
350        }
351    }
352}
353
354#[cfg(all(test, feature = "nightly"))]
355mod tests {
356    use super::*;
357    use crate::consts::ACC;
358    use crate::distribution::internal::*;
359    use crate::testing_boiler;
360
361    testing_boiler!((f64, f64), Gamma);
362
363    #[test]
364    fn test_create() {
365        let valid = [
366            (1.0, 0.1),
367            (1.0, 1.0),
368            (10.0, 10.0),
369            (10.0, 1.0),
370            (10.0, INF),
371        ];
372
373        for &arg in valid.iter() {
374            try_create(arg);
375        }
376    }
377
378    #[test]
379    fn test_bad_create() {
380        let invalid = [
381            (0.0, 0.0),
382            (1.0, f64::NAN),
383            (1.0, -1.0),
384            (-1.0, 1.0),
385            (-1.0, -1.0),
386            (-1.0, f64::NAN),
387        ];
388        for &arg in invalid.iter() {
389            bad_create_case(arg);
390        }
391    }
392
393    #[test]
394    fn test_mean() {
395        let f = |x: Gamma| x.mean().unwrap();
396        let test = [
397            ((1.0, 0.1), 10.0),
398            ((1.0, 1.0), 1.0),
399            ((10.0, 10.0), 1.0),
400            ((10.0, 1.0), 10.0),
401            ((10.0, INF), 0.0),
402        ];
403        for &(arg, res) in test.iter() {
404            test_case(arg, res, f);
405        }
406    }
407
408    #[test]
409    fn test_variance() {
410        let f = |x: Gamma| x.variance().unwrap();
411        let test = [
412            ((1.0, 0.1), 100.0),
413            ((1.0, 1.0), 1.0),
414            ((10.0, 10.0), 0.1),
415            ((10.0, 1.0), 10.0),
416            ((10.0, INF), 0.0),
417        ];
418        for &(arg, res) in test.iter() {
419            test_case(arg, res, f);
420        }
421    }
422
423    #[test]
424    fn test_entropy() {
425        let f = |x: Gamma| x.entropy().unwrap();
426        let test = [
427            ((1.0, 0.1), 3.302585092994045628506840223),
428            ((1.0, 1.0), 1.0),
429            ((10.0, 10.0), 0.2334690854869339583626209),
430            ((10.0, 1.0), 2.53605417848097964238061239),
431            ((10.0, INF), f64::NEG_INFINITY),
432        ];
433        for &(arg, res) in test.iter() {
434            test_case(arg, res, f);
435        }
436    }
437
438    #[test]
439    fn test_skewness() {
440        let f = |x: Gamma| x.skewness().unwrap();
441        let test = [
442            ((1.0, 0.1), 2.0),
443            ((1.0, 1.0), 2.0),
444            ((10.0, 10.0), 0.6324555320336758663997787),
445            ((10.0, 1.0), 0.63245553203367586639977870),
446            ((10.0, INF), 0.6324555320336758),
447        ];
448        for &(arg, res) in test.iter() {
449            test_case(arg, res, f);
450        }
451    }
452
453    #[test]
454    fn test_mode() {
455        let f = |x: Gamma| x.mode().unwrap();
456        let test = [((1.0, 0.1), 0.0), ((1.0, 1.0), 0.0)];
457        for &(arg, res) in test.iter() {
458            test_case_special(arg, res, 10e-6, f);
459        }
460        let test = [((10.0, 10.0), 0.9), ((10.0, 1.0), 9.0), ((10.0, INF), 0.0)];
461        for &(arg, res) in test.iter() {
462            test_case(arg, res, f);
463        }
464    }
465
466    #[test]
467    fn test_min_max() {
468        let f = |x: Gamma| x.min();
469        let test = [
470            ((1.0, 0.1), 0.0),
471            ((1.0, 1.0), 0.0),
472            ((10.0, 10.0), 0.0),
473            ((10.0, 1.0), 0.0),
474            ((10.0, INF), 0.0),
475        ];
476        for &(arg, res) in test.iter() {
477            test_case(arg, res, f);
478        }
479        let f = |x: Gamma| x.max();
480        let test = [
481            ((1.0, 0.1), INF),
482            ((1.0, 1.0), INF),
483            ((10.0, 10.0), INF),
484            ((10.0, 1.0), INF),
485            ((10.0, INF), INF),
486        ];
487        for &(arg, res) in test.iter() {
488            test_case(arg, res, f);
489        }
490    }
491
492    #[test]
493    fn test_pdf() {
494        let f = |arg: f64| move |x: Gamma| x.pdf(arg);
495        let test = [
496            ((1.0, 0.1), 1.0, 0.090483741803595961836995),
497            ((1.0, 0.1), 10.0, 0.036787944117144234201693),
498            ((1.0, 1.0), 1.0, 0.367879441171442321595523),
499            ((1.0, 1.0), 10.0, 0.000045399929762484851535),
500            ((10.0, 10.0), 1.0, 1.251100357211332989847649),
501            ((10.0, 10.0), 10.0, 1.025153212086870580621609e-30),
502            ((10.0, 1.0), 1.0, 0.000001013777119630297402),
503            ((10.0, 1.0), 10.0, 0.125110035721133298984764),
504        ];
505        for &(arg, x, res) in test.iter() {
506            test_case(arg, res, f(x));
507        }
508        //TODO: test special
509        // test_is_nan((10.0, INF), pdf(1.0)); // is this really the behavior we want?
510        //TODO: test special
511        // (10.0, INF, INF, 0.0, pdf(INF)),];
512    }
513
514    #[test]
515    fn test_pdf_at_zero() {
516        test_case((1.0, 0.1), 0.1, |x| x.pdf(0.0));
517        test_case((1.0, 0.1), 0.1f64.ln(), |x| x.ln_pdf(0.0));
518    }
519
520    #[test]
521    fn test_ln_pdf() {
522        let f = |arg: f64| move |x: Gamma| x.ln_pdf(arg);
523        let test = [
524            ((1.0, 0.1), 1.0, -2.40258509299404563405795),
525            ((1.0, 0.1), 10.0, -3.30258509299404562850684),
526            ((1.0, 1.0), 1.0, -1.0),
527            ((1.0, 1.0), 10.0, -10.0),
528            ((10.0, 10.0), 1.0, 0.224023449858987228972196),
529            ((10.0, 10.0), 10.0, -69.0527107131946016148658),
530            ((10.0, 1.0), 1.0, -13.8018274800814696112077),
531            ((10.0, 1.0), 10.0, -2.07856164313505845504579),
532            ((10.0, INF), INF, f64::NEG_INFINITY),
533        ];
534        for &(arg, x, res) in test.iter() {
535            test_case(arg, res, f(x));
536        }
537        // TODO: test special
538        // test_is_nan((10.0, INF), f(1.0)); // is this really the behavior we want?
539    }
540
541    #[test]
542    fn test_cdf() {
543        let f = |arg: f64| move |x: Gamma| x.cdf(arg);
544        let test = [
545            ((1.0, 0.1), 1.0, 0.095162581964040431858607),
546            ((1.0, 0.1), 10.0, 0.632120558828557678404476),
547            ((1.0, 1.0), 1.0, 0.632120558828557678404476),
548            ((1.0, 1.0), 10.0, 0.999954600070237515148464),
549            ((10.0, 10.0), 1.0, 0.542070285528147791685835),
550            ((10.0, 10.0), 10.0, 0.999999999999999999999999),
551            ((10.0, 1.0), 1.0, 0.000000111425478338720677),
552            ((10.0, 1.0), 10.0, 0.542070285528147791685835),
553            ((10.0, INF), 1.0, 0.0),
554            ((10.0, INF), 10.0, 1.0),
555        ];
556        for &(arg, x, res) in test.iter() {
557            test_case(arg, res, f(x));
558        }
559    }
560
561    #[test]
562    fn test_cdf_at_zero() {
563        test_case((1.0, 0.1), 0.0, |x| x.cdf(0.0));
564    }
565
566    #[test]
567    fn test_sf() {
568        let f = |arg: f64| move |x: Gamma| x.sf(arg);
569        let test = [
570            ((1.0, 0.1), 1.0, 0.9048374180359595),
571            ((1.0, 0.1), 10.0, 0.3678794411714419),
572            ((1.0, 1.0), 1.0, 0.3678794411714419),
573            ((1.0, 1.0), 10.0, 4.539992976249074e-5),
574            ((10.0, 10.0), 1.0, 0.4579297144718528),
575            ((10.0, 10.0), 10.0, 1.1253473960842808e-31),
576            ((10.0, 1.0), 1.0, 0.9999998885745217),
577            ((10.0, 1.0), 10.0, 0.4579297144718528),
578            ((10.0, INF), 1.0, 1.0),
579            ((10.0, INF), 10.0, 0.0),
580        ];
581        for &(arg, x, res) in test.iter() {
582            test_case(arg, res, f(x));
583        }
584    }
585
586    #[test]
587    fn test_sf_at_zero() {
588        test_case((1.0, 0.1), 1.0, |x| x.sf(0.0));
589    }
590
591    #[test]
592    fn test_continuous() {
593        test::check_continuous_distribution(&try_create((1.0, 0.5)), 0.0, 20.0);
594        test::check_continuous_distribution(&try_create((9.0, 2.0)), 0.0, 20.0);
595    }
596}