statrs/distribution/
discrete_uniform.rs

1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::statistics::*;
3use crate::{Result, StatsError};
4use rand::Rng;
5
6/// Implements the [Discrete
7/// Uniform](https://en.wikipedia.org/wiki/Discrete_uniform_distribution)
8/// distribution
9///
10/// # Examples
11///
12/// ```
13/// use statrs::distribution::{DiscreteUniform, Discrete};
14/// use statrs::statistics::Distribution;
15///
16/// let n = DiscreteUniform::new(0, 5).unwrap();
17/// assert_eq!(n.mean().unwrap(), 2.5);
18/// assert_eq!(n.pmf(3), 1.0 / 6.0);
19/// ```
20#[derive(Debug, Copy, Clone, PartialEq)]
21pub struct DiscreteUniform {
22    min: i64,
23    max: i64,
24}
25
26impl DiscreteUniform {
27    /// Constructs a new discrete uniform distribution with a minimum value
28    /// of `min` and a maximum value of `max`.
29    ///
30    /// # Errors
31    ///
32    /// Returns an error if `max < min`
33    ///
34    /// # Examples
35    ///
36    /// ```
37    /// use statrs::distribution::DiscreteUniform;
38    ///
39    /// let mut result = DiscreteUniform::new(0, 5);
40    /// assert!(result.is_ok());
41    ///
42    /// result = DiscreteUniform::new(5, 0);
43    /// assert!(result.is_err());
44    /// ```
45    pub fn new(min: i64, max: i64) -> Result<DiscreteUniform> {
46        if max < min {
47            Err(StatsError::BadParams)
48        } else {
49            Ok(DiscreteUniform { min, max })
50        }
51    }
52}
53
54impl ::rand::distributions::Distribution<f64> for DiscreteUniform {
55    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
56        rng.gen_range(self.min..=self.max) as f64
57    }
58}
59
60impl DiscreteCDF<i64, f64> for DiscreteUniform {
61    /// Calculates the cumulative distribution function for the
62    /// discrete uniform distribution at `x`
63    ///
64    /// # Formula
65    ///
66    /// ```ignore
67    /// (floor(x) - min + 1) / (max - min + 1)
68    /// ```
69    fn cdf(&self, x: i64) -> f64 {
70        if x < self.min {
71            0.0
72        } else if x >= self.max {
73            1.0
74        } else {
75            let lower = self.min as f64;
76            let upper = self.max as f64;
77            let ans = (x as f64 - lower + 1.0) / (upper - lower + 1.0);
78            if ans > 1.0 {
79                1.0
80            } else {
81                ans
82            }
83        }
84    }
85
86    fn sf(&self, x: i64) -> f64 {
87        //1. - self.cdf(x)
88        if x < self.min {
89            1.0
90        } else if x >= self.max {
91            0.0
92        } else {
93            let lower = self.min as f64;
94            let upper = self.max as f64;
95            let ans = (upper - x as f64) / (upper - lower + 1.0);
96            if ans > 1.0 {
97                1.0
98            } else {
99                ans
100            }
101        }
102    }
103}
104
105impl Min<i64> for DiscreteUniform {
106    /// Returns the minimum value in the domain of the discrete uniform
107    /// distribution
108    ///
109    /// # Remarks
110    ///
111    /// This is the same value as the minimum passed into the constructor
112    fn min(&self) -> i64 {
113        self.min
114    }
115}
116
117impl Max<i64> for DiscreteUniform {
118    /// Returns the maximum value in the domain of the discrete uniform
119    /// distribution
120    ///
121    /// # Remarks
122    ///
123    /// This is the same value as the maximum passed into the constructor
124    fn max(&self) -> i64 {
125        self.max
126    }
127}
128
129impl Distribution<f64> for DiscreteUniform {
130    /// Returns the mean of the discrete uniform distribution
131    ///
132    /// # Formula
133    ///
134    /// ```ignore
135    /// (min + max) / 2
136    /// ```
137    fn mean(&self) -> Option<f64> {
138        Some((self.min + self.max) as f64 / 2.0)
139    }
140    /// Returns the variance of the discrete uniform distribution
141    ///
142    /// # Formula
143    ///
144    /// ```ignore
145    /// ((max - min + 1)^2 - 1) / 12
146    /// ```
147    fn variance(&self) -> Option<f64> {
148        let diff = (self.max - self.min) as f64;
149        Some(((diff + 1.0) * (diff + 1.0) - 1.0) / 12.0)
150    }
151    /// Returns the entropy of the discrete uniform distribution
152    ///
153    /// # Formula
154    ///
155    /// ```ignore
156    /// ln(max - min + 1)
157    /// ```
158    fn entropy(&self) -> Option<f64> {
159        let diff = (self.max - self.min) as f64;
160        Some((diff + 1.0).ln())
161    }
162    /// Returns the skewness of the discrete uniform distribution
163    ///
164    /// # Formula
165    ///
166    /// ```ignore
167    /// 0
168    /// ```
169    fn skewness(&self) -> Option<f64> {
170        Some(0.0)
171    }
172}
173
174impl Median<f64> for DiscreteUniform {
175    /// Returns the median of the discrete uniform distribution
176    ///
177    /// # Formula
178    ///
179    /// ```ignore
180    /// (max + min) / 2
181    /// ```
182    fn median(&self) -> f64 {
183        (self.min + self.max) as f64 / 2.0
184    }
185}
186
187impl Mode<Option<i64>> for DiscreteUniform {
188    /// Returns the mode for the discrete uniform distribution
189    ///
190    /// # Remarks
191    ///
192    /// Since every element has an equal probability, mode simply
193    /// returns the middle element
194    ///
195    /// # Formula
196    ///
197    /// ```ignore
198    /// N/A // (max + min) / 2 for the middle element
199    /// ```
200    fn mode(&self) -> Option<i64> {
201        Some(((self.min + self.max) as f64 / 2.0).floor() as i64)
202    }
203}
204
205impl Discrete<i64, f64> for DiscreteUniform {
206    /// Calculates the probability mass function for the discrete uniform
207    /// distribution at `x`
208    ///
209    /// # Remarks
210    ///
211    /// Returns `0.0` if `x` is not in `[min, max]`
212    ///
213    /// # Formula
214    ///
215    /// ```ignore
216    /// 1 / (max - min + 1)
217    /// ```
218    fn pmf(&self, x: i64) -> f64 {
219        if x >= self.min && x <= self.max {
220            1.0 / (self.max - self.min + 1) as f64
221        } else {
222            0.0
223        }
224    }
225
226    /// Calculates the log probability mass function for the discrete uniform
227    /// distribution at `x`
228    ///
229    /// # Remarks
230    ///
231    /// Returns `f64::NEG_INFINITY` if `x` is not in `[min, max]`
232    ///
233    /// # Formula
234    ///
235    /// ```ignore
236    /// ln(1 / (max - min + 1))
237    /// ```
238    fn ln_pmf(&self, x: i64) -> f64 {
239        if x >= self.min && x <= self.max {
240            -((self.max - self.min + 1) as f64).ln()
241        } else {
242            f64::NEG_INFINITY
243        }
244    }
245}
246
247#[rustfmt::skip]
248#[cfg(all(test, feature = "nightly"))]
249mod tests {
250    use std::fmt::Debug;
251    use crate::statistics::*;
252    use crate::distribution::{DiscreteCDF, Discrete, DiscreteUniform};
253    use crate::consts::ACC;
254
255    fn try_create(min: i64, max: i64) -> DiscreteUniform {
256        let n = DiscreteUniform::new(min, max);
257        assert!(n.is_ok());
258        n.unwrap()
259    }
260
261    fn create_case(min: i64, max: i64) {
262        let n = try_create(min, max);
263        assert_eq!(min, n.min());
264        assert_eq!(max, n.max());
265    }
266
267    fn bad_create_case(min: i64, max: i64) {
268        let n = DiscreteUniform::new(min, max);
269        assert!(n.is_err());
270    }
271
272    fn get_value<T, F>(min: i64, max: i64, eval: F) -> T
273        where T: PartialEq + Debug,
274              F: Fn(DiscreteUniform) -> T
275    {
276        let n = try_create(min, max);
277        eval(n)
278    }
279
280    fn test_case<T, F>(min: i64, max: i64, expected: T, eval: F)
281        where T: PartialEq + Debug,
282              F: Fn(DiscreteUniform) -> T
283    {
284        let x = get_value(min, max, eval);
285        assert_eq!(expected, x);
286    }
287
288    #[test]
289    fn test_create() {
290        create_case(-10, 10);
291        create_case(0, 4);
292        create_case(10, 20);
293        create_case(20, 20);
294    }
295
296    #[test]
297    fn test_bad_create() {
298        bad_create_case(-1, -2);
299        bad_create_case(6, 5);
300    }
301
302    #[test]
303    fn test_mean() {
304        let mean = |x: DiscreteUniform| x.mean().unwrap();
305        test_case(-10, 10, 0.0, mean);
306        test_case(0, 4, 2.0, mean);
307        test_case(10, 20, 15.0, mean);
308        test_case(20, 20, 20.0, mean);
309    }
310
311    #[test]
312    fn test_variance() {
313        let variance = |x: DiscreteUniform| x.variance().unwrap();
314        test_case(-10, 10, 36.66666666666666666667, variance);
315        test_case(0, 4, 2.0, variance);
316        test_case(10, 20, 10.0, variance);
317        test_case(20, 20, 0.0, variance);
318    }
319
320    #[test]
321    fn test_entropy() {
322        let entropy = |x: DiscreteUniform| x.entropy().unwrap();
323        test_case(-10, 10, 3.0445224377234229965005979803657054342845752874046093, entropy);
324        test_case(0, 4, 1.6094379124341003746007593332261876395256013542685181, entropy);
325        test_case(10, 20, 2.3978952727983705440619435779651292998217068539374197, entropy);
326        test_case(20, 20, 0.0, entropy);
327    }
328
329    #[test]
330    fn test_skewness() {
331        let skewness = |x: DiscreteUniform| x.skewness().unwrap();
332        test_case(-10, 10, 0.0, skewness);
333        test_case(0, 4, 0.0, skewness);
334        test_case(10, 20, 0.0, skewness);
335        test_case(20, 20, 0.0, skewness);
336    }
337
338    #[test]
339    fn test_median() {
340        let median = |x: DiscreteUniform| x.median();
341        test_case(-10, 10, 0.0, median);
342        test_case(0, 4, 2.0, median);
343        test_case(10, 20, 15.0, median);
344        test_case(20, 20, 20.0, median);
345    }
346
347    #[test]
348    fn test_mode() {
349        let mode = |x: DiscreteUniform| x.mode().unwrap();
350        test_case(-10, 10, 0, mode);
351        test_case(0, 4, 2, mode);
352        test_case(10, 20, 15, mode);
353        test_case(20, 20, 20, mode);
354    }
355
356    #[test]
357    fn test_pmf() {
358        let pmf = |arg: i64| move |x: DiscreteUniform| x.pmf(arg);
359        test_case(-10, 10, 0.04761904761904761904762, pmf(-5));
360        test_case(-10, 10, 0.04761904761904761904762, pmf(1));
361        test_case(-10, 10, 0.04761904761904761904762, pmf(10));
362        test_case(-10, -10, 0.0, pmf(0));
363        test_case(-10, -10, 1.0, pmf(-10));
364    }
365
366    #[test]
367    fn test_ln_pmf() {
368        let ln_pmf = |arg: i64| move |x: DiscreteUniform| x.ln_pmf(arg);
369        test_case(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(-5));
370        test_case(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(1));
371        test_case(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(10));
372        test_case(-10, -10, f64::NEG_INFINITY, ln_pmf(0));
373        test_case(-10, -10, 0.0, ln_pmf(-10));
374    }
375
376    #[test]
377    fn test_cdf() {
378        let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg);
379        test_case(-10, 10, 0.2857142857142857142857, cdf(-5));
380        test_case(-10, 10, 0.5714285714285714285714, cdf(1));
381        test_case(-10, 10, 1.0, cdf(10));
382        test_case(-10, -10, 1.0, cdf(-10));
383    }
384
385    #[test]
386    fn test_sf() {
387        let sf = |arg: i64| move |x: DiscreteUniform| x.sf(arg);
388        test_case(-10, 10, 0.7142857142857142857143, sf(-5));
389        test_case(-10, 10, 0.42857142857142855, sf(1));
390        test_case(-10, 10, 0.0, sf(10));
391        test_case(-10, -10, 0.0, sf(-10));
392    }
393
394    #[test]
395    fn test_cdf_lower_bound() {
396        let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg);
397        test_case(0, 3, 0.0, cdf(-1));
398    }
399
400    #[test]
401    fn test_sf_lower_bound() {
402        let sf = |arg: i64| move |x: DiscreteUniform| x.sf(arg);
403        test_case(0, 3, 1.0, sf(-1));
404    }
405
406    #[test]
407    fn test_cdf_upper_bound() {
408        let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg);
409        test_case(0, 3, 1.0, cdf(5));
410    }
411}