statrs/distribution/
negative_binomial.rs

1use crate::distribution::{self, poisson, Discrete, DiscreteCDF};
2use crate::function::{beta, gamma};
3use crate::statistics::*;
4use crate::{Result, StatsError};
5use rand::Rng;
6use std::f64;
7
8/// Implements the
9/// [negative binomial](http://en.wikipedia.org/wiki/Negative_binomial_distribution)
10/// distribution.
11///
12/// *Please note carefully the meaning of the parameters.*  As noted in the
13/// wikipedia article, there are several different commonly used conventions
14/// for the parameters of the negative binomial distribution.
15///
16/// The negative binomial distribution is a discrete distribution with two
17/// parameters, `r` and `p`.  When `r` is an integer, the negative binomial
18/// distribution can be interpreted as the distribution of the number of
19/// failures in a sequence of Bernoulli trials that continue until `r`
20/// successes occur.  `p` is the probability of success in a single Bernoulli
21/// trial.
22///
23/// `NegativeBinomial` accepts non-integer values for `r`.  This is a
24/// generalization of the more common case where `r` is an integer.
25///
26/// # Examples
27///
28/// ```
29/// use statrs::distribution::{NegativeBinomial, Discrete};
30/// use statrs::statistics::DiscreteDistribution;
31/// use statrs::prec::almost_eq;
32///
33/// let r = NegativeBinomial::new(4.0, 0.5).unwrap();
34/// assert_eq!(r.mean().unwrap(), 4.0);
35/// assert!(almost_eq(r.pmf(0), 0.0625, 1e-8));
36/// assert!(almost_eq(r.pmf(3), 0.15625, 1e-8));
37/// ```
38#[derive(Debug, Copy, Clone, PartialEq)]
39pub struct NegativeBinomial {
40    r: f64,
41    p: f64,
42}
43
44impl NegativeBinomial {
45    /// Constructs a new negative binomial distribution with parameters `r`
46    /// and `p`.  When `r` is an integer, the negative binomial distribution
47    /// can be interpreted as the distribution of the number of failures in
48    /// a sequence of Bernoulli trials that continue until `r` successes occur.
49    /// `p` is the probability of success in a single Bernoulli trial.
50    ///
51    /// # Errors
52    ///
53    /// Returns an error if `p` is `NaN`, less than `0.0`,
54    /// greater than `1.0`, or if `r` is `NaN` or less than `0`
55    ///
56    /// # Examples
57    ///
58    /// ```
59    /// use statrs::distribution::NegativeBinomial;
60    ///
61    /// let mut result = NegativeBinomial::new(4.0, 0.5);
62    /// assert!(result.is_ok());
63    ///
64    /// result = NegativeBinomial::new(-0.5, 5.0);
65    /// assert!(result.is_err());
66    /// ```
67    pub fn new(r: f64, p: f64) -> Result<NegativeBinomial> {
68        if p.is_nan() || p < 0.0 || p > 1.0 || r.is_nan() || r < 0.0 {
69            Err(StatsError::BadParams)
70        } else {
71            Ok(NegativeBinomial { r, p })
72        }
73    }
74
75    /// Returns the probability of success `p` of a single
76    /// Bernoulli trial associated with the negative binomial
77    /// distribution.
78    ///
79    /// # Examples
80    ///
81    /// ```
82    /// use statrs::distribution::NegativeBinomial;
83    ///
84    /// let r = NegativeBinomial::new(5.0, 0.5).unwrap();
85    /// assert_eq!(r.p(), 0.5);
86    /// ```
87    pub fn p(&self) -> f64 {
88        self.p
89    }
90
91    /// Returns the number `r` of success of this negative
92    /// binomial distribution.
93    ///
94    /// # Examples
95    ///
96    /// ```
97    /// use statrs::distribution::NegativeBinomial;
98    ///
99    /// let r = NegativeBinomial::new(5.0, 0.5).unwrap();
100    /// assert_eq!(r.r(), 5.0);
101    /// ```
102    pub fn r(&self) -> f64 {
103        self.r
104    }
105}
106
107impl ::rand::distributions::Distribution<u64> for NegativeBinomial {
108    fn sample<R: Rng + ?Sized>(&self, r: &mut R) -> u64 {
109        let lambda = distribution::gamma::sample_unchecked(r, self.r, (1.0 - self.p) / self.p);
110        poisson::sample_unchecked(r, lambda).floor() as u64
111    }
112}
113
114impl DiscreteCDF<u64, f64> for NegativeBinomial {
115    /// Calculates the cumulative distribution function for the
116    /// negative binomial distribution at `x`.
117    ///
118    /// # Formula
119    ///
120    /// ```ignore
121    /// I_(p)(r, x+1)
122    /// ```
123    ///
124    /// where `I_(x)(a, b)` is the regularized incomplete beta function.
125    fn cdf(&self, x: u64) -> f64 {
126        beta::beta_reg(self.r, x as f64 + 1.0, self.p)
127    }
128
129    /// Calculates the survival function for the
130    /// negative binomial distribution at `x`
131    ///
132    /// Note that due to extending the distribution to the reals
133    /// (allowing positive real values for `r`), while still technically
134    /// a discrete distribution the CDF behaves more like that of a
135    /// continuous distribution rather than a discrete distribution
136    /// (i.e. a smooth graph rather than a step-ladder)
137    ///
138    /// # Formula
139    ///
140    /// ```ignore
141    /// I_(1-p)(x+1, r)
142    /// ```
143    ///
144    /// where `I_(x)(a, b)` is the regularized incomplete beta function
145    fn sf(&self, x: u64) -> f64 {
146        beta::beta_reg(x as f64 + 1.0, self.r, 1. - self.p)
147    }
148}
149
150impl Min<u64> for NegativeBinomial {
151    /// Returns the minimum value in the domain of the
152    /// negative binomial distribution representable by a 64-bit
153    /// integer.
154    ///
155    /// # Formula
156    ///
157    /// ```ignore
158    /// 0
159    /// ```
160    fn min(&self) -> u64 {
161        0
162    }
163}
164
165impl Max<u64> for NegativeBinomial {
166    /// Returns the maximum value in the domain of the
167    /// negative binomial distribution representable by a 64-bit
168    /// integer.
169    ///
170    /// # Formula
171    ///
172    /// ```ignore
173    /// u64::MAX
174    /// ```
175    fn max(&self) -> u64 {
176        std::u64::MAX
177    }
178}
179
180impl DiscreteDistribution<f64> for NegativeBinomial {
181    /// Returns the mean of the negative binomial distribution.
182    ///
183    /// # Formula
184    ///
185    /// ```ignore
186    /// r * (1-p) / p
187    /// ```
188    fn mean(&self) -> Option<f64> {
189        Some(self.r * (1.0 - self.p) / self.p)
190    }
191    /// Returns the variance of the negative binomial distribution.
192    ///
193    /// # Formula
194    ///
195    /// ```ignore
196    /// r * (1-p) / p^2
197    /// ```
198    fn variance(&self) -> Option<f64> {
199        Some(self.r * (1.0 - self.p) / (self.p * self.p))
200    }
201    /// Returns the skewness of the negative binomial distribution.
202    ///
203    /// # Formula
204    ///
205    /// ```ignore
206    /// (2-p) / sqrt(r * (1-p))
207    /// ```
208    fn skewness(&self) -> Option<f64> {
209        Some((2.0 - self.p) / f64::sqrt(self.r * (1.0 - self.p)))
210    }
211}
212
213impl Mode<Option<f64>> for NegativeBinomial {
214    /// Returns the mode for the negative binomial distribution.
215    ///
216    /// # Formula
217    ///
218    /// ```ignore
219    /// if r > 1 then
220    ///     floor((r - 1) * (1-p / p))
221    /// else
222    ///     0
223    /// ```
224    fn mode(&self) -> Option<f64> {
225        let mode = if self.r > 1.0 {
226            f64::floor((self.r - 1.0) * (1.0 - self.p) / self.p)
227        } else {
228            0.0
229        };
230        Some(mode)
231    }
232}
233
234impl Discrete<u64, f64> for NegativeBinomial {
235    /// Calculates the probability mass function for the negative binomial
236    /// distribution at `x`.
237    ///
238    /// # Formula
239    ///
240    /// When `r` is an integer, the formula is:
241    ///
242    /// ```ignore
243    /// (x + r - 1 choose x) * (1 - p)^x * p^r
244    /// ```
245    ///
246    /// The general formula for real `r` is:
247    ///
248    /// ```ignore
249    /// Γ(r + x)/(Γ(r) * Γ(x + 1)) * (1 - p)^x * p^r
250    /// ```
251    ///
252    /// where Γ(x) is the Gamma function.
253    fn pmf(&self, x: u64) -> f64 {
254        self.ln_pmf(x).exp()
255    }
256
257    /// Calculates the log probability mass function for the negative binomial
258    /// distribution at `x`.
259    ///
260    /// # Formula
261    ///
262    /// When `r` is an integer, the formula is:
263    ///
264    /// ```ignore
265    /// ln((x + r - 1 choose x) * (1 - p)^x * p^r)
266    /// ```
267    ///
268    /// The general formula for real `r` is:
269    ///
270    /// ```ignore
271    /// ln(Γ(r + x)/(Γ(r) * Γ(x + 1)) * (1 - p)^x * p^r)
272    /// ```
273    ///
274    /// where Γ(x) is the Gamma function.
275    fn ln_pmf(&self, x: u64) -> f64 {
276        let k = x as f64;
277        gamma::ln_gamma(self.r + k) - gamma::ln_gamma(self.r) - gamma::ln_gamma(k + 1.0)
278            + (self.r * self.p.ln())
279            + (k * (-self.p).ln_1p())
280    }
281}
282
283#[rustfmt::skip]
284#[cfg(all(test, feature = "nightly"))]
285mod tests {
286    use std::fmt::Debug;
287    use crate::statistics::*;
288    use crate::distribution::{DiscreteCDF, Discrete, NegativeBinomial};
289    use crate::distribution::internal::test;
290    use crate::consts::ACC;
291
292    fn try_create(r: f64, p: f64) -> NegativeBinomial {
293        let r = NegativeBinomial::new(r, p);
294        assert!(r.is_ok());
295        r.unwrap()
296    }
297
298    fn create_case(r: f64, p: f64) {
299        let dist = try_create(r, p);
300        assert_eq!(p, dist.p());
301        assert_eq!(r, dist.r());
302    }
303
304    fn bad_create_case(r: f64, p: f64) {
305        let r = NegativeBinomial::new(r, p);
306        assert!(r.is_err());
307    }
308
309    fn get_value<T, F>(r: f64, p: f64, eval: F) -> T
310        where T: PartialEq + Debug,
311                F: Fn(NegativeBinomial) -> T
312    {
313        let r = try_create(r, p);
314        eval(r)
315    }
316
317    fn test_case<T, F>(r: f64, p: f64, expected: T, eval: F)
318        where T: PartialEq + Debug,
319                F: Fn(NegativeBinomial) -> T
320    {
321        let x = get_value(r, p, eval);
322        assert_eq!(expected, x);
323    }
324
325
326    fn test_case_or_nan<F>(r: f64, p: f64, expected: f64, eval: F)
327        where F: Fn(NegativeBinomial) -> f64
328    {
329        let x = get_value(r, p, eval);
330        if expected.is_nan() {
331            assert!(x.is_nan())
332        }
333        else {
334            assert_eq!(expected, x);
335        }
336    }
337    fn test_almost<F>(r: f64, p: f64, expected: f64, acc: f64, eval: F)
338        where F: Fn(NegativeBinomial) -> f64
339    {
340        let x = get_value(r, p, eval);
341        assert_almost_eq!(expected, x, acc);
342    }
343
344    #[test]
345    fn test_create() {
346        create_case(0.0, 0.0);
347        create_case(0.3, 0.4);
348        create_case(1.0, 0.3);
349    }
350
351    #[test]
352    fn test_bad_create() {
353        bad_create_case(f64::NAN, 1.0);
354        bad_create_case(0.0, f64::NAN);
355        bad_create_case(-1.0, 1.0);
356        bad_create_case(2.0, 2.0);
357    }
358
359    #[test]
360    fn test_mean() {
361        let mean = |x: NegativeBinomial| x.mean().unwrap();
362        test_case(4.0, 0.0, f64::INFINITY, mean);
363        test_almost(3.0, 0.3, 7.0, 1e-15 , mean);
364        test_case(2.0, 1.0, 0.0, mean);
365    }
366
367    #[test]
368    fn test_variance() {
369        let variance = |x: NegativeBinomial| x.variance().unwrap();
370        test_case(4.0, 0.0, f64::INFINITY, variance);
371        test_almost(3.0, 0.3, 23.333333333333, 1e-12, variance);
372        test_case(2.0, 1.0, 0.0, variance);
373    }
374
375    #[test]
376    fn test_skewness() {
377        let skewness = |x: NegativeBinomial| x.skewness().unwrap();
378        test_case(0.0, 0.0, f64::INFINITY, skewness);
379        test_almost(0.1, 0.3, 6.425396041, 1e-09, skewness);
380        test_case(1.0, 1.0, f64::INFINITY, skewness);
381    }
382
383    #[test]
384    fn test_mode() {
385        let mode = |x: NegativeBinomial| x.mode().unwrap();
386        test_case(0.0, 0.0, 0.0, mode);
387        test_case(0.3, 0.0, 0.0, mode);
388        test_case(1.0, 1.0, 0.0, mode);
389        test_case(10.0, 0.01, 891.0, mode);
390    }
391
392    #[test]
393    fn test_min_max() {
394        let min = |x: NegativeBinomial| x.min();
395        let max = |x: NegativeBinomial| x.max();
396        test_case(1.0, 0.5, 0, min);
397        test_case(1.0, 0.3, std::u64::MAX, max);
398    }
399
400    #[test]
401    fn test_pmf() {
402        let pmf = |arg: u64| move |x: NegativeBinomial| x.pmf(arg);
403        test_almost(4.0, 0.5, 0.0625, 1e-8, pmf(0));
404        test_almost(4.0, 0.5, 0.15625, 1e-8, pmf(3));
405        test_case(1.0, 0.0, 0.0, pmf(0));
406        test_case(1.0, 0.0, 0.0, pmf(1));
407        test_almost(3.0, 0.2, 0.008, 1e-15, pmf(0));
408        test_almost(3.0, 0.2, 0.0192, 1e-15, pmf(1));
409        test_almost(3.0, 0.2, 0.04096, 1e-15, pmf(3));
410        test_almost(10.0, 0.2, 1.024e-07, 1e-07, pmf(0));
411        test_almost(10.0, 0.2, 8.192e-07, 1e-07, pmf(1));
412        test_almost(10.0, 0.2, 0.001015706852, 1e-07, pmf(10));
413        test_almost(1.0, 0.3, 0.3, 1e-15,  pmf(0));
414        test_almost(1.0, 0.3, 0.21, 1e-15, pmf(1));
415        test_almost(3.0, 0.3, 0.027, 1e-15, pmf(0));
416        test_case(0.3, 1.0, 0.0, pmf(1));
417        test_case(0.3, 1.0, 0.0, pmf(3));
418        test_case_or_nan(0.3, 1.0, f64::NAN, pmf(0));
419        test_case(0.3, 1.0, 0.0, pmf(1));
420        test_case(0.3, 1.0, 0.0, pmf(10));
421        test_case_or_nan(1.0, 1.0, f64::NAN, pmf(0));
422        test_case(1.0, 1.0, 0.0, pmf(1));
423        test_case_or_nan(3.0, 1.0, f64::NAN, pmf(0));
424        test_case(3.0, 1.0, 0.0, pmf(1));
425        test_case(3.0, 1.0, 0.0, pmf(3));
426        test_case_or_nan(10.0, 1.0, f64::NAN, pmf(0));
427        test_case(10.0, 1.0, 0.0, pmf(1));
428        test_case(10.0, 1.0, 0.0, pmf(10));
429    }
430
431    #[test]
432    fn test_ln_pmf() {
433        let ln_pmf = |arg: u64| move |x: NegativeBinomial| x.ln_pmf(arg);
434        test_case(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(0));
435        test_case(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(1));
436        test_almost(3.0, 0.2, -4.828313737, 1e-08, ln_pmf(0));
437        test_almost(3.0, 0.2, -3.952845, 1e-08, ln_pmf(1));
438        test_almost(3.0, 0.2, -3.195159298, 1e-08, ln_pmf(3));
439        test_almost(10.0, 0.2, -16.09437912, 1e-08, ln_pmf(0));
440        test_almost(10.0, 0.2, -14.01493758, 1e-08, ln_pmf(1));
441        test_almost(10.0, 0.2, -6.892170503, 1e-08, ln_pmf(10));
442        test_almost(1.0, 0.3, -1.203972804, 1e-08,  ln_pmf(0));
443        test_almost(1.0, 0.3, -1.560647748, 1e-08, ln_pmf(1));
444        test_almost(3.0, 0.3, -3.611918413, 1e-08, ln_pmf(0));
445        test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1));
446        test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(3));
447        test_case_or_nan(0.3, 1.0, f64::NAN, ln_pmf(0));
448        test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1));
449        test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(10));
450        test_case_or_nan(1.0, 1.0, f64::NAN, ln_pmf(0));
451        test_case(1.0, 1.0, f64::NEG_INFINITY, ln_pmf(1));
452        test_case_or_nan(3.0, 1.0, f64::NAN, ln_pmf(0));
453        test_case(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(1));
454        test_case(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(3));
455        test_case_or_nan(10.0, 1.0, f64::NAN, ln_pmf(0));
456        test_case(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(1));
457        test_case(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(10));
458    }
459
460    #[test]
461    fn test_cdf() {
462        let cdf = |arg: u64| move |x: NegativeBinomial| x.cdf(arg);
463        test_almost(1.0, 0.3, 0.3, 1e-08, cdf(0));
464        test_almost(1.0, 0.3, 0.51, 1e-08, cdf(1));
465        test_almost(1.0, 0.3, 0.83193, 1e-08, cdf(4));
466        test_almost(1.0, 0.3, 0.9802267326, 1e-08, cdf(10));
467        test_case(1.0, 1.0, 1.0, cdf(0));
468        test_case(1.0, 1.0, 1.0, cdf(1));
469        test_almost(10.0, 0.75, 0.05631351471, 1e-08, cdf(0));
470        test_almost(10.0, 0.75, 0.1970973015, 1e-08, cdf(1));
471        test_almost(10.0, 0.75, 0.9960578583, 1e-08, cdf(10));
472    }
473
474    #[test]
475    fn test_sf() {
476        let sf = |arg: u64| move |x: NegativeBinomial| x.sf(arg);
477        test_almost(1.0, 0.3, 0.7, 1e-08, sf(0));
478        test_almost(1.0, 0.3, 0.49, 1e-08, sf(1));
479        test_almost(1.0, 0.3, 0.1680699999999986, 1e-08, sf(4));
480        test_almost(1.0, 0.3, 0.019773267430000074, 1e-08, sf(10));
481        test_case(1.0, 1.0, 0.0, sf(0));
482        test_case(1.0, 1.0, 0.0, sf(1));
483        test_almost(10.0, 0.75, 0.9436864852905275, 1e-08, sf(0));
484        test_almost(10.0, 0.75, 0.8029026985168456, 1e-08, sf(1));
485        test_almost(10.0, 0.75, 0.003942141664083465, 1e-08, sf(10));
486    }
487
488    #[test]
489    fn test_cdf_upper_bound() {
490        let cdf = |arg: u64| move |x: NegativeBinomial| x.cdf(arg);
491        test_case(3.0, 0.5, 1.0, cdf(100));
492    }
493
494    #[test]
495    fn test_discrete() {
496        test::check_discrete_distribution(&try_create(5.0, 0.3), 35);
497        test::check_discrete_distribution(&try_create(10.0, 0.7), 21);
498    }
499    
500    #[test]
501    fn test_sf_upper_bound() {
502        let sf = |arg: u64| move |x: NegativeBinomial| x.sf(arg);
503        test_almost(3.0, 0.5, 5.282409836586059e-28, 1e-28, sf(100));
504    }
505}