statrs/distribution/
geometric.rs

1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::statistics::*;
3use std::f64;
4
5/// Implements the
6/// [Geometric](https://en.wikipedia.org/wiki/Geometric_distribution)
7/// distribution
8///
9/// # Examples
10///
11/// ```
12/// use statrs::distribution::{Geometric, Discrete};
13/// use statrs::statistics::Distribution;
14///
15/// let n = Geometric::new(0.3).unwrap();
16/// assert_eq!(n.mean().unwrap(), 1.0 / 0.3);
17/// assert_eq!(n.pmf(1), 0.3);
18/// assert_eq!(n.pmf(2), 0.21);
19/// ```
20#[derive(Copy, Clone, PartialEq, Debug)]
21pub struct Geometric {
22    p: f64,
23}
24
25/// Represents the errors that can occur when creating a [`Geometric`].
26#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
27#[non_exhaustive]
28pub enum GeometricError {
29    /// The probability is NaN or not in `(0, 1]`.
30    ProbabilityInvalid,
31}
32
33impl std::fmt::Display for GeometricError {
34    #[cfg_attr(coverage_nightly, coverage(off))]
35    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
36        match self {
37            GeometricError::ProbabilityInvalid => write!(f, "Probability is NaN or not in (0, 1]"),
38        }
39    }
40}
41
42impl std::error::Error for GeometricError {}
43
44impl Geometric {
45    /// Constructs a new shifted geometric distribution with a probability
46    /// of `p`
47    ///
48    /// # Errors
49    ///
50    /// Returns an error if `p` is not in `(0, 1]`
51    ///
52    /// # Examples
53    ///
54    /// ```
55    /// use statrs::distribution::Geometric;
56    ///
57    /// let mut result = Geometric::new(0.5);
58    /// assert!(result.is_ok());
59    ///
60    /// result = Geometric::new(0.0);
61    /// assert!(result.is_err());
62    /// ```
63    pub fn new(p: f64) -> Result<Geometric, GeometricError> {
64        if p <= 0.0 || p > 1.0 || p.is_nan() {
65            Err(GeometricError::ProbabilityInvalid)
66        } else {
67            Ok(Geometric { p })
68        }
69    }
70
71    /// Returns the probability `p` of the geometric
72    /// distribution
73    ///
74    /// # Examples
75    ///
76    /// ```
77    /// use statrs::distribution::Geometric;
78    ///
79    /// let n = Geometric::new(0.5).unwrap();
80    /// assert_eq!(n.p(), 0.5);
81    /// ```
82    pub fn p(&self) -> f64 {
83        self.p
84    }
85}
86
87impl std::fmt::Display for Geometric {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        write!(f, "Geom({})", self.p)
90    }
91}
92
93#[cfg(feature = "rand")]
94#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
95impl ::rand::distributions::Distribution<u64> for Geometric {
96    fn sample<R: ::rand::Rng + ?Sized>(&self, r: &mut R) -> u64 {
97        if ulps_eq!(self.p, 1.0) {
98            1
99        } else {
100            let x: f64 = r.sample(::rand::distributions::OpenClosed01);
101            // This cast is safe, because the largest finite value this expression can take is when
102            // `x = 1.4e-45` and `1.0 - self.p = 0.9999999999999999`, in which case we get
103            // `930262250532780300`, which when casted to a `u64` is `930262250532780288`.
104            x.log(1.0 - self.p).ceil() as u64
105        }
106    }
107}
108
109#[cfg(feature = "rand")]
110#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
111impl ::rand::distributions::Distribution<f64> for Geometric {
112    fn sample<R: ::rand::Rng + ?Sized>(&self, r: &mut R) -> f64 {
113        r.sample::<u64, _>(self) as f64
114    }
115}
116
117impl DiscreteCDF<u64, f64> for Geometric {
118    /// Calculates the cumulative distribution function for the geometric
119    /// distribution at `x`
120    ///
121    /// # Formula
122    ///
123    /// ```text
124    /// 1 - (1 - p) ^ x
125    /// ```
126    fn cdf(&self, x: u64) -> f64 {
127        if x == 0 {
128            0.0
129        } else {
130            // 1 - (1 - p) ^ x = 1 - exp(log(1 - p)*x)
131            //                 = -expm1(log1p(-p)*x))
132            //                 = -((-p).ln_1p() * x).exp_m1()
133            -((-self.p).ln_1p() * (x as f64)).exp_m1()
134        }
135    }
136
137    /// Calculates the survival function for the geometric
138    /// distribution at `x`
139    ///
140    /// # Formula
141    ///
142    /// ```text
143    /// (1 - p) ^ x
144    /// ```
145    fn sf(&self, x: u64) -> f64 {
146        // (1-p) ^ x = exp(log(1-p)*x)
147        //           = exp(log1p(-p) * x)
148        if x == 0 {
149            1.0
150        } else {
151            ((-self.p).ln_1p() * (x as f64)).exp()
152        }
153    }
154}
155
156impl Min<u64> for Geometric {
157    /// Returns the minimum value in the domain of the
158    /// geometric distribution representable by a 64-bit
159    /// integer
160    ///
161    /// # Formula
162    ///
163    /// ```text
164    /// 1
165    /// ```
166    fn min(&self) -> u64 {
167        1
168    }
169}
170
171impl Max<u64> for Geometric {
172    /// Returns the maximum value in the domain of the
173    /// geometric distribution representable by a 64-bit
174    /// integer
175    ///
176    /// # Formula
177    ///
178    /// ```text
179    /// 2^63 - 1
180    /// ```
181    fn max(&self) -> u64 {
182        u64::MAX
183    }
184}
185
186impl Distribution<f64> for Geometric {
187    /// Returns the mean of the geometric distribution
188    ///
189    /// # Formula
190    ///
191    /// ```text
192    /// 1 / p
193    /// ```
194    fn mean(&self) -> Option<f64> {
195        Some(1.0 / self.p)
196    }
197
198    /// Returns the standard deviation of the geometric distribution
199    ///
200    /// # Formula
201    ///
202    /// ```text
203    /// (1 - p) / p^2
204    /// ```
205    fn variance(&self) -> Option<f64> {
206        Some((1.0 - self.p) / (self.p * self.p))
207    }
208
209    /// Returns the entropy of the geometric distribution
210    ///
211    /// # Formula
212    ///
213    /// ```text
214    /// (-(1 - p) * log_2(1 - p) - p * log_2(p)) / p
215    /// ```
216    fn entropy(&self) -> Option<f64> {
217        let inv = 1.0 / self.p;
218        Some(-inv * (1. - self.p).log(2.0) + (inv - 1.).log(2.0))
219    }
220
221    /// Returns the skewness of the geometric distribution
222    ///
223    /// # Formula
224    ///
225    /// ```text
226    /// (2 - p) / sqrt(1 - p)
227    /// ```
228    fn skewness(&self) -> Option<f64> {
229        if ulps_eq!(self.p, 1.0) {
230            return Some(f64::INFINITY);
231        };
232        Some((2.0 - self.p) / (1.0 - self.p).sqrt())
233    }
234}
235
236impl Mode<Option<u64>> for Geometric {
237    /// Returns the mode of the geometric distribution
238    ///
239    /// # Formula
240    ///
241    /// ```text
242    /// 1
243    /// ```
244    fn mode(&self) -> Option<u64> {
245        Some(1)
246    }
247}
248
249impl Median<f64> for Geometric {
250    /// Returns the median of the geometric distribution
251    ///
252    /// # Remarks
253    ///
254    /// # Formula
255    ///
256    /// ```text
257    /// ceil(-1 / log_2(1 - p))
258    /// ```
259    fn median(&self) -> f64 {
260        (-f64::consts::LN_2 / (1.0 - self.p).ln()).ceil()
261    }
262}
263
264impl Discrete<u64, f64> for Geometric {
265    /// Calculates the probability mass function for the geometric
266    /// distribution at `x`
267    ///
268    /// # Formula
269    ///
270    /// ```text
271    /// (1 - p)^(x - 1) * p
272    /// ```
273    fn pmf(&self, x: u64) -> f64 {
274        if x == 0 {
275            0.0
276        } else {
277            (1.0 - self.p).powi(x as i32 - 1) * self.p
278        }
279    }
280
281    /// Calculates the log probability mass function for the geometric
282    /// distribution at `x`
283    ///
284    /// # Formula
285    ///
286    /// ```text
287    /// ln((1 - p)^(x - 1) * p)
288    /// ```
289    fn ln_pmf(&self, x: u64) -> f64 {
290        if x == 0 {
291            f64::NEG_INFINITY
292        } else if ulps_eq!(self.p, 1.0) && x == 1 {
293            0.0
294        } else if ulps_eq!(self.p, 1.0) {
295            f64::NEG_INFINITY
296        } else {
297            ((x - 1) as f64 * (1.0 - self.p).ln()) + self.p.ln()
298        }
299    }
300}
301
302#[rustfmt::skip]
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use crate::distribution::internal::*;
307    use crate::testing_boiler;
308
309    testing_boiler!(p: f64; Geometric; GeometricError);
310
311    #[test]
312    fn test_create() {
313        create_ok(0.3);
314        create_ok(1.0);
315    }
316
317    #[test]
318    fn test_bad_create() {
319        create_err(f64::NAN);
320        create_err(0.0);
321        create_err(-1.0);
322        create_err(2.0);
323    }
324
325    #[test]
326    fn test_mean() {
327        let mean = |x: Geometric| x.mean().unwrap();
328        test_exact(0.3, 1.0 / 0.3, mean);
329        test_exact(1.0, 1.0, mean);
330    }
331
332    #[test]
333    fn test_variance() {
334        let variance = |x: Geometric| x.variance().unwrap();
335        test_exact(0.3, 0.7 / (0.3 * 0.3), variance);
336        test_exact(1.0, 0.0, variance);
337    }
338
339    #[test]
340    fn test_entropy() {
341        let entropy = |x: Geometric| x.entropy().unwrap();
342        test_absolute(0.3, 2.937636330768973333333, 1e-14, entropy);
343        test_is_nan(1.0, entropy);
344    }
345
346    #[test]
347    fn test_skewness() {
348        let skewness = |x: Geometric| x.skewness().unwrap();
349        test_absolute(0.3, 2.031888635868469187947, 1e-15, skewness);
350        test_exact(1.0, f64::INFINITY, skewness);
351    }
352
353    #[test]
354    fn test_median() {
355        let median = |x: Geometric| x.median();
356        test_exact(0.0001, 6932.0, median);
357        test_exact(0.1, 7.0, median);
358        test_exact(0.3, 2.0, median);
359        test_exact(0.9, 1.0, median);
360        // test_exact(0.99, 1.0, median);
361        test_exact(1.0, 0.0, median);
362    }
363
364    #[test]
365    fn test_mode() {
366        let mode = |x: Geometric| x.mode().unwrap();
367        test_exact(0.3, 1, mode);
368        test_exact(1.0, 1, mode);
369    }
370
371    #[test]
372    fn test_min_max() {
373        let min = |x: Geometric| x.min();
374        let max = |x: Geometric| x.max();
375        test_exact(0.3, 1, min);
376        test_exact(0.3, u64::MAX, max);
377    }
378
379    #[test]
380    fn test_pmf() {
381        let pmf = |arg: u64| move |x: Geometric| x.pmf(arg);
382        test_exact(0.3, 0.3, pmf(1));
383        test_exact(0.3, 0.21, pmf(2));
384        test_exact(1.0, 1.0, pmf(1));
385        test_exact(1.0, 0.0, pmf(2));
386        test_absolute(0.5, 0.5, 1e-10, pmf(1));
387        test_absolute(0.5, 0.25, 1e-10, pmf(2));
388    }
389
390    #[test]
391    fn test_pmf_lower_bound() {
392        let pmf = |arg: u64| move |x: Geometric| x.pmf(arg);
393        test_exact(0.3, 0.0, pmf(0));
394    }
395
396    #[test]
397    fn test_ln_pmf() {
398        let ln_pmf = |arg: u64| move |x: Geometric| x.ln_pmf(arg);
399        test_absolute(0.3, -1.203972804325935992623, 1e-15, ln_pmf(1));
400        test_absolute(0.3, -1.560647748264668371535, 1e-15, ln_pmf(2));
401        test_exact(1.0, 0.0, ln_pmf(1));
402        test_exact(1.0, f64::NEG_INFINITY, ln_pmf(2));
403    }
404
405    #[test]
406    fn test_ln_pmf_lower_bound() {
407        let ln_pmf = |arg: u64| move |x: Geometric| x.ln_pmf(arg);
408        test_exact(0.3, f64::NEG_INFINITY, ln_pmf(0));
409    }
410
411    #[test]
412    fn test_cdf() {
413        let cdf = |arg: u64| move |x: Geometric| x.cdf(arg);
414        test_exact(1.0, 1.0, cdf(1));
415        test_exact(1.0, 1.0, cdf(2));
416        test_absolute(0.5, 0.5, 1e-15, cdf(1));
417        test_absolute(0.5, 0.75, 1e-15, cdf(2));
418    }
419
420    #[test]
421    fn test_sf() {
422        let sf = |arg: u64| move |x: Geometric| x.sf(arg);
423        test_exact(1.0, 0.0, sf(1));
424        test_exact(1.0, 0.0, sf(2));
425        test_absolute(0.5, 0.5, 1e-15, sf(1));
426        test_absolute(0.5, 0.25, 1e-15, sf(2));
427    }
428
429    #[test]
430    fn test_cdf_small_p() {
431        //
432        // Expected values were computed with the arbitrary precision
433        // library mpmath in Python, e.g.:
434        //
435        //   import mpmath
436        //   mpmath.mp.dps = 400
437        //   p = mpmath.mpf(1e-9)
438        //   k = 5
439        //   cdf = float(1 - (1 - p)**k)
440        //   # cdf is 4.99999999e-09
441        //
442        let geom = Geometric::new(1e-9f64).unwrap();
443
444        let cdf = geom.cdf(5u64);
445        let expected = 4.99999999e-09;
446        assert_relative_eq!(cdf, expected, epsilon = 0.0, max_relative = 1e-15);
447    }
448
449    #[test]
450    fn test_sf_small_p() {
451        let geom = Geometric::new(1e-9f64).unwrap();
452
453        let sf = geom.sf(5u64);
454        let expected = 0.999999995;
455        assert_relative_eq!(sf, expected, epsilon = 0.0, max_relative = 1e-15);
456    }
457
458    #[test]
459    fn test_cdf_very_small_p() {
460        //
461        // Expected values were computed with the arbitrary precision
462        // library mpmath in Python, e.g.:
463        //
464        //   import mpmath
465        //   mpmath.mp.dps = 400
466        //   p = mpmath.mpf(1e-17)
467        //   k = 100000000000000
468        //   cdf = float(1 - (1 - p)**k)
469        //   # cdf is 0.0009995001666250085
470        //
471        let geom = Geometric::new(1e-17f64).unwrap();
472
473        let cdf = geom.cdf(10u64);
474        let expected = 1e-16f64;
475        assert_relative_eq!(cdf, expected, epsilon = 0.0, max_relative = 1e-15);
476
477        let cdf = geom.cdf(100000000000000u64);
478        let expected = 0.0009995001666250085f64;
479        assert_relative_eq!(cdf, expected, epsilon = 0.0, max_relative = 1e-15);
480    }
481
482    #[test]
483    fn test_sf_very_small_p() {
484        let geom = Geometric::new(1e-17f64).unwrap();
485
486        let sf = geom.sf(10u64);
487        let expected =  0.9999999999999999;
488        assert_relative_eq!(sf, expected, epsilon = 0.0, max_relative = 1e-15);
489
490        let sf = geom.sf(100000000000000u64);
491        let expected = 0.999000499833375;
492        assert_relative_eq!(sf, expected, epsilon = 0.0, max_relative = 1e-15);
493    }
494
495    #[test]
496    fn test_cdf_lower_bound() {
497        let cdf = |arg: u64| move |x: Geometric| x.cdf(arg);
498        test_exact(0.3, 0.0, cdf(0));
499    }
500
501    #[test]
502    fn test_sf_lower_bound() {
503        let sf = |arg: u64| move |x: Geometric| x.sf(arg);
504        test_exact(0.3, 1.0, sf(0));
505    }
506
507    #[test]
508    fn test_discrete() {
509        test::check_discrete_distribution(&create_ok(0.3), 100);
510        test::check_discrete_distribution(&create_ok(0.6), 100);
511        test::check_discrete_distribution(&create_ok(1.0), 1);
512    }
513}