1use crate::distribution::{Continuous, ContinuousCDF};
2use crate::function::gamma;
3use crate::prec;
4use crate::statistics::*;
5
6#[derive(Copy, Clone, PartialEq, Debug)]
21pub struct Gamma {
22 shape: f64,
23 rate: f64,
24}
25
26#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
28#[non_exhaustive]
29pub enum GammaError {
30 ShapeInvalid,
32
33 RateInvalid,
35
36 ShapeAndRateInfinite,
38}
39
40impl std::fmt::Display for GammaError {
41 #[cfg_attr(coverage_nightly, coverage(off))]
42 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43 match self {
44 GammaError::ShapeInvalid => write!(f, "Shape is NaN zero, or less than zero."),
45 GammaError::RateInvalid => write!(f, "Rate is NaN zero, or less than zero."),
46 GammaError::ShapeAndRateInfinite => write!(f, "Shape and rate are infinite"),
47 }
48 }
49}
50
51impl std::error::Error for GammaError {}
52
53impl Gamma {
54 pub fn new(shape: f64, rate: f64) -> Result<Gamma, GammaError> {
74 if shape.is_nan() || shape <= 0.0 {
75 return Err(GammaError::ShapeInvalid);
76 }
77
78 if rate.is_nan() || rate <= 0.0 {
79 return Err(GammaError::RateInvalid);
80 }
81
82 if shape.is_infinite() && rate.is_infinite() {
83 return Err(GammaError::ShapeAndRateInfinite);
84 }
85
86 Ok(Gamma { shape, rate })
87 }
88
89 pub fn shape(&self) -> f64 {
100 self.shape
101 }
102
103 pub fn rate(&self) -> f64 {
114 self.rate
115 }
116}
117
118impl std::fmt::Display for Gamma {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 write!(f, "Γ({}, {})", self.shape, self.rate)
121 }
122}
123
124#[cfg(feature = "rand")]
125#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
126impl ::rand::distributions::Distribution<f64> for Gamma {
127 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
128 sample_unchecked(rng, self.shape, self.rate)
129 }
130}
131
132impl ContinuousCDF<f64, f64> for Gamma {
133 fn cdf(&self, x: f64) -> f64 {
146 if x <= 0.0 {
147 0.0
148 } else if ulps_eq!(x, self.shape) && self.rate.is_infinite() {
149 1.0
150 } else if self.rate.is_infinite() {
151 0.0
152 } else if x.is_infinite() {
153 1.0
154 } else {
155 gamma::gamma_lr(self.shape, x * self.rate)
156 }
157 }
158
159 fn sf(&self, x: f64) -> f64 {
171 if x <= 0.0 {
172 1.0
173 } else if ulps_eq!(x, self.shape) && self.rate.is_infinite() {
174 0.0
175 } else if self.rate.is_infinite() {
176 1.0
177 } else if x.is_infinite() {
178 0.0
179 } else {
180 gamma::gamma_ur(self.shape, x * self.rate)
181 }
182 }
183
184 fn inverse_cdf(&self, p: f64) -> f64 {
185 if !(0.0..=1.0).contains(&p) {
186 panic!("default inverse_cdf implementation should be provided probability on [0,1]")
187 }
188 if p == 0.0 {
189 return self.min();
190 };
191 if p == 1.0 {
192 return self.max();
193 };
194
195 let mut high = 2.0;
197 let mut low = 1.0;
198 while self.cdf(low) > p {
199 low /= 2.0;
200 }
201 while self.cdf(high) < p {
202 high *= 2.0;
203 }
204 let mut x_0 = (high + low) / 2.0;
205
206 for _ in 0..8 {
207 if self.cdf(x_0) >= p {
208 high = x_0;
209 } else {
210 low = x_0;
211 }
212 if prec::convergence(&mut x_0, (high + low) / 2.0) {
213 break;
214 }
215 }
216
217 for _ in 0..4 {
219 let x_next = x_0 - (self.cdf(x_0) - p) / self.pdf(x_0);
220 if prec::convergence(&mut x_0, x_next) {
221 break;
222 }
223 }
224
225 x_0
226 }
227}
228
229impl Min<f64> for Gamma {
230 fn min(&self) -> f64 {
240 0.0
241 }
242}
243
244impl Max<f64> for Gamma {
245 fn max(&self) -> f64 {
255 f64::INFINITY
256 }
257}
258
259impl Distribution<f64> for Gamma {
260 fn mean(&self) -> Option<f64> {
270 Some(self.shape / self.rate)
271 }
272
273 fn variance(&self) -> Option<f64> {
283 Some(self.shape / (self.rate * self.rate))
284 }
285
286 fn entropy(&self) -> Option<f64> {
297 let entr = self.shape - self.rate.ln()
298 + gamma::ln_gamma(self.shape)
299 + (1.0 - self.shape) * gamma::digamma(self.shape);
300 Some(entr)
301 }
302
303 fn skewness(&self) -> Option<f64> {
313 Some(2.0 / self.shape.sqrt())
314 }
315}
316
317impl Mode<Option<f64>> for Gamma {
318 fn mode(&self) -> Option<f64> {
328 if self.shape < 1.0 {
329 None
330 } else {
331 Some((self.shape - 1.0) / self.rate)
332 }
333 }
334}
335
336impl Continuous<f64, f64> for Gamma {
337 fn pdf(&self, x: f64) -> f64 {
353 if x < 0.0 {
354 0.0
355 } else if ulps_eq!(self.shape, 1.0) {
356 self.rate * (-self.rate * x).exp()
357 } else if self.shape > 160.0 {
358 self.ln_pdf(x).exp()
359 } else if x.is_infinite() {
360 0.0
361 } else {
362 self.rate.powf(self.shape) * x.powf(self.shape - 1.0) * (-self.rate * x).exp()
363 / gamma::gamma(self.shape)
364 }
365 }
366
367 fn ln_pdf(&self, x: f64) -> f64 {
384 if x < 0.0 {
385 f64::NEG_INFINITY
386 } else if ulps_eq!(self.shape, 1.0) {
387 self.rate.ln() - self.rate * x
388 } else if x.is_infinite() {
389 f64::NEG_INFINITY
390 } else {
391 self.shape * self.rate.ln() + (self.shape - 1.0) * x.ln()
392 - self.rate * x
393 - gamma::ln_gamma(self.shape)
394 }
395 }
396}
397#[cfg(feature = "rand")]
405#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
406pub fn sample_unchecked<R: ::rand::Rng + ?Sized>(rng: &mut R, shape: f64, rate: f64) -> f64 {
407 let mut a = shape;
408 let mut afix = 1.0;
409 if shape < 1.0 {
410 a = shape + 1.0;
411 afix = rng.gen::<f64>().powf(1.0 / shape);
412 }
413
414 let d = a - 1.0 / 3.0;
415 let c = 1.0 / (9.0 * d).sqrt();
416 loop {
417 let mut x;
418 let mut v;
419 loop {
420 x = super::normal::sample_unchecked(rng, 0.0, 1.0);
421 v = 1.0 + c * x;
422 if v > 0.0 {
423 break;
424 };
425 }
426
427 v = v * v * v;
428 x = x * x;
429 let u: f64 = rng.gen();
430 if u < 1.0 - 0.0331 * x * x || u.ln() < 0.5 * x + d * (1.0 - v + v.ln()) {
431 return afix * d * v / rate;
432 }
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use crate::distribution::internal::*;
440 use crate::testing_boiler;
441
442 testing_boiler!(shape: f64, rate: f64; Gamma; GammaError);
443
444 #[test]
445 fn test_create() {
446 let valid = [
447 (1.0, 0.1),
448 (1.0, 1.0),
449 (10.0, 10.0),
450 (10.0, 1.0),
451 (10.0, f64::INFINITY),
452 ];
453
454 for (s, r) in valid {
455 create_ok(s, r);
456 }
457 }
458
459 #[test]
460 fn test_bad_create() {
461 let invalid = [
462 (0.0, 0.0, GammaError::ShapeInvalid),
463 (1.0, f64::NAN, GammaError::RateInvalid),
464 (1.0, -1.0, GammaError::RateInvalid),
465 (-1.0, 1.0, GammaError::ShapeInvalid),
466 (-1.0, -1.0, GammaError::ShapeInvalid),
467 (-1.0, f64::NAN, GammaError::ShapeInvalid),
468 (
469 f64::INFINITY,
470 f64::INFINITY,
471 GammaError::ShapeAndRateInfinite,
472 ),
473 ];
474 for (s, r, err) in invalid {
475 test_create_err(s, r, err);
476 }
477 }
478
479 #[test]
480 fn test_mean() {
481 let f = |x: Gamma| x.mean().unwrap();
482 let test = [
483 ((1.0, 0.1), 10.0),
484 ((1.0, 1.0), 1.0),
485 ((10.0, 10.0), 1.0),
486 ((10.0, 1.0), 10.0),
487 ((10.0, f64::INFINITY), 0.0),
488 ];
489 for ((s, r), res) in test {
490 test_relative(s, r, res, f);
491 }
492 }
493
494 #[test]
495 fn test_variance() {
496 let f = |x: Gamma| x.variance().unwrap();
497 let test = [
498 ((1.0, 0.1), 100.0),
499 ((1.0, 1.0), 1.0),
500 ((10.0, 10.0), 0.1),
501 ((10.0, 1.0), 10.0),
502 ((10.0, f64::INFINITY), 0.0),
503 ];
504 for ((s, r), res) in test {
505 test_relative(s, r, res, f);
506 }
507 }
508
509 #[test]
510 fn test_entropy() {
511 let f = |x: Gamma| x.entropy().unwrap();
512 let test = [
513 ((1.0, 0.1), 3.302585092994045628506840223),
514 ((1.0, 1.0), 1.0),
515 ((10.0, 10.0), 0.2334690854869339583626209),
516 ((10.0, 1.0), 2.53605417848097964238061239),
517 ((10.0, f64::INFINITY), f64::NEG_INFINITY),
518 ];
519 for ((s, r), res) in test {
520 test_relative(s, r, res, f);
521 }
522 }
523
524 #[test]
525 fn test_skewness() {
526 let f = |x: Gamma| x.skewness().unwrap();
527 let test = [
528 ((1.0, 0.1), 2.0),
529 ((1.0, 1.0), 2.0),
530 ((10.0, 10.0), 0.6324555320336758663997787),
531 ((10.0, 1.0), 0.63245553203367586639977870),
532 ((10.0, f64::INFINITY), 0.6324555320336758),
533 ];
534 for ((s, r), res) in test {
535 test_relative(s, r, res, f);
536 }
537 }
538
539 #[test]
540 fn test_mode() {
541 let f = |x: Gamma| x.mode().unwrap();
542 let test = [((1.0, 0.1), 0.0), ((1.0, 1.0), 0.0)];
543 for &((s, r), res) in test.iter() {
544 test_absolute(s, r, res, 10e-6, f);
545 }
546 let test = [
547 ((10.0, 10.0), 0.9),
548 ((10.0, 1.0), 9.0),
549 ((10.0, f64::INFINITY), 0.0),
550 ];
551 for ((s, r), res) in test {
552 test_relative(s, r, res, f);
553 }
554 }
555
556 #[test]
557 fn test_min_max() {
558 let f = |x: Gamma| x.min();
559 let test = [
560 ((1.0, 0.1), 0.0),
561 ((1.0, 1.0), 0.0),
562 ((10.0, 10.0), 0.0),
563 ((10.0, 1.0), 0.0),
564 ((10.0, f64::INFINITY), 0.0),
565 ];
566 for ((s, r), res) in test {
567 test_relative(s, r, res, f);
568 }
569 let f = |x: Gamma| x.max();
570 let test = [
571 ((1.0, 0.1), f64::INFINITY),
572 ((1.0, 1.0), f64::INFINITY),
573 ((10.0, 10.0), f64::INFINITY),
574 ((10.0, 1.0), f64::INFINITY),
575 ((10.0, f64::INFINITY), f64::INFINITY),
576 ];
577 for ((s, r), res) in test {
578 test_relative(s, r, res, f);
579 }
580 }
581
582 #[test]
583 fn test_pdf() {
584 let f = |arg: f64| move |x: Gamma| x.pdf(arg);
585 let test = [
586 ((1.0, 0.1), 1.0, 0.090483741803595961836995),
587 ((1.0, 0.1), 10.0, 0.036787944117144234201693),
588 ((1.0, 1.0), 1.0, 0.367879441171442321595523),
589 ((1.0, 1.0), 10.0, 0.000045399929762484851535),
590 ((10.0, 10.0), 1.0, 1.251100357211332989847649),
591 ((10.0, 10.0), 10.0, 1.025153212086870580621609e-30),
592 ((10.0, 1.0), 1.0, 0.000001013777119630297402),
593 ((10.0, 1.0), 10.0, 0.125110035721133298984764),
594 ];
595 for ((s, r), x, res) in test {
596 test_relative(s, r, res, f(x));
597 }
598 }
603
604 #[test]
605 fn test_pdf_at_zero() {
606 test_relative(1.0, 0.1, 0.1, |x| x.pdf(0.0));
607 test_relative(1.0, 0.1, 0.1f64.ln(), |x| x.ln_pdf(0.0));
608 }
609
610 #[test]
611 fn test_ln_pdf() {
612 let f = |arg: f64| move |x: Gamma| x.ln_pdf(arg);
613 let test = [
614 ((1.0, 0.1), 1.0, -2.40258509299404563405795),
615 ((1.0, 0.1), 10.0, -3.30258509299404562850684),
616 ((1.0, 1.0), 1.0, -1.0),
617 ((1.0, 1.0), 10.0, -10.0),
618 ((10.0, 10.0), 1.0, 0.224023449858987228972196),
619 ((10.0, 10.0), 10.0, -69.0527107131946016148658),
620 ((10.0, 1.0), 1.0, -13.8018274800814696112077),
621 ((10.0, 1.0), 10.0, -2.07856164313505845504579),
622 ((10.0, f64::INFINITY), f64::INFINITY, f64::NEG_INFINITY),
623 ];
624 for ((s, r), x, res) in test {
625 test_relative(s, r, res, f(x));
626 }
627 }
630
631 #[test]
632 fn test_cdf() {
633 let f = |arg: f64| move |x: Gamma| x.cdf(arg);
634 let test = [
635 ((1.0, 0.1), 1.0, 0.095162581964040431858607),
636 ((1.0, 0.1), 10.0, 0.632120558828557678404476),
637 ((1.0, 1.0), 1.0, 0.632120558828557678404476),
638 ((1.0, 1.0), 10.0, 0.999954600070237515148464),
639 ((10.0, 10.0), 1.0, 0.542070285528147791685835),
640 ((10.0, 10.0), 10.0, 0.999999999999999999999999),
641 ((10.0, 1.0), 1.0, 0.000000111425478338720677),
642 ((10.0, 1.0), 10.0, 0.542070285528147791685835),
643 ((10.0, f64::INFINITY), 1.0, 0.0),
644 ((10.0, f64::INFINITY), 10.0, 1.0),
645 ];
646 for ((s, r), x, res) in test {
647 test_relative(s, r, res, f(x));
648 }
649 }
650
651 #[test]
652 fn test_cdf_at_zero() {
653 test_relative(1.0, 0.1, 0.0, |x| x.cdf(0.0));
654 }
655
656 #[test]
657 fn test_cdf_inverse_identity() {
658 let f = |p: f64| move |g: Gamma| g.cdf(g.inverse_cdf(p));
659 let params = [
660 (1.0, 0.1),
661 (1.0, 1.0),
662 (10.0, 10.0),
663 (10.0, 1.0),
664 (100.0, 200.0),
665 ];
666
667 for (s, r) in params {
668 for n in -5..0 {
669 let p = 10.0f64.powi(n);
670 test_relative(s, r, p, f(p));
671 }
672 }
673
674 {
676 let x = 20.5567;
677 let f = |x: f64| move |g: Gamma| g.inverse_cdf(g.cdf(x));
678 test_relative(3.0, 0.5, x, f(x))
679 }
680 }
681
682 #[test]
683 fn test_sf() {
684 let f = |arg: f64| move |x: Gamma| x.sf(arg);
685 let test = [
686 ((1.0, 0.1), 1.0, 0.9048374180359595),
687 ((1.0, 0.1), 10.0, 0.3678794411714419),
688 ((1.0, 1.0), 1.0, 0.3678794411714419),
689 ((1.0, 1.0), 10.0, 4.539992976249074e-5),
690 ((10.0, 10.0), 1.0, 0.4579297144718528),
691 ((10.0, 10.0), 10.0, 1.1253473960842808e-31),
692 ((10.0, 1.0), 1.0, 0.9999998885745217),
693 ((10.0, 1.0), 10.0, 0.4579297144718528),
694 ((10.0, f64::INFINITY), 1.0, 1.0),
695 ((10.0, f64::INFINITY), 10.0, 0.0),
696 ];
697 for ((s, r), x, res) in test {
698 test_relative(s, r, res, f(x));
699 }
700 }
701
702 #[test]
703 fn test_sf_at_zero() {
704 test_relative(1.0, 0.1, 1.0, |x| x.sf(0.0));
705 }
706
707 #[test]
708 fn test_continuous() {
709 test::check_continuous_distribution(&create_ok(1.0, 0.5), 0.0, 20.0);
710 test::check_continuous_distribution(&create_ok(9.0, 2.0), 0.0, 20.0);
711 }
712}