1use crate::distribution::{Continuous, ContinuousCDF};
2use crate::function::gamma;
3use crate::statistics::*;
4use crate::{Result, StatsError};
5use core::f64::INFINITY as INF;
6use rand::Rng;
7
8#[derive(Debug, Copy, Clone, PartialEq)]
23pub struct Gamma {
24 shape: f64,
25 rate: f64,
26}
27
28impl Gamma {
29 pub fn new(shape: f64, rate: f64) -> Result<Gamma> {
49 if shape.is_nan()
50 || rate.is_nan()
51 || shape.is_infinite() && rate.is_infinite()
52 || shape <= 0.0
53 || rate <= 0.0
54 {
55 return Err(StatsError::BadParams);
56 }
57 Ok(Gamma { shape, rate })
58 }
59
60 pub fn shape(&self) -> f64 {
71 self.shape
72 }
73
74 pub fn rate(&self) -> f64 {
85 self.rate
86 }
87}
88
89impl ::rand::distributions::Distribution<f64> for Gamma {
90 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
91 sample_unchecked(rng, self.shape, self.rate)
92 }
93}
94
95impl ContinuousCDF<f64, f64> for Gamma {
96 fn cdf(&self, x: f64) -> f64 {
109 if x <= 0.0 {
110 0.0
111 } else if ulps_eq!(x, self.shape) && self.rate.is_infinite() {
112 1.0
113 } else if self.rate.is_infinite() {
114 0.0
115 } else if x.is_infinite() {
116 1.0
117 } else {
118 gamma::gamma_lr(self.shape, x * self.rate)
119 }
120 }
121
122 fn sf(&self, x: f64) -> f64 {
134 if x <= 0.0 {
135 1.0
136 }
137 else if ulps_eq!(x, self.shape) && self.rate.is_infinite() {
138 0.0
139 }
140 else if self.rate.is_infinite() {
141 1.0
142 }
143 else if x.is_infinite() {
144 0.0
145 }
146 else {
147 gamma::gamma_ur(self.shape, x * self.rate)
148 }
149 }
150}
151
152impl Min<f64> for Gamma {
153 fn min(&self) -> f64 {
163 0.0
164 }
165}
166
167impl Max<f64> for Gamma {
168 fn max(&self) -> f64 {
178 INF
179 }
180}
181
182impl Distribution<f64> for Gamma {
183 fn mean(&self) -> Option<f64> {
193 Some(self.shape / self.rate)
194 }
195 fn variance(&self) -> Option<f64> {
205 Some(self.shape / (self.rate * self.rate))
206 }
207 fn entropy(&self) -> Option<f64> {
218 let entr = self.shape - self.rate.ln()
219 + gamma::ln_gamma(self.shape)
220 + (1.0 - self.shape) * gamma::digamma(self.shape);
221 Some(entr)
222 }
223 fn skewness(&self) -> Option<f64> {
233 Some(2.0 / self.shape.sqrt())
234 }
235}
236
237impl Mode<Option<f64>> for Gamma {
238 fn mode(&self) -> Option<f64> {
248 Some((self.shape - 1.0) / self.rate)
249 }
250}
251
252impl Continuous<f64, f64> for Gamma {
253 fn pdf(&self, x: f64) -> f64 {
269 if x < 0.0 {
270 0.0
271 } else if ulps_eq!(self.shape, 1.0) {
272 self.rate * (-self.rate * x).exp()
273 } else if self.shape > 160.0 {
274 self.ln_pdf(x).exp()
275 } else if x.is_infinite() {
276 0.0
277 } else {
278 self.rate.powf(self.shape) * x.powf(self.shape - 1.0) * (-self.rate * x).exp()
279 / gamma::gamma(self.shape)
280 }
281 }
282
283 fn ln_pdf(&self, x: f64) -> f64 {
300 if x < 0.0 {
301 f64::NEG_INFINITY
302 } else if ulps_eq!(self.shape, 1.0) {
303 self.rate.ln() - self.rate * x
304 } else if x.is_infinite() {
305 f64::NEG_INFINITY
306 } else {
307 self.shape * self.rate.ln() + (self.shape - 1.0) * x.ln()
308 - self.rate * x
309 - gamma::ln_gamma(self.shape)
310 }
311 }
312}
313pub fn sample_unchecked<R: Rng + ?Sized>(rng: &mut R, shape: f64, rate: f64) -> f64 {
325 let mut a = shape;
326 let mut afix = 1.0;
327 if shape < 1.0 {
328 a = shape + 1.0;
329 afix = rng.gen::<f64>().powf(1.0 / shape);
330 }
331
332 let d = a - 1.0 / 3.0;
333 let c = 1.0 / (9.0 * d).sqrt();
334 loop {
335 let mut x;
336 let mut v;
337 loop {
338 x = super::normal::sample_unchecked(rng, 0.0, 1.0);
339 v = 1.0 + c * x;
340 if v > 0.0 {
341 break;
342 };
343 }
344
345 v *= v * v;
346 x *= x;
347 let u: f64 = rng.gen();
348 if u < 1.0 - 0.0331 * x * x || u.ln() < 0.5 * x + d * (1.0 - v + v.ln()) {
349 return afix * d * v / rate;
350 }
351 }
352}
353
354#[cfg(all(test, feature = "nightly"))]
355mod tests {
356 use super::*;
357 use crate::consts::ACC;
358 use crate::distribution::internal::*;
359 use crate::testing_boiler;
360
361 testing_boiler!((f64, f64), Gamma);
362
363 #[test]
364 fn test_create() {
365 let valid = [
366 (1.0, 0.1),
367 (1.0, 1.0),
368 (10.0, 10.0),
369 (10.0, 1.0),
370 (10.0, INF),
371 ];
372
373 for &arg in valid.iter() {
374 try_create(arg);
375 }
376 }
377
378 #[test]
379 fn test_bad_create() {
380 let invalid = [
381 (0.0, 0.0),
382 (1.0, f64::NAN),
383 (1.0, -1.0),
384 (-1.0, 1.0),
385 (-1.0, -1.0),
386 (-1.0, f64::NAN),
387 ];
388 for &arg in invalid.iter() {
389 bad_create_case(arg);
390 }
391 }
392
393 #[test]
394 fn test_mean() {
395 let f = |x: Gamma| x.mean().unwrap();
396 let test = [
397 ((1.0, 0.1), 10.0),
398 ((1.0, 1.0), 1.0),
399 ((10.0, 10.0), 1.0),
400 ((10.0, 1.0), 10.0),
401 ((10.0, INF), 0.0),
402 ];
403 for &(arg, res) in test.iter() {
404 test_case(arg, res, f);
405 }
406 }
407
408 #[test]
409 fn test_variance() {
410 let f = |x: Gamma| x.variance().unwrap();
411 let test = [
412 ((1.0, 0.1), 100.0),
413 ((1.0, 1.0), 1.0),
414 ((10.0, 10.0), 0.1),
415 ((10.0, 1.0), 10.0),
416 ((10.0, INF), 0.0),
417 ];
418 for &(arg, res) in test.iter() {
419 test_case(arg, res, f);
420 }
421 }
422
423 #[test]
424 fn test_entropy() {
425 let f = |x: Gamma| x.entropy().unwrap();
426 let test = [
427 ((1.0, 0.1), 3.302585092994045628506840223),
428 ((1.0, 1.0), 1.0),
429 ((10.0, 10.0), 0.2334690854869339583626209),
430 ((10.0, 1.0), 2.53605417848097964238061239),
431 ((10.0, INF), f64::NEG_INFINITY),
432 ];
433 for &(arg, res) in test.iter() {
434 test_case(arg, res, f);
435 }
436 }
437
438 #[test]
439 fn test_skewness() {
440 let f = |x: Gamma| x.skewness().unwrap();
441 let test = [
442 ((1.0, 0.1), 2.0),
443 ((1.0, 1.0), 2.0),
444 ((10.0, 10.0), 0.6324555320336758663997787),
445 ((10.0, 1.0), 0.63245553203367586639977870),
446 ((10.0, INF), 0.6324555320336758),
447 ];
448 for &(arg, res) in test.iter() {
449 test_case(arg, res, f);
450 }
451 }
452
453 #[test]
454 fn test_mode() {
455 let f = |x: Gamma| x.mode().unwrap();
456 let test = [((1.0, 0.1), 0.0), ((1.0, 1.0), 0.0)];
457 for &(arg, res) in test.iter() {
458 test_case_special(arg, res, 10e-6, f);
459 }
460 let test = [((10.0, 10.0), 0.9), ((10.0, 1.0), 9.0), ((10.0, INF), 0.0)];
461 for &(arg, res) in test.iter() {
462 test_case(arg, res, f);
463 }
464 }
465
466 #[test]
467 fn test_min_max() {
468 let f = |x: Gamma| x.min();
469 let test = [
470 ((1.0, 0.1), 0.0),
471 ((1.0, 1.0), 0.0),
472 ((10.0, 10.0), 0.0),
473 ((10.0, 1.0), 0.0),
474 ((10.0, INF), 0.0),
475 ];
476 for &(arg, res) in test.iter() {
477 test_case(arg, res, f);
478 }
479 let f = |x: Gamma| x.max();
480 let test = [
481 ((1.0, 0.1), INF),
482 ((1.0, 1.0), INF),
483 ((10.0, 10.0), INF),
484 ((10.0, 1.0), INF),
485 ((10.0, INF), INF),
486 ];
487 for &(arg, res) in test.iter() {
488 test_case(arg, res, f);
489 }
490 }
491
492 #[test]
493 fn test_pdf() {
494 let f = |arg: f64| move |x: Gamma| x.pdf(arg);
495 let test = [
496 ((1.0, 0.1), 1.0, 0.090483741803595961836995),
497 ((1.0, 0.1), 10.0, 0.036787944117144234201693),
498 ((1.0, 1.0), 1.0, 0.367879441171442321595523),
499 ((1.0, 1.0), 10.0, 0.000045399929762484851535),
500 ((10.0, 10.0), 1.0, 1.251100357211332989847649),
501 ((10.0, 10.0), 10.0, 1.025153212086870580621609e-30),
502 ((10.0, 1.0), 1.0, 0.000001013777119630297402),
503 ((10.0, 1.0), 10.0, 0.125110035721133298984764),
504 ];
505 for &(arg, x, res) in test.iter() {
506 test_case(arg, res, f(x));
507 }
508 }
513
514 #[test]
515 fn test_pdf_at_zero() {
516 test_case((1.0, 0.1), 0.1, |x| x.pdf(0.0));
517 test_case((1.0, 0.1), 0.1f64.ln(), |x| x.ln_pdf(0.0));
518 }
519
520 #[test]
521 fn test_ln_pdf() {
522 let f = |arg: f64| move |x: Gamma| x.ln_pdf(arg);
523 let test = [
524 ((1.0, 0.1), 1.0, -2.40258509299404563405795),
525 ((1.0, 0.1), 10.0, -3.30258509299404562850684),
526 ((1.0, 1.0), 1.0, -1.0),
527 ((1.0, 1.0), 10.0, -10.0),
528 ((10.0, 10.0), 1.0, 0.224023449858987228972196),
529 ((10.0, 10.0), 10.0, -69.0527107131946016148658),
530 ((10.0, 1.0), 1.0, -13.8018274800814696112077),
531 ((10.0, 1.0), 10.0, -2.07856164313505845504579),
532 ((10.0, INF), INF, f64::NEG_INFINITY),
533 ];
534 for &(arg, x, res) in test.iter() {
535 test_case(arg, res, f(x));
536 }
537 }
540
541 #[test]
542 fn test_cdf() {
543 let f = |arg: f64| move |x: Gamma| x.cdf(arg);
544 let test = [
545 ((1.0, 0.1), 1.0, 0.095162581964040431858607),
546 ((1.0, 0.1), 10.0, 0.632120558828557678404476),
547 ((1.0, 1.0), 1.0, 0.632120558828557678404476),
548 ((1.0, 1.0), 10.0, 0.999954600070237515148464),
549 ((10.0, 10.0), 1.0, 0.542070285528147791685835),
550 ((10.0, 10.0), 10.0, 0.999999999999999999999999),
551 ((10.0, 1.0), 1.0, 0.000000111425478338720677),
552 ((10.0, 1.0), 10.0, 0.542070285528147791685835),
553 ((10.0, INF), 1.0, 0.0),
554 ((10.0, INF), 10.0, 1.0),
555 ];
556 for &(arg, x, res) in test.iter() {
557 test_case(arg, res, f(x));
558 }
559 }
560
561 #[test]
562 fn test_cdf_at_zero() {
563 test_case((1.0, 0.1), 0.0, |x| x.cdf(0.0));
564 }
565
566 #[test]
567 fn test_sf() {
568 let f = |arg: f64| move |x: Gamma| x.sf(arg);
569 let test = [
570 ((1.0, 0.1), 1.0, 0.9048374180359595),
571 ((1.0, 0.1), 10.0, 0.3678794411714419),
572 ((1.0, 1.0), 1.0, 0.3678794411714419),
573 ((1.0, 1.0), 10.0, 4.539992976249074e-5),
574 ((10.0, 10.0), 1.0, 0.4579297144718528),
575 ((10.0, 10.0), 10.0, 1.1253473960842808e-31),
576 ((10.0, 1.0), 1.0, 0.9999998885745217),
577 ((10.0, 1.0), 10.0, 0.4579297144718528),
578 ((10.0, INF), 1.0, 1.0),
579 ((10.0, INF), 10.0, 0.0),
580 ];
581 for &(arg, x, res) in test.iter() {
582 test_case(arg, res, f(x));
583 }
584 }
585
586 #[test]
587 fn test_sf_at_zero() {
588 test_case((1.0, 0.1), 1.0, |x| x.sf(0.0));
589 }
590
591 #[test]
592 fn test_continuous() {
593 test::check_continuous_distribution(&try_create((1.0, 0.5)), 0.0, 20.0);
594 test::check_continuous_distribution(&try_create((9.0, 2.0)), 0.0, 20.0);
595 }
596}