statrs/distribution/
geometric.rs

1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::statistics::*;
3use crate::{Result, StatsError};
4use rand::distributions::OpenClosed01;
5use rand::Rng;
6use std::{f64, u64};
7
8/// Implements the
9/// [Geometric](https://en.wikipedia.org/wiki/Geometric_distribution)
10/// distribution
11///
12/// # Examples
13///
14/// ```
15/// use statrs::distribution::{Geometric, Discrete};
16/// use statrs::statistics::Distribution;
17///
18/// let n = Geometric::new(0.3).unwrap();
19/// assert_eq!(n.mean().unwrap(), 1.0 / 0.3);
20/// assert_eq!(n.pmf(1), 0.3);
21/// assert_eq!(n.pmf(2), 0.21);
22/// ```
23#[derive(Debug, Copy, Clone, PartialEq)]
24pub struct Geometric {
25    p: f64,
26}
27
28impl Geometric {
29    /// Constructs a new shifted geometric distribution with a probability
30    /// of `p`
31    ///
32    /// # Errors
33    ///
34    /// Returns an error if `p` is not in `(0, 1]`
35    ///
36    /// # Examples
37    ///
38    /// ```
39    /// use statrs::distribution::Geometric;
40    ///
41    /// let mut result = Geometric::new(0.5);
42    /// assert!(result.is_ok());
43    ///
44    /// result = Geometric::new(0.0);
45    /// assert!(result.is_err());
46    /// ```
47    pub fn new(p: f64) -> Result<Geometric> {
48        if p <= 0.0 || p > 1.0 || p.is_nan() {
49            Err(StatsError::BadParams)
50        } else {
51            Ok(Geometric { p })
52        }
53    }
54
55    /// Returns the probability `p` of the geometric
56    /// distribution
57    ///
58    /// # Examples
59    ///
60    /// ```
61    /// use statrs::distribution::Geometric;
62    ///
63    /// let n = Geometric::new(0.5).unwrap();
64    /// assert_eq!(n.p(), 0.5);
65    /// ```
66    pub fn p(&self) -> f64 {
67        self.p
68    }
69}
70
71impl ::rand::distributions::Distribution<f64> for Geometric {
72    fn sample<R: Rng + ?Sized>(&self, r: &mut R) -> f64 {
73        if ulps_eq!(self.p, 1.0) {
74            1.0
75        } else {
76            let x: f64 = r.sample(OpenClosed01);
77            x.log(1.0 - self.p).ceil()
78        }
79    }
80}
81
82impl DiscreteCDF<u64, f64> for Geometric {
83    /// Calculates the cumulative distribution function for the geometric
84    /// distribution at `x`
85    ///
86    /// # Formula
87    ///
88    /// ```ignore
89    /// 1 - (1 - p) ^ x
90    /// ```
91    fn cdf(&self, x: u64) -> f64 {
92        if x == 0 {
93            0.0
94        } else {
95            // 1 - (1 - p) ^ x = 1 - exp(log(1 - p)*x)
96            //                 = -expm1(log1p(-p)*x))
97            //                 = -((-p).ln_1p() * x).exp_m1()
98            -((-self.p).ln_1p() * (x as f64)).exp_m1()
99        }
100    }
101
102    /// Calculates the survival function for the geometric
103    /// distribution at `x`
104    ///
105    /// # Formula
106    ///
107    /// ```ignore
108    /// (1 - p) ^ x
109    /// ```
110    fn sf(&self, x: u64) -> f64 {
111        // (1-p) ^ x = exp(log(1-p)*x)
112        //           = exp(log1p(-p) * x)
113        if x == 0 {
114            1.0
115        } else {
116            ((-self.p).ln_1p() * (x as f64)).exp()
117        }
118    }
119}
120
121impl Min<u64> for Geometric {
122    /// Returns the minimum value in the domain of the
123    /// geometric distribution representable by a 64-bit
124    /// integer
125    ///
126    /// # Formula
127    ///
128    /// ```ignore
129    /// 1
130    /// ```
131    fn min(&self) -> u64 {
132        1
133    }
134}
135
136impl Max<u64> for Geometric {
137    /// Returns the maximum value in the domain of the
138    /// geometric distribution representable by a 64-bit
139    /// integer
140    ///
141    /// # Formula
142    ///
143    /// ```ignore
144    /// 2^63 - 1
145    /// ```
146    fn max(&self) -> u64 {
147        u64::MAX
148    }
149}
150
151impl Distribution<f64> for Geometric {
152    /// Returns the mean of the geometric distribution
153    ///
154    /// # Formula
155    ///
156    /// ```ignore
157    /// 1 / p
158    /// ```
159    fn mean(&self) -> Option<f64> {
160        Some(1.0 / self.p)
161    }
162    /// Returns the standard deviation of the geometric distribution
163    ///
164    /// # Formula
165    ///
166    /// ```ignore
167    /// (1 - p) / p^2
168    /// ```
169    fn variance(&self) -> Option<f64> {
170        Some((1.0 - self.p) / (self.p * self.p))
171    }
172    /// Returns the entropy of the geometric distribution
173    ///
174    /// # Formula
175    ///
176    /// ```ignore
177    /// (-(1 - p) * log_2(1 - p) - p * log_2(p)) / p
178    /// ```
179    fn entropy(&self) -> Option<f64> {
180        let inv = 1.0 / self.p;
181        Some(-inv * (1. - self.p).log(2.0) + (inv - 1.).log(2.0))
182    }
183    /// Returns the skewness of the geometric distribution
184    ///
185    /// # Formula
186    ///
187    /// ```ignore
188    /// (2 - p) / sqrt(1 - p)
189    /// ```
190    fn skewness(&self) -> Option<f64> {
191        if ulps_eq!(self.p, 1.0) {
192            return Some(f64::INFINITY);
193        };
194        Some((2.0 - self.p) / (1.0 - self.p).sqrt())
195    }
196}
197
198impl Mode<Option<u64>> for Geometric {
199    /// Returns the mode of the geometric distribution
200    ///
201    /// # Formula
202    ///
203    /// ```ignore
204    /// 1
205    /// ```
206    fn mode(&self) -> Option<u64> {
207        Some(1)
208    }
209}
210
211impl Median<f64> for Geometric {
212    /// Returns the median of the geometric distribution
213    ///
214    /// # Remarks
215    ///
216    /// # Formula
217    ///
218    /// ```ignore
219    /// ceil(-1 / log_2(1 - p))
220    /// ```
221    fn median(&self) -> f64 {
222        (-f64::consts::LN_2 / (1.0 - self.p).ln()).ceil()
223    }
224}
225
226impl Discrete<u64, f64> for Geometric {
227    /// Calculates the probability mass function for the geometric
228    /// distribution at `x`
229    ///
230    /// # Formula
231    ///
232    /// ```ignore
233    /// (1 - p)^(x - 1) * p
234    /// ```
235    fn pmf(&self, x: u64) -> f64 {
236        if x == 0 {
237            0.0
238        } else {
239            (1.0 - self.p).powi(x as i32 - 1) * self.p
240        }
241    }
242
243    /// Calculates the log probability mass function for the geometric
244    /// distribution at `x`
245    ///
246    /// # Formula
247    ///
248    /// ```ignore
249    /// ln((1 - p)^(x - 1) * p)
250    /// ```
251    fn ln_pmf(&self, x: u64) -> f64 {
252        if x == 0 {
253            f64::NEG_INFINITY
254        } else if ulps_eq!(self.p, 1.0) && x == 1 {
255            0.0
256        } else if ulps_eq!(self.p, 1.0) {
257            f64::NEG_INFINITY
258        } else {
259            ((x - 1) as f64 * (1.0 - self.p).ln()) + self.p.ln()
260        }
261    }
262}
263
264#[rustfmt::skip]
265#[cfg(all(test, feature = "nightly"))]
266mod tests {
267    use std::fmt::Debug;
268    use crate::statistics::*;
269    use crate::distribution::{DiscreteCDF, Discrete, Geometric};
270    use crate::distribution::internal::*;
271    use crate::consts::ACC;
272
273    fn try_create(p: f64) -> Geometric {
274        let n = Geometric::new(p);
275        assert!(n.is_ok());
276        n.unwrap()
277    }
278
279    fn create_case(p: f64) {
280        let n = try_create(p);
281        assert_eq!(p, n.p());
282    }
283
284    fn bad_create_case(p: f64) {
285        let n = Geometric::new(p);
286        assert!(n.is_err());
287    }
288
289    fn get_value<T, F>(p: f64, eval: F) -> T
290        where T: PartialEq + Debug,
291              F: Fn(Geometric) -> T
292    {
293        let n = try_create(p);
294        eval(n)
295    }
296
297    fn test_case<T, F>(p: f64, expected: T, eval: F)
298        where T: PartialEq + Debug,
299              F: Fn(Geometric) -> T
300    {
301        let x = get_value(p, eval);
302        assert_eq!(expected, x);
303    }
304
305    fn test_almost<F>(p: f64, expected: f64, acc: f64, eval: F)
306        where F: Fn(Geometric) -> f64
307    {
308        let x = get_value(p, eval);
309        assert_almost_eq!(expected, x, acc);
310    }
311
312    fn test_is_nan<F>(p: f64, eval: F)
313        where F: Fn(Geometric) -> f64
314    {
315        let x = get_value(p, eval);
316        assert!(x.is_nan());
317    }
318
319    #[test]
320    fn test_create() {
321        create_case(0.3);
322        create_case(1.0);
323    }
324
325    #[test]
326    fn test_bad_create() {
327        bad_create_case(f64::NAN);
328        bad_create_case(0.0);
329        bad_create_case(-1.0);
330        bad_create_case(2.0);
331    }
332
333    #[test]
334    fn test_mean() {
335        let mean = |x: Geometric| x.mean().unwrap();
336        test_case(0.3, 1.0 / 0.3, mean);
337        test_case(1.0, 1.0, mean);
338    }
339
340    #[test]
341    fn test_variance() {
342        let variance = |x: Geometric| x.variance().unwrap();
343        test_case(0.3, 0.7 / (0.3 * 0.3), variance);
344        test_case(1.0, 0.0, variance);
345    }
346
347    #[test]
348    fn test_entropy() {
349        let entropy = |x: Geometric| x.entropy().unwrap();
350        test_almost(0.3, 2.937636330768973333333, 1e-14, entropy);
351        test_is_nan(1.0, entropy);
352    }
353
354    #[test]
355    fn test_skewness() {
356        let skewness = |x: Geometric| x.skewness().unwrap();
357        test_almost(0.3, 2.031888635868469187947, 1e-15, skewness);
358        test_case(1.0, f64::INFINITY, skewness);
359    }
360
361    #[test]
362    fn test_median() {
363        let median = |x: Geometric| x.median();
364        test_case(0.0001, 6932.0, median);
365        test_case(0.1, 7.0, median);
366        test_case(0.3, 2.0, median);
367        test_case(0.9, 1.0, median);
368        // test_case(0.99, 1.0, median);
369        test_case(1.0, 0.0, median);
370    }
371
372    #[test]
373    fn test_mode() {
374        let mode = |x: Geometric| x.mode().unwrap();
375        test_case(0.3, 1, mode);
376        test_case(1.0, 1, mode);
377    }
378
379    #[test]
380    fn test_min_max() {
381        let min = |x: Geometric| x.min();
382        let max = |x: Geometric| x.max();
383        test_case(0.3, 1, min);
384        test_case(0.3, u64::MAX, max);
385    }
386
387    #[test]
388    fn test_pmf() {
389        let pmf = |arg: u64| move |x: Geometric| x.pmf(arg);
390        test_case(0.3, 0.3, pmf(1));
391        test_case(0.3, 0.21, pmf(2));
392        test_case(1.0, 1.0, pmf(1));
393        test_case(1.0, 0.0, pmf(2));
394        test_almost(0.5, 0.5, 1e-10, pmf(1));
395        test_almost(0.5, 0.25, 1e-10, pmf(2));
396    }
397
398    #[test]
399    fn test_pmf_lower_bound() {
400        let pmf = |arg: u64| move |x: Geometric| x.pmf(arg);
401        test_case(0.3, 0.0, pmf(0));
402    }
403
404    #[test]
405    fn test_ln_pmf() {
406        let ln_pmf = |arg: u64| move |x: Geometric| x.ln_pmf(arg);
407        test_almost(0.3, -1.203972804325935992623, 1e-15, ln_pmf(1));
408        test_almost(0.3, -1.560647748264668371535, 1e-15, ln_pmf(2));
409        test_case(1.0, 0.0, ln_pmf(1));
410        test_case(1.0, f64::NEG_INFINITY, ln_pmf(2));
411    }
412
413    #[test]
414    fn test_ln_pmf_lower_bound() {
415        let ln_pmf = |arg: u64| move |x: Geometric| x.ln_pmf(arg);
416        test_case(0.3, f64::NEG_INFINITY, ln_pmf(0));
417    }
418
419    #[test]
420    fn test_cdf() {
421        let cdf = |arg: u64| move |x: Geometric| x.cdf(arg);
422        test_case(1.0, 1.0, cdf(1));
423        test_case(1.0, 1.0, cdf(2));
424        test_almost(0.5, 0.5, 1e-15, cdf(1));
425        test_almost(0.5, 0.75, 1e-15, cdf(2));
426    }
427
428    #[test]
429    fn test_sf() {
430        let sf = |arg: u64| move |x: Geometric| x.sf(arg);
431        test_case(1.0, 0.0, sf(1));
432        test_case(1.0, 0.0, sf(2));
433        test_almost(0.5, 0.5, 1e-15, sf(1));
434        test_almost(0.5, 0.25, 1e-15, sf(2));
435    }
436
437    #[test]
438    fn test_cdf_small_p() {
439        //
440        // Expected values were computed with the arbitrary precision
441        // library mpmath in Python, e.g.:
442        //
443        //   import mpmath
444        //   mpmath.mp.dps = 400
445        //   p = mpmath.mpf(1e-9)
446        //   k = 5
447        //   cdf = float(1 - (1 - p)**k)
448        //   # cdf is 4.99999999e-09
449        //
450        let geom = Geometric::new(1e-9f64).unwrap();
451
452        let cdf = geom.cdf(5u64);
453        let expected = 4.99999999e-09;
454        assert_relative_eq!(cdf, expected, epsilon = 0.0, max_relative = 1e-15);
455    }
456
457    #[test]
458    fn test_sf_small_p() {
459        let geom = Geometric::new(1e-9f64).unwrap();
460
461        let sf = geom.sf(5u64);
462        let expected = 0.999999995;
463        assert_relative_eq!(sf, expected, epsilon = 0.0, max_relative = 1e-15);
464    }
465
466    #[test]
467    fn test_cdf_very_small_p() {
468        //
469        // Expected values were computed with the arbitrary precision
470        // library mpmath in Python, e.g.:
471        //
472        //   import mpmath
473        //   mpmath.mp.dps = 400
474        //   p = mpmath.mpf(1e-17)
475        //   k = 100000000000000
476        //   cdf = float(1 - (1 - p)**k)
477        //   # cdf is 0.0009995001666250085
478        //
479        let geom = Geometric::new(1e-17f64).unwrap();
480
481        let cdf = geom.cdf(10u64);
482        let expected = 1e-16f64;
483        assert_relative_eq!(cdf, expected, epsilon = 0.0, max_relative = 1e-15);
484
485        let cdf = geom.cdf(100000000000000u64);
486        let expected = 0.0009995001666250085f64;
487        assert_relative_eq!(cdf, expected, epsilon = 0.0, max_relative = 1e-15);
488    }
489
490    #[test]
491    fn test_sf_very_small_p() {
492        let geom = Geometric::new(1e-17f64).unwrap();
493
494        let sf = geom.sf(10u64);
495        let expected =  0.9999999999999999;
496        assert_relative_eq!(sf, expected, epsilon = 0.0, max_relative = 1e-15);
497
498        let sf = geom.sf(100000000000000u64);
499        let expected = 0.999000499833375;
500        assert_relative_eq!(sf, expected, epsilon = 0.0, max_relative = 1e-15);
501    }
502
503    #[test]
504    fn test_cdf_lower_bound() {
505        let cdf = |arg: u64| move |x: Geometric| x.cdf(arg);
506        test_case(0.3, 0.0, cdf(0));
507    }
508
509    #[test]
510    fn test_sf_lower_bound() {
511        let sf = |arg: u64| move |x: Geometric| x.sf(arg);
512        test_case(0.3, 1.0, sf(0));
513    }
514
515    #[test]
516    fn test_discrete() {
517        test::check_discrete_distribution(&try_create(0.3), 100);
518        test::check_discrete_distribution(&try_create(0.6), 100);
519        test::check_discrete_distribution(&try_create(1.0), 1);
520    }
521}