statrs/distribution/
negative_binomial.rs

1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::function::{beta, gamma};
3use crate::statistics::*;
4use std::f64;
5
6/// Implements the
7/// [negative binomial](http://en.wikipedia.org/wiki/Negative_binomial_distribution)
8/// distribution.
9///
10/// *Please note carefully the meaning of the parameters.*  As noted in the
11/// wikipedia article, there are several different commonly used conventions
12/// for the parameters of the negative binomial distribution.
13///
14/// The negative binomial distribution is a discrete distribution with two
15/// parameters, `r` and `p`.  When `r` is an integer, the negative binomial
16/// distribution can be interpreted as the distribution of the number of
17/// failures in a sequence of Bernoulli trials that continue until `r`
18/// successes occur.  `p` is the probability of success in a single Bernoulli
19/// trial.
20///
21/// `NegativeBinomial` accepts non-integer values for `r`.  This is a
22/// generalization of the more common case where `r` is an integer.
23///
24/// # Examples
25///
26/// ```
27/// use statrs::distribution::{NegativeBinomial, Discrete};
28/// use statrs::statistics::DiscreteDistribution;
29/// use statrs::prec::almost_eq;
30///
31/// let r = NegativeBinomial::new(4.0, 0.5).unwrap();
32/// assert_eq!(r.mean().unwrap(), 4.0);
33/// assert!(almost_eq(r.pmf(0), 0.0625, 1e-8));
34/// assert!(almost_eq(r.pmf(3), 0.15625, 1e-8));
35/// ```
36#[derive(Copy, Clone, PartialEq, Debug)]
37pub struct NegativeBinomial {
38    r: f64,
39    p: f64,
40}
41
42/// Represents the errors that can occur when creating a [`NegativeBinomial`].
43#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
44#[non_exhaustive]
45pub enum NegativeBinomialError {
46    /// `r` is NaN or less than zero.
47    RInvalid,
48
49    /// `p` is NaN or not in `[0, 1]`.
50    PInvalid,
51}
52
53impl std::fmt::Display for NegativeBinomialError {
54    #[cfg_attr(coverage_nightly, coverage(off))]
55    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
56        match self {
57            NegativeBinomialError::RInvalid => write!(f, "r is NaN or less than zero"),
58            NegativeBinomialError::PInvalid => write!(f, "p is NaN or not in [0, 1]"),
59        }
60    }
61}
62
63impl std::error::Error for NegativeBinomialError {}
64
65impl NegativeBinomial {
66    /// Constructs a new negative binomial distribution with parameters `r`
67    /// and `p`.  When `r` is an integer, the negative binomial distribution
68    /// can be interpreted as the distribution of the number of failures in
69    /// a sequence of Bernoulli trials that continue until `r` successes occur.
70    /// `p` is the probability of success in a single Bernoulli trial.
71    ///
72    /// # Errors
73    ///
74    /// Returns an error if `p` is `NaN`, less than `0.0`,
75    /// greater than `1.0`, or if `r` is `NaN` or less than `0`
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use statrs::distribution::NegativeBinomial;
81    ///
82    /// let mut result = NegativeBinomial::new(4.0, 0.5);
83    /// assert!(result.is_ok());
84    ///
85    /// result = NegativeBinomial::new(-0.5, 5.0);
86    /// assert!(result.is_err());
87    /// ```
88    pub fn new(r: f64, p: f64) -> Result<NegativeBinomial, NegativeBinomialError> {
89        if r.is_nan() || r < 0.0 {
90            return Err(NegativeBinomialError::RInvalid);
91        }
92
93        if p.is_nan() || !(0.0..=1.0).contains(&p) {
94            return Err(NegativeBinomialError::PInvalid);
95        }
96
97        Ok(NegativeBinomial { r, p })
98    }
99
100    /// Returns the probability of success `p` of a single
101    /// Bernoulli trial associated with the negative binomial
102    /// distribution.
103    ///
104    /// # Examples
105    ///
106    /// ```
107    /// use statrs::distribution::NegativeBinomial;
108    ///
109    /// let r = NegativeBinomial::new(5.0, 0.5).unwrap();
110    /// assert_eq!(r.p(), 0.5);
111    /// ```
112    pub fn p(&self) -> f64 {
113        self.p
114    }
115
116    /// Returns the number `r` of success of this negative
117    /// binomial distribution.
118    ///
119    /// # Examples
120    ///
121    /// ```
122    /// use statrs::distribution::NegativeBinomial;
123    ///
124    /// let r = NegativeBinomial::new(5.0, 0.5).unwrap();
125    /// assert_eq!(r.r(), 5.0);
126    /// ```
127    pub fn r(&self) -> f64 {
128        self.r
129    }
130}
131
132impl std::fmt::Display for NegativeBinomial {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        write!(f, "NB({},{})", self.r, self.p)
135    }
136}
137
138#[cfg(feature = "rand")]
139#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
140impl ::rand::distributions::Distribution<u64> for NegativeBinomial {
141    fn sample<R: ::rand::Rng + ?Sized>(&self, r: &mut R) -> u64 {
142        use crate::distribution::{gamma, poisson};
143
144        let lambda = gamma::sample_unchecked(r, self.r, (1.0 - self.p) / self.p);
145        poisson::sample_unchecked(r, lambda).floor() as u64
146    }
147}
148
149impl DiscreteCDF<u64, f64> for NegativeBinomial {
150    /// Calculates the cumulative distribution function for the
151    /// negative binomial distribution at `x`.
152    ///
153    /// # Formula
154    ///
155    /// ```text
156    /// I_(p)(r, x+1)
157    /// ```
158    ///
159    /// where `I_(x)(a, b)` is the regularized incomplete beta function.
160    fn cdf(&self, x: u64) -> f64 {
161        beta::beta_reg(self.r, x as f64 + 1.0, self.p)
162    }
163
164    /// Calculates the survival function for the
165    /// negative binomial distribution at `x`
166    ///
167    /// Note that due to extending the distribution to the reals
168    /// (allowing positive real values for `r`), while still technically
169    /// a discrete distribution the CDF behaves more like that of a
170    /// continuous distribution rather than a discrete distribution
171    /// (i.e. a smooth graph rather than a step-ladder)
172    ///
173    /// # Formula
174    ///
175    /// ```text
176    /// I_(1-p)(x+1, r)
177    /// ```
178    ///
179    /// where `I_(x)(a, b)` is the regularized incomplete beta function
180    fn sf(&self, x: u64) -> f64 {
181        beta::beta_reg(x as f64 + 1.0, self.r, 1. - self.p)
182    }
183}
184
185impl Min<u64> for NegativeBinomial {
186    /// Returns the minimum value in the domain of the
187    /// negative binomial distribution representable by a 64-bit
188    /// integer.
189    ///
190    /// # Formula
191    ///
192    /// ```text
193    /// 0
194    /// ```
195    fn min(&self) -> u64 {
196        0
197    }
198}
199
200impl Max<u64> for NegativeBinomial {
201    /// Returns the maximum value in the domain of the
202    /// negative binomial distribution representable by a 64-bit
203    /// integer.
204    ///
205    /// # Formula
206    ///
207    /// ```text
208    /// u64::MAX
209    /// ```
210    fn max(&self) -> u64 {
211        u64::MAX
212    }
213}
214
215impl DiscreteDistribution<f64> for NegativeBinomial {
216    /// Returns the mean of the negative binomial distribution.
217    ///
218    /// # Formula
219    ///
220    /// ```text
221    /// r * (1-p) / p
222    /// ```
223    fn mean(&self) -> Option<f64> {
224        Some(self.r * (1.0 - self.p) / self.p)
225    }
226
227    /// Returns the variance of the negative binomial distribution.
228    ///
229    /// # Formula
230    ///
231    /// ```text
232    /// r * (1-p) / p^2
233    /// ```
234    fn variance(&self) -> Option<f64> {
235        Some(self.r * (1.0 - self.p) / (self.p * self.p))
236    }
237
238    /// Returns the skewness of the negative binomial distribution.
239    ///
240    /// # Formula
241    ///
242    /// ```text
243    /// (2-p) / sqrt(r * (1-p))
244    /// ```
245    fn skewness(&self) -> Option<f64> {
246        Some((2.0 - self.p) / f64::sqrt(self.r * (1.0 - self.p)))
247    }
248}
249
250impl Mode<Option<f64>> for NegativeBinomial {
251    /// Returns the mode for the negative binomial distribution.
252    ///
253    /// # Formula
254    ///
255    /// ```text
256    /// if r > 1 then
257    ///     floor((r - 1) * (1-p / p))
258    /// else
259    ///     0
260    /// ```
261    fn mode(&self) -> Option<f64> {
262        let mode = if self.r > 1.0 {
263            f64::floor((self.r - 1.0) * (1.0 - self.p) / self.p)
264        } else {
265            0.0
266        };
267        Some(mode)
268    }
269}
270
271impl Discrete<u64, f64> for NegativeBinomial {
272    /// Calculates the probability mass function for the negative binomial
273    /// distribution at `x`.
274    ///
275    /// # Formula
276    ///
277    /// When `r` is an integer, the formula is:
278    ///
279    /// ```text
280    /// (x + r - 1 choose x) * (1 - p)^x * p^r
281    /// ```
282    ///
283    /// The general formula for real `r` is:
284    ///
285    /// ```text
286    /// Γ(r + x)/(Γ(r) * Γ(x + 1)) * (1 - p)^x * p^r
287    /// ```
288    ///
289    /// where Γ(x) is the Gamma function.
290    fn pmf(&self, x: u64) -> f64 {
291        self.ln_pmf(x).exp()
292    }
293
294    /// Calculates the log probability mass function for the negative binomial
295    /// distribution at `x`.
296    ///
297    /// # Formula
298    ///
299    /// When `r` is an integer, the formula is:
300    ///
301    /// ```text
302    /// ln((x + r - 1 choose x) * (1 - p)^x * p^r)
303    /// ```
304    ///
305    /// The general formula for real `r` is:
306    ///
307    /// ```text
308    /// ln(Γ(r + x)/(Γ(r) * Γ(x + 1)) * (1 - p)^x * p^r)
309    /// ```
310    ///
311    /// where Γ(x) is the Gamma function.
312    fn ln_pmf(&self, x: u64) -> f64 {
313        let k = x as f64;
314        gamma::ln_gamma(self.r + k) - gamma::ln_gamma(self.r) - gamma::ln_gamma(k + 1.0)
315            + (self.r * self.p.ln())
316            + (k * (-self.p).ln_1p())
317    }
318}
319
320#[rustfmt::skip]
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use crate::distribution::internal::test;
325    use crate::testing_boiler;
326
327    testing_boiler!(r: f64, p: f64; NegativeBinomial; NegativeBinomialError);
328
329    #[test]
330    fn test_create() {
331        create_ok(0.0, 0.0);
332        create_ok(0.3, 0.4);
333        create_ok(1.0, 0.3);
334    }
335
336    #[test]
337    fn test_bad_create() {
338        test_create_err(f64::NAN, 1.0, NegativeBinomialError::RInvalid);
339        test_create_err(0.0, f64::NAN, NegativeBinomialError::PInvalid);
340        create_err(-1.0, 1.0);
341        create_err(2.0, 2.0);
342    }
343
344    #[test]
345    fn test_mean() {
346        let mean = |x: NegativeBinomial| x.mean().unwrap();
347        test_exact(4.0, 0.0, f64::INFINITY, mean);
348        test_absolute(3.0, 0.3, 7.0, 1e-15 , mean);
349        test_exact(2.0, 1.0, 0.0, mean);
350    }
351
352    #[test]
353    fn test_variance() {
354        let variance = |x: NegativeBinomial| x.variance().unwrap();
355        test_exact(4.0, 0.0, f64::INFINITY, variance);
356        test_absolute(3.0, 0.3, 23.333333333333, 1e-12, variance);
357        test_exact(2.0, 1.0, 0.0, variance);
358    }
359
360    #[test]
361    fn test_skewness() {
362        let skewness = |x: NegativeBinomial| x.skewness().unwrap();
363        test_exact(0.0, 0.0, f64::INFINITY, skewness);
364        test_absolute(0.1, 0.3, 6.425396041, 1e-09, skewness);
365        test_exact(1.0, 1.0, f64::INFINITY, skewness);
366    }
367
368    #[test]
369    fn test_mode() {
370        let mode = |x: NegativeBinomial| x.mode().unwrap();
371        test_exact(0.0, 0.0, 0.0, mode);
372        test_exact(0.3, 0.0, 0.0, mode);
373        test_exact(1.0, 1.0, 0.0, mode);
374        test_exact(10.0, 0.01, 891.0, mode);
375    }
376
377    #[test]
378    fn test_min_max() {
379        let min = |x: NegativeBinomial| x.min();
380        let max = |x: NegativeBinomial| x.max();
381        test_exact(1.0, 0.5, 0, min);
382        test_exact(1.0, 0.3, u64::MAX, max);
383    }
384
385    #[test]
386    fn test_pmf() {
387        let pmf = |arg: u64| move |x: NegativeBinomial| x.pmf(arg);
388        test_absolute(4.0, 0.5, 0.0625, 1e-8, pmf(0));
389        test_absolute(4.0, 0.5, 0.15625, 1e-8, pmf(3));
390        test_exact(1.0, 0.0, 0.0, pmf(0));
391        test_exact(1.0, 0.0, 0.0, pmf(1));
392        test_absolute(3.0, 0.2, 0.008, 1e-15, pmf(0));
393        test_absolute(3.0, 0.2, 0.0192, 1e-15, pmf(1));
394        test_absolute(3.0, 0.2, 0.04096, 1e-15, pmf(3));
395        test_absolute(10.0, 0.2, 1.024e-07, 1e-07, pmf(0));
396        test_absolute(10.0, 0.2, 8.192e-07, 1e-07, pmf(1));
397        test_absolute(10.0, 0.2, 0.001015706852, 1e-07, pmf(10));
398        test_absolute(1.0, 0.3, 0.3, 1e-15,  pmf(0));
399        test_absolute(1.0, 0.3, 0.21, 1e-15, pmf(1));
400        test_absolute(3.0, 0.3, 0.027, 1e-15, pmf(0));
401        test_exact(0.3, 1.0, 0.0, pmf(1));
402        test_exact(0.3, 1.0, 0.0, pmf(3));
403        test_is_nan(0.3, 1.0, pmf(0));
404        test_exact(0.3, 1.0, 0.0, pmf(1));
405        test_exact(0.3, 1.0, 0.0, pmf(10));
406        test_is_nan(1.0, 1.0, pmf(0));
407        test_exact(1.0, 1.0, 0.0, pmf(1));
408        test_is_nan(3.0, 1.0, pmf(0));
409        test_exact(3.0, 1.0, 0.0, pmf(1));
410        test_exact(3.0, 1.0, 0.0, pmf(3));
411        test_is_nan(10.0, 1.0, pmf(0));
412        test_exact(10.0, 1.0, 0.0, pmf(1));
413        test_exact(10.0, 1.0, 0.0, pmf(10));
414    }
415
416    #[test]
417    fn test_ln_pmf() {
418        let ln_pmf = |arg: u64| move |x: NegativeBinomial| x.ln_pmf(arg);
419        test_exact(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(0));
420        test_exact(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(1));
421        test_absolute(3.0, 0.2, -4.828313737, 1e-08, ln_pmf(0));
422        test_absolute(3.0, 0.2, -3.952845, 1e-08, ln_pmf(1));
423        test_absolute(3.0, 0.2, -3.195159298, 1e-08, ln_pmf(3));
424        test_absolute(10.0, 0.2, -16.09437912, 1e-08, ln_pmf(0));
425        test_absolute(10.0, 0.2, -14.01493758, 1e-08, ln_pmf(1));
426        test_absolute(10.0, 0.2, -6.892170503, 1e-08, ln_pmf(10));
427        test_absolute(1.0, 0.3, -1.203972804, 1e-08,  ln_pmf(0));
428        test_absolute(1.0, 0.3, -1.560647748, 1e-08, ln_pmf(1));
429        test_absolute(3.0, 0.3, -3.611918413, 1e-08, ln_pmf(0));
430        test_exact(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1));
431        test_exact(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(3));
432        test_is_nan(0.3, 1.0, ln_pmf(0));
433        test_exact(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1));
434        test_exact(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(10));
435        test_is_nan(1.0, 1.0, ln_pmf(0));
436        test_exact(1.0, 1.0, f64::NEG_INFINITY, ln_pmf(1));
437        test_is_nan(3.0, 1.0, ln_pmf(0));
438        test_exact(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(1));
439        test_exact(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(3));
440        test_is_nan(10.0, 1.0, ln_pmf(0));
441        test_exact(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(1));
442        test_exact(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(10));
443    }
444
445    #[test]
446    fn test_cdf() {
447        let cdf = |arg: u64| move |x: NegativeBinomial| x.cdf(arg);
448        test_absolute(1.0, 0.3, 0.3, 1e-08, cdf(0));
449        test_absolute(1.0, 0.3, 0.51, 1e-08, cdf(1));
450        test_absolute(1.0, 0.3, 0.83193, 1e-08, cdf(4));
451        test_absolute(1.0, 0.3, 0.9802267326, 1e-08, cdf(10));
452        test_exact(1.0, 1.0, 1.0, cdf(0));
453        test_exact(1.0, 1.0, 1.0, cdf(1));
454        test_absolute(10.0, 0.75, 0.05631351471, 1e-08, cdf(0));
455        test_absolute(10.0, 0.75, 0.1970973015, 1e-08, cdf(1));
456        test_absolute(10.0, 0.75, 0.9960578583, 1e-08, cdf(10));
457    }
458
459    #[test]
460    fn test_sf() {
461        let sf = |arg: u64| move |x: NegativeBinomial| x.sf(arg);
462        test_absolute(1.0, 0.3, 0.7, 1e-08, sf(0));
463        test_absolute(1.0, 0.3, 0.49, 1e-08, sf(1));
464        test_absolute(1.0, 0.3, 0.1680699999999986, 1e-08, sf(4));
465        test_absolute(1.0, 0.3, 0.019773267430000074, 1e-08, sf(10));
466        test_exact(1.0, 1.0, 0.0, sf(0));
467        test_exact(1.0, 1.0, 0.0, sf(1));
468        test_absolute(10.0, 0.75, 0.9436864852905275, 1e-08, sf(0));
469        test_absolute(10.0, 0.75, 0.8029026985168456, 1e-08, sf(1));
470        test_absolute(10.0, 0.75, 0.003942141664083465, 1e-08, sf(10));
471    }
472
473    #[test]
474    fn test_cdf_upper_bound() {
475        let cdf = |arg: u64| move |x: NegativeBinomial| x.cdf(arg);
476        test_exact(3.0, 0.5, 1.0, cdf(100));
477    }
478
479    #[test]
480    fn test_discrete() {
481        test::check_discrete_distribution(&create_ok(5.0, 0.3), 35);
482        test::check_discrete_distribution(&create_ok(10.0, 0.7), 21);
483    }
484    
485    #[test]
486    fn test_sf_upper_bound() {
487        let sf = |arg: u64| move |x: NegativeBinomial| x.sf(arg);
488        test_absolute(3.0, 0.5, 5.282409836586059e-28, 1e-28, sf(100));
489    }
490}