statrs/distribution/
discrete_uniform.rs

1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::statistics::*;
3
4/// Implements the [Discrete
5/// Uniform](https://en.wikipedia.org/wiki/Discrete_uniform_distribution)
6/// distribution
7///
8/// # Examples
9///
10/// ```
11/// use statrs::distribution::{DiscreteUniform, Discrete};
12/// use statrs::statistics::Distribution;
13///
14/// let n = DiscreteUniform::new(0, 5).unwrap();
15/// assert_eq!(n.mean().unwrap(), 2.5);
16/// assert_eq!(n.pmf(3), 1.0 / 6.0);
17/// ```
18#[derive(Copy, Clone, PartialEq, Eq, Debug)]
19pub struct DiscreteUniform {
20    min: i64,
21    max: i64,
22}
23
24/// Represents the errors that can occur when creating a [`DiscreteUniform`].
25#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
26#[non_exhaustive]
27pub enum DiscreteUniformError {
28    /// The maximum is less than the minimum.
29    MinMaxInvalid,
30}
31
32impl std::fmt::Display for DiscreteUniformError {
33    #[cfg_attr(coverage_nightly, coverage(off))]
34    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
35        match self {
36            DiscreteUniformError::MinMaxInvalid => write!(f, "Maximum is less than minimum"),
37        }
38    }
39}
40
41impl std::error::Error for DiscreteUniformError {}
42
43impl DiscreteUniform {
44    /// Constructs a new discrete uniform distribution with a minimum value
45    /// of `min` and a maximum value of `max`.
46    ///
47    /// # Errors
48    ///
49    /// Returns an error if `max < min`
50    ///
51    /// # Examples
52    ///
53    /// ```
54    /// use statrs::distribution::DiscreteUniform;
55    ///
56    /// let mut result = DiscreteUniform::new(0, 5);
57    /// assert!(result.is_ok());
58    ///
59    /// result = DiscreteUniform::new(5, 0);
60    /// assert!(result.is_err());
61    /// ```
62    pub fn new(min: i64, max: i64) -> Result<DiscreteUniform, DiscreteUniformError> {
63        if max < min {
64            Err(DiscreteUniformError::MinMaxInvalid)
65        } else {
66            Ok(DiscreteUniform { min, max })
67        }
68    }
69}
70
71impl std::fmt::Display for DiscreteUniform {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        write!(f, "Uni([{}, {}])", self.min, self.max)
74    }
75}
76
77#[cfg(feature = "rand")]
78#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
79impl ::rand::distributions::Distribution<i64> for DiscreteUniform {
80    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> i64 {
81        rng.gen_range(self.min..=self.max)
82    }
83}
84
85#[cfg(feature = "rand")]
86#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
87impl ::rand::distributions::Distribution<f64> for DiscreteUniform {
88    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
89        rng.sample::<i64, _>(self) as f64
90    }
91}
92
93impl DiscreteCDF<i64, f64> for DiscreteUniform {
94    /// Calculates the cumulative distribution function for the
95    /// discrete uniform distribution at `x`
96    ///
97    /// # Formula
98    ///
99    /// ```text
100    /// (floor(x) - min + 1) / (max - min + 1)
101    /// ```
102    fn cdf(&self, x: i64) -> f64 {
103        if x < self.min {
104            0.0
105        } else if x >= self.max {
106            1.0
107        } else {
108            let lower = self.min as f64;
109            let upper = self.max as f64;
110            let ans = (x as f64 - lower + 1.0) / (upper - lower + 1.0);
111            if ans > 1.0 {
112                1.0
113            } else {
114                ans
115            }
116        }
117    }
118
119    fn sf(&self, x: i64) -> f64 {
120        // 1. - self.cdf(x)
121        if x < self.min {
122            1.0
123        } else if x >= self.max {
124            0.0
125        } else {
126            let lower = self.min as f64;
127            let upper = self.max as f64;
128            let ans = (upper - x as f64) / (upper - lower + 1.0);
129            if ans > 1.0 {
130                1.0
131            } else {
132                ans
133            }
134        }
135    }
136}
137
138impl Min<i64> for DiscreteUniform {
139    /// Returns the minimum value in the domain of the discrete uniform
140    /// distribution
141    ///
142    /// # Remarks
143    ///
144    /// This is the same value as the minimum passed into the constructor
145    fn min(&self) -> i64 {
146        self.min
147    }
148}
149
150impl Max<i64> for DiscreteUniform {
151    /// Returns the maximum value in the domain of the discrete uniform
152    /// distribution
153    ///
154    /// # Remarks
155    ///
156    /// This is the same value as the maximum passed into the constructor
157    fn max(&self) -> i64 {
158        self.max
159    }
160}
161
162impl Distribution<f64> for DiscreteUniform {
163    /// Returns the mean of the discrete uniform distribution
164    ///
165    /// # Formula
166    ///
167    /// ```text
168    /// (min + max) / 2
169    /// ```
170    fn mean(&self) -> Option<f64> {
171        Some((self.min + self.max) as f64 / 2.0)
172    }
173
174    /// Returns the variance of the discrete uniform distribution
175    ///
176    /// # Formula
177    ///
178    /// ```text
179    /// ((max - min + 1)^2 - 1) / 12
180    /// ```
181    fn variance(&self) -> Option<f64> {
182        let diff = (self.max - self.min) as f64;
183        Some(((diff + 1.0) * (diff + 1.0) - 1.0) / 12.0)
184    }
185
186    /// Returns the entropy of the discrete uniform distribution
187    ///
188    /// # Formula
189    ///
190    /// ```text
191    /// ln(max - min + 1)
192    /// ```
193    fn entropy(&self) -> Option<f64> {
194        let diff = (self.max - self.min) as f64;
195        Some((diff + 1.0).ln())
196    }
197
198    /// Returns the skewness of the discrete uniform distribution
199    ///
200    /// # Formula
201    ///
202    /// ```text
203    /// 0
204    /// ```
205    fn skewness(&self) -> Option<f64> {
206        Some(0.0)
207    }
208}
209
210impl Median<f64> for DiscreteUniform {
211    /// Returns the median of the discrete uniform distribution
212    ///
213    /// # Formula
214    ///
215    /// ```text
216    /// (max + min) / 2
217    /// ```
218    fn median(&self) -> f64 {
219        (self.min + self.max) as f64 / 2.0
220    }
221}
222
223impl Mode<Option<i64>> for DiscreteUniform {
224    /// Returns the mode for the discrete uniform distribution
225    ///
226    /// # Remarks
227    ///
228    /// Since every element has an equal probability, mode simply
229    /// returns the middle element
230    ///
231    /// # Formula
232    ///
233    /// ```text
234    /// N/A // (max + min) / 2 for the middle element
235    /// ```
236    fn mode(&self) -> Option<i64> {
237        Some(((self.min + self.max) as f64 / 2.0).floor() as i64)
238    }
239}
240
241impl Discrete<i64, f64> for DiscreteUniform {
242    /// Calculates the probability mass function for the discrete uniform
243    /// distribution at `x`
244    ///
245    /// # Remarks
246    ///
247    /// Returns `0.0` if `x` is not in `[min, max]`
248    ///
249    /// # Formula
250    ///
251    /// ```text
252    /// 1 / (max - min + 1)
253    /// ```
254    fn pmf(&self, x: i64) -> f64 {
255        if x >= self.min && x <= self.max {
256            1.0 / (self.max - self.min + 1) as f64
257        } else {
258            0.0
259        }
260    }
261
262    /// Calculates the log probability mass function for the discrete uniform
263    /// distribution at `x`
264    ///
265    /// # Remarks
266    ///
267    /// Returns `f64::NEG_INFINITY` if `x` is not in `[min, max]`
268    ///
269    /// # Formula
270    ///
271    /// ```text
272    /// ln(1 / (max - min + 1))
273    /// ```
274    fn ln_pmf(&self, x: i64) -> f64 {
275        if x >= self.min && x <= self.max {
276            -((self.max - self.min + 1) as f64).ln()
277        } else {
278            f64::NEG_INFINITY
279        }
280    }
281}
282
283#[rustfmt::skip]
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use crate::testing_boiler;
288
289    testing_boiler!(min: i64, max: i64; DiscreteUniform; DiscreteUniformError);
290
291    #[test]
292    fn test_create() {
293        create_ok(-10, 10);
294        create_ok(0, 4);
295        create_ok(10, 20);
296        create_ok(20, 20);
297    }
298
299    #[test]
300    fn test_bad_create() {
301        create_err(-1, -2);
302        create_err(6, 5);
303    }
304
305    #[test]
306    fn test_mean() {
307        let mean = |x: DiscreteUniform| x.mean().unwrap();
308        test_exact(-10, 10, 0.0, mean);
309        test_exact(0, 4, 2.0, mean);
310        test_exact(10, 20, 15.0, mean);
311        test_exact(20, 20, 20.0, mean);
312    }
313
314    #[test]
315    fn test_variance() {
316        let variance = |x: DiscreteUniform| x.variance().unwrap();
317        test_exact(-10, 10, 36.66666666666666666667, variance);
318        test_exact(0, 4, 2.0, variance);
319        test_exact(10, 20, 10.0, variance);
320        test_exact(20, 20, 0.0, variance);
321    }
322
323    #[test]
324    fn test_entropy() {
325        let entropy = |x: DiscreteUniform| x.entropy().unwrap();
326        test_exact(-10, 10, 3.0445224377234229965005979803657054342845752874046093, entropy);
327        test_exact(0, 4, 1.6094379124341003746007593332261876395256013542685181, entropy);
328        test_exact(10, 20, 2.3978952727983705440619435779651292998217068539374197, entropy);
329        test_exact(20, 20, 0.0, entropy);
330    }
331
332    #[test]
333    fn test_skewness() {
334        let skewness = |x: DiscreteUniform| x.skewness().unwrap();
335        test_exact(-10, 10, 0.0, skewness);
336        test_exact(0, 4, 0.0, skewness);
337        test_exact(10, 20, 0.0, skewness);
338        test_exact(20, 20, 0.0, skewness);
339    }
340
341    #[test]
342    fn test_median() {
343        let median = |x: DiscreteUniform| x.median();
344        test_exact(-10, 10, 0.0, median);
345        test_exact(0, 4, 2.0, median);
346        test_exact(10, 20, 15.0, median);
347        test_exact(20, 20, 20.0, median);
348    }
349
350    #[test]
351    fn test_mode() {
352        let mode = |x: DiscreteUniform| x.mode().unwrap();
353        test_exact(-10, 10, 0, mode);
354        test_exact(0, 4, 2, mode);
355        test_exact(10, 20, 15, mode);
356        test_exact(20, 20, 20, mode);
357    }
358
359    #[test]
360    fn test_pmf() {
361        let pmf = |arg: i64| move |x: DiscreteUniform| x.pmf(arg);
362        test_exact(-10, 10, 0.04761904761904761904762, pmf(-5));
363        test_exact(-10, 10, 0.04761904761904761904762, pmf(1));
364        test_exact(-10, 10, 0.04761904761904761904762, pmf(10));
365        test_exact(-10, -10, 0.0, pmf(0));
366        test_exact(-10, -10, 1.0, pmf(-10));
367    }
368
369    #[test]
370    fn test_ln_pmf() {
371        let ln_pmf = |arg: i64| move |x: DiscreteUniform| x.ln_pmf(arg);
372        test_exact(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(-5));
373        test_exact(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(1));
374        test_exact(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(10));
375        test_exact(-10, -10, f64::NEG_INFINITY, ln_pmf(0));
376        test_exact(-10, -10, 0.0, ln_pmf(-10));
377    }
378
379    #[test]
380    fn test_cdf() {
381        let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg);
382        test_exact(-10, 10, 0.2857142857142857142857, cdf(-5));
383        test_exact(-10, 10, 0.5714285714285714285714, cdf(1));
384        test_exact(-10, 10, 1.0, cdf(10));
385        test_exact(-10, -10, 1.0, cdf(-10));
386    }
387
388    #[test]
389    fn test_sf() {
390        let sf = |arg: i64| move |x: DiscreteUniform| x.sf(arg);
391        test_exact(-10, 10, 0.7142857142857142857143, sf(-5));
392        test_exact(-10, 10, 0.42857142857142855, sf(1));
393        test_exact(-10, 10, 0.0, sf(10));
394        test_exact(-10, -10, 0.0, sf(-10));
395    }
396
397    #[test]
398    fn test_cdf_lower_bound() {
399        let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg);
400        test_exact(0, 3, 0.0, cdf(-1));
401    }
402
403    #[test]
404    fn test_sf_lower_bound() {
405        let sf = |arg: i64| move |x: DiscreteUniform| x.sf(arg);
406        test_exact(0, 3, 1.0, sf(-1));
407    }
408
409    #[test]
410    fn test_cdf_upper_bound() {
411        let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg);
412        test_exact(0, 3, 1.0, cdf(5));
413    }
414}