1use crate::distribution::{self, poisson, Discrete, DiscreteCDF};
2use crate::function::{beta, gamma};
3use crate::statistics::*;
4use crate::{Result, StatsError};
5use rand::Rng;
6use std::f64;
7
8#[derive(Debug, Copy, Clone, PartialEq)]
39pub struct NegativeBinomial {
40 r: f64,
41 p: f64,
42}
43
44impl NegativeBinomial {
45 pub fn new(r: f64, p: f64) -> Result<NegativeBinomial> {
68 if p.is_nan() || p < 0.0 || p > 1.0 || r.is_nan() || r < 0.0 {
69 Err(StatsError::BadParams)
70 } else {
71 Ok(NegativeBinomial { r, p })
72 }
73 }
74
75 pub fn p(&self) -> f64 {
88 self.p
89 }
90
91 pub fn r(&self) -> f64 {
103 self.r
104 }
105}
106
107impl ::rand::distributions::Distribution<u64> for NegativeBinomial {
108 fn sample<R: Rng + ?Sized>(&self, r: &mut R) -> u64 {
109 let lambda = distribution::gamma::sample_unchecked(r, self.r, (1.0 - self.p) / self.p);
110 poisson::sample_unchecked(r, lambda).floor() as u64
111 }
112}
113
114impl DiscreteCDF<u64, f64> for NegativeBinomial {
115 fn cdf(&self, x: u64) -> f64 {
126 beta::beta_reg(self.r, x as f64 + 1.0, self.p)
127 }
128
129 fn sf(&self, x: u64) -> f64 {
146 beta::beta_reg(x as f64 + 1.0, self.r, 1. - self.p)
147 }
148}
149
150impl Min<u64> for NegativeBinomial {
151 fn min(&self) -> u64 {
161 0
162 }
163}
164
165impl Max<u64> for NegativeBinomial {
166 fn max(&self) -> u64 {
176 std::u64::MAX
177 }
178}
179
180impl DiscreteDistribution<f64> for NegativeBinomial {
181 fn mean(&self) -> Option<f64> {
189 Some(self.r * (1.0 - self.p) / self.p)
190 }
191 fn variance(&self) -> Option<f64> {
199 Some(self.r * (1.0 - self.p) / (self.p * self.p))
200 }
201 fn skewness(&self) -> Option<f64> {
209 Some((2.0 - self.p) / f64::sqrt(self.r * (1.0 - self.p)))
210 }
211}
212
213impl Mode<Option<f64>> for NegativeBinomial {
214 fn mode(&self) -> Option<f64> {
225 let mode = if self.r > 1.0 {
226 f64::floor((self.r - 1.0) * (1.0 - self.p) / self.p)
227 } else {
228 0.0
229 };
230 Some(mode)
231 }
232}
233
234impl Discrete<u64, f64> for NegativeBinomial {
235 fn pmf(&self, x: u64) -> f64 {
254 self.ln_pmf(x).exp()
255 }
256
257 fn ln_pmf(&self, x: u64) -> f64 {
276 let k = x as f64;
277 gamma::ln_gamma(self.r + k) - gamma::ln_gamma(self.r) - gamma::ln_gamma(k + 1.0)
278 + (self.r * self.p.ln())
279 + (k * (-self.p).ln_1p())
280 }
281}
282
283#[rustfmt::skip]
284#[cfg(all(test, feature = "nightly"))]
285mod tests {
286 use std::fmt::Debug;
287 use crate::statistics::*;
288 use crate::distribution::{DiscreteCDF, Discrete, NegativeBinomial};
289 use crate::distribution::internal::test;
290 use crate::consts::ACC;
291
292 fn try_create(r: f64, p: f64) -> NegativeBinomial {
293 let r = NegativeBinomial::new(r, p);
294 assert!(r.is_ok());
295 r.unwrap()
296 }
297
298 fn create_case(r: f64, p: f64) {
299 let dist = try_create(r, p);
300 assert_eq!(p, dist.p());
301 assert_eq!(r, dist.r());
302 }
303
304 fn bad_create_case(r: f64, p: f64) {
305 let r = NegativeBinomial::new(r, p);
306 assert!(r.is_err());
307 }
308
309 fn get_value<T, F>(r: f64, p: f64, eval: F) -> T
310 where T: PartialEq + Debug,
311 F: Fn(NegativeBinomial) -> T
312 {
313 let r = try_create(r, p);
314 eval(r)
315 }
316
317 fn test_case<T, F>(r: f64, p: f64, expected: T, eval: F)
318 where T: PartialEq + Debug,
319 F: Fn(NegativeBinomial) -> T
320 {
321 let x = get_value(r, p, eval);
322 assert_eq!(expected, x);
323 }
324
325
326 fn test_case_or_nan<F>(r: f64, p: f64, expected: f64, eval: F)
327 where F: Fn(NegativeBinomial) -> f64
328 {
329 let x = get_value(r, p, eval);
330 if expected.is_nan() {
331 assert!(x.is_nan())
332 }
333 else {
334 assert_eq!(expected, x);
335 }
336 }
337 fn test_almost<F>(r: f64, p: f64, expected: f64, acc: f64, eval: F)
338 where F: Fn(NegativeBinomial) -> f64
339 {
340 let x = get_value(r, p, eval);
341 assert_almost_eq!(expected, x, acc);
342 }
343
344 #[test]
345 fn test_create() {
346 create_case(0.0, 0.0);
347 create_case(0.3, 0.4);
348 create_case(1.0, 0.3);
349 }
350
351 #[test]
352 fn test_bad_create() {
353 bad_create_case(f64::NAN, 1.0);
354 bad_create_case(0.0, f64::NAN);
355 bad_create_case(-1.0, 1.0);
356 bad_create_case(2.0, 2.0);
357 }
358
359 #[test]
360 fn test_mean() {
361 let mean = |x: NegativeBinomial| x.mean().unwrap();
362 test_case(4.0, 0.0, f64::INFINITY, mean);
363 test_almost(3.0, 0.3, 7.0, 1e-15 , mean);
364 test_case(2.0, 1.0, 0.0, mean);
365 }
366
367 #[test]
368 fn test_variance() {
369 let variance = |x: NegativeBinomial| x.variance().unwrap();
370 test_case(4.0, 0.0, f64::INFINITY, variance);
371 test_almost(3.0, 0.3, 23.333333333333, 1e-12, variance);
372 test_case(2.0, 1.0, 0.0, variance);
373 }
374
375 #[test]
376 fn test_skewness() {
377 let skewness = |x: NegativeBinomial| x.skewness().unwrap();
378 test_case(0.0, 0.0, f64::INFINITY, skewness);
379 test_almost(0.1, 0.3, 6.425396041, 1e-09, skewness);
380 test_case(1.0, 1.0, f64::INFINITY, skewness);
381 }
382
383 #[test]
384 fn test_mode() {
385 let mode = |x: NegativeBinomial| x.mode().unwrap();
386 test_case(0.0, 0.0, 0.0, mode);
387 test_case(0.3, 0.0, 0.0, mode);
388 test_case(1.0, 1.0, 0.0, mode);
389 test_case(10.0, 0.01, 891.0, mode);
390 }
391
392 #[test]
393 fn test_min_max() {
394 let min = |x: NegativeBinomial| x.min();
395 let max = |x: NegativeBinomial| x.max();
396 test_case(1.0, 0.5, 0, min);
397 test_case(1.0, 0.3, std::u64::MAX, max);
398 }
399
400 #[test]
401 fn test_pmf() {
402 let pmf = |arg: u64| move |x: NegativeBinomial| x.pmf(arg);
403 test_almost(4.0, 0.5, 0.0625, 1e-8, pmf(0));
404 test_almost(4.0, 0.5, 0.15625, 1e-8, pmf(3));
405 test_case(1.0, 0.0, 0.0, pmf(0));
406 test_case(1.0, 0.0, 0.0, pmf(1));
407 test_almost(3.0, 0.2, 0.008, 1e-15, pmf(0));
408 test_almost(3.0, 0.2, 0.0192, 1e-15, pmf(1));
409 test_almost(3.0, 0.2, 0.04096, 1e-15, pmf(3));
410 test_almost(10.0, 0.2, 1.024e-07, 1e-07, pmf(0));
411 test_almost(10.0, 0.2, 8.192e-07, 1e-07, pmf(1));
412 test_almost(10.0, 0.2, 0.001015706852, 1e-07, pmf(10));
413 test_almost(1.0, 0.3, 0.3, 1e-15, pmf(0));
414 test_almost(1.0, 0.3, 0.21, 1e-15, pmf(1));
415 test_almost(3.0, 0.3, 0.027, 1e-15, pmf(0));
416 test_case(0.3, 1.0, 0.0, pmf(1));
417 test_case(0.3, 1.0, 0.0, pmf(3));
418 test_case_or_nan(0.3, 1.0, f64::NAN, pmf(0));
419 test_case(0.3, 1.0, 0.0, pmf(1));
420 test_case(0.3, 1.0, 0.0, pmf(10));
421 test_case_or_nan(1.0, 1.0, f64::NAN, pmf(0));
422 test_case(1.0, 1.0, 0.0, pmf(1));
423 test_case_or_nan(3.0, 1.0, f64::NAN, pmf(0));
424 test_case(3.0, 1.0, 0.0, pmf(1));
425 test_case(3.0, 1.0, 0.0, pmf(3));
426 test_case_or_nan(10.0, 1.0, f64::NAN, pmf(0));
427 test_case(10.0, 1.0, 0.0, pmf(1));
428 test_case(10.0, 1.0, 0.0, pmf(10));
429 }
430
431 #[test]
432 fn test_ln_pmf() {
433 let ln_pmf = |arg: u64| move |x: NegativeBinomial| x.ln_pmf(arg);
434 test_case(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(0));
435 test_case(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(1));
436 test_almost(3.0, 0.2, -4.828313737, 1e-08, ln_pmf(0));
437 test_almost(3.0, 0.2, -3.952845, 1e-08, ln_pmf(1));
438 test_almost(3.0, 0.2, -3.195159298, 1e-08, ln_pmf(3));
439 test_almost(10.0, 0.2, -16.09437912, 1e-08, ln_pmf(0));
440 test_almost(10.0, 0.2, -14.01493758, 1e-08, ln_pmf(1));
441 test_almost(10.0, 0.2, -6.892170503, 1e-08, ln_pmf(10));
442 test_almost(1.0, 0.3, -1.203972804, 1e-08, ln_pmf(0));
443 test_almost(1.0, 0.3, -1.560647748, 1e-08, ln_pmf(1));
444 test_almost(3.0, 0.3, -3.611918413, 1e-08, ln_pmf(0));
445 test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1));
446 test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(3));
447 test_case_or_nan(0.3, 1.0, f64::NAN, ln_pmf(0));
448 test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1));
449 test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(10));
450 test_case_or_nan(1.0, 1.0, f64::NAN, ln_pmf(0));
451 test_case(1.0, 1.0, f64::NEG_INFINITY, ln_pmf(1));
452 test_case_or_nan(3.0, 1.0, f64::NAN, ln_pmf(0));
453 test_case(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(1));
454 test_case(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(3));
455 test_case_or_nan(10.0, 1.0, f64::NAN, ln_pmf(0));
456 test_case(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(1));
457 test_case(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(10));
458 }
459
460 #[test]
461 fn test_cdf() {
462 let cdf = |arg: u64| move |x: NegativeBinomial| x.cdf(arg);
463 test_almost(1.0, 0.3, 0.3, 1e-08, cdf(0));
464 test_almost(1.0, 0.3, 0.51, 1e-08, cdf(1));
465 test_almost(1.0, 0.3, 0.83193, 1e-08, cdf(4));
466 test_almost(1.0, 0.3, 0.9802267326, 1e-08, cdf(10));
467 test_case(1.0, 1.0, 1.0, cdf(0));
468 test_case(1.0, 1.0, 1.0, cdf(1));
469 test_almost(10.0, 0.75, 0.05631351471, 1e-08, cdf(0));
470 test_almost(10.0, 0.75, 0.1970973015, 1e-08, cdf(1));
471 test_almost(10.0, 0.75, 0.9960578583, 1e-08, cdf(10));
472 }
473
474 #[test]
475 fn test_sf() {
476 let sf = |arg: u64| move |x: NegativeBinomial| x.sf(arg);
477 test_almost(1.0, 0.3, 0.7, 1e-08, sf(0));
478 test_almost(1.0, 0.3, 0.49, 1e-08, sf(1));
479 test_almost(1.0, 0.3, 0.1680699999999986, 1e-08, sf(4));
480 test_almost(1.0, 0.3, 0.019773267430000074, 1e-08, sf(10));
481 test_case(1.0, 1.0, 0.0, sf(0));
482 test_case(1.0, 1.0, 0.0, sf(1));
483 test_almost(10.0, 0.75, 0.9436864852905275, 1e-08, sf(0));
484 test_almost(10.0, 0.75, 0.8029026985168456, 1e-08, sf(1));
485 test_almost(10.0, 0.75, 0.003942141664083465, 1e-08, sf(10));
486 }
487
488 #[test]
489 fn test_cdf_upper_bound() {
490 let cdf = |arg: u64| move |x: NegativeBinomial| x.cdf(arg);
491 test_case(3.0, 0.5, 1.0, cdf(100));
492 }
493
494 #[test]
495 fn test_discrete() {
496 test::check_discrete_distribution(&try_create(5.0, 0.3), 35);
497 test::check_discrete_distribution(&try_create(10.0, 0.7), 21);
498 }
499
500 #[test]
501 fn test_sf_upper_bound() {
502 let sf = |arg: u64| move |x: NegativeBinomial| x.sf(arg);
503 test_almost(3.0, 0.5, 5.282409836586059e-28, 1e-28, sf(100));
504 }
505}