1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::statistics::*;
3
4#[derive(Copy, Clone, PartialEq, Eq, Debug)]
19pub struct DiscreteUniform {
20 min: i64,
21 max: i64,
22}
23
24#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
26#[non_exhaustive]
27pub enum DiscreteUniformError {
28 MinMaxInvalid,
30}
31
32impl std::fmt::Display for DiscreteUniformError {
33 #[cfg_attr(coverage_nightly, coverage(off))]
34 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
35 match self {
36 DiscreteUniformError::MinMaxInvalid => write!(f, "Maximum is less than minimum"),
37 }
38 }
39}
40
41impl std::error::Error for DiscreteUniformError {}
42
43impl DiscreteUniform {
44 pub fn new(min: i64, max: i64) -> Result<DiscreteUniform, DiscreteUniformError> {
63 if max < min {
64 Err(DiscreteUniformError::MinMaxInvalid)
65 } else {
66 Ok(DiscreteUniform { min, max })
67 }
68 }
69}
70
71impl std::fmt::Display for DiscreteUniform {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 write!(f, "Uni([{}, {}])", self.min, self.max)
74 }
75}
76
77#[cfg(feature = "rand")]
78#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
79impl ::rand::distributions::Distribution<i64> for DiscreteUniform {
80 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> i64 {
81 rng.gen_range(self.min..=self.max)
82 }
83}
84
85#[cfg(feature = "rand")]
86#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
87impl ::rand::distributions::Distribution<f64> for DiscreteUniform {
88 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
89 rng.sample::<i64, _>(self) as f64
90 }
91}
92
93impl DiscreteCDF<i64, f64> for DiscreteUniform {
94 fn cdf(&self, x: i64) -> f64 {
103 if x < self.min {
104 0.0
105 } else if x >= self.max {
106 1.0
107 } else {
108 let lower = self.min as f64;
109 let upper = self.max as f64;
110 let ans = (x as f64 - lower + 1.0) / (upper - lower + 1.0);
111 if ans > 1.0 {
112 1.0
113 } else {
114 ans
115 }
116 }
117 }
118
119 fn sf(&self, x: i64) -> f64 {
120 if x < self.min {
122 1.0
123 } else if x >= self.max {
124 0.0
125 } else {
126 let lower = self.min as f64;
127 let upper = self.max as f64;
128 let ans = (upper - x as f64) / (upper - lower + 1.0);
129 if ans > 1.0 {
130 1.0
131 } else {
132 ans
133 }
134 }
135 }
136}
137
138impl Min<i64> for DiscreteUniform {
139 fn min(&self) -> i64 {
146 self.min
147 }
148}
149
150impl Max<i64> for DiscreteUniform {
151 fn max(&self) -> i64 {
158 self.max
159 }
160}
161
162impl Distribution<f64> for DiscreteUniform {
163 fn mean(&self) -> Option<f64> {
171 Some((self.min + self.max) as f64 / 2.0)
172 }
173
174 fn variance(&self) -> Option<f64> {
182 let diff = (self.max - self.min) as f64;
183 Some(((diff + 1.0) * (diff + 1.0) - 1.0) / 12.0)
184 }
185
186 fn entropy(&self) -> Option<f64> {
194 let diff = (self.max - self.min) as f64;
195 Some((diff + 1.0).ln())
196 }
197
198 fn skewness(&self) -> Option<f64> {
206 Some(0.0)
207 }
208}
209
210impl Median<f64> for DiscreteUniform {
211 fn median(&self) -> f64 {
219 (self.min + self.max) as f64 / 2.0
220 }
221}
222
223impl Mode<Option<i64>> for DiscreteUniform {
224 fn mode(&self) -> Option<i64> {
237 Some(((self.min + self.max) as f64 / 2.0).floor() as i64)
238 }
239}
240
241impl Discrete<i64, f64> for DiscreteUniform {
242 fn pmf(&self, x: i64) -> f64 {
255 if x >= self.min && x <= self.max {
256 1.0 / (self.max - self.min + 1) as f64
257 } else {
258 0.0
259 }
260 }
261
262 fn ln_pmf(&self, x: i64) -> f64 {
275 if x >= self.min && x <= self.max {
276 -((self.max - self.min + 1) as f64).ln()
277 } else {
278 f64::NEG_INFINITY
279 }
280 }
281}
282
283#[rustfmt::skip]
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use crate::testing_boiler;
288
289 testing_boiler!(min: i64, max: i64; DiscreteUniform; DiscreteUniformError);
290
291 #[test]
292 fn test_create() {
293 create_ok(-10, 10);
294 create_ok(0, 4);
295 create_ok(10, 20);
296 create_ok(20, 20);
297 }
298
299 #[test]
300 fn test_bad_create() {
301 create_err(-1, -2);
302 create_err(6, 5);
303 }
304
305 #[test]
306 fn test_mean() {
307 let mean = |x: DiscreteUniform| x.mean().unwrap();
308 test_exact(-10, 10, 0.0, mean);
309 test_exact(0, 4, 2.0, mean);
310 test_exact(10, 20, 15.0, mean);
311 test_exact(20, 20, 20.0, mean);
312 }
313
314 #[test]
315 fn test_variance() {
316 let variance = |x: DiscreteUniform| x.variance().unwrap();
317 test_exact(-10, 10, 36.66666666666666666667, variance);
318 test_exact(0, 4, 2.0, variance);
319 test_exact(10, 20, 10.0, variance);
320 test_exact(20, 20, 0.0, variance);
321 }
322
323 #[test]
324 fn test_entropy() {
325 let entropy = |x: DiscreteUniform| x.entropy().unwrap();
326 test_exact(-10, 10, 3.0445224377234229965005979803657054342845752874046093, entropy);
327 test_exact(0, 4, 1.6094379124341003746007593332261876395256013542685181, entropy);
328 test_exact(10, 20, 2.3978952727983705440619435779651292998217068539374197, entropy);
329 test_exact(20, 20, 0.0, entropy);
330 }
331
332 #[test]
333 fn test_skewness() {
334 let skewness = |x: DiscreteUniform| x.skewness().unwrap();
335 test_exact(-10, 10, 0.0, skewness);
336 test_exact(0, 4, 0.0, skewness);
337 test_exact(10, 20, 0.0, skewness);
338 test_exact(20, 20, 0.0, skewness);
339 }
340
341 #[test]
342 fn test_median() {
343 let median = |x: DiscreteUniform| x.median();
344 test_exact(-10, 10, 0.0, median);
345 test_exact(0, 4, 2.0, median);
346 test_exact(10, 20, 15.0, median);
347 test_exact(20, 20, 20.0, median);
348 }
349
350 #[test]
351 fn test_mode() {
352 let mode = |x: DiscreteUniform| x.mode().unwrap();
353 test_exact(-10, 10, 0, mode);
354 test_exact(0, 4, 2, mode);
355 test_exact(10, 20, 15, mode);
356 test_exact(20, 20, 20, mode);
357 }
358
359 #[test]
360 fn test_pmf() {
361 let pmf = |arg: i64| move |x: DiscreteUniform| x.pmf(arg);
362 test_exact(-10, 10, 0.04761904761904761904762, pmf(-5));
363 test_exact(-10, 10, 0.04761904761904761904762, pmf(1));
364 test_exact(-10, 10, 0.04761904761904761904762, pmf(10));
365 test_exact(-10, -10, 0.0, pmf(0));
366 test_exact(-10, -10, 1.0, pmf(-10));
367 }
368
369 #[test]
370 fn test_ln_pmf() {
371 let ln_pmf = |arg: i64| move |x: DiscreteUniform| x.ln_pmf(arg);
372 test_exact(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(-5));
373 test_exact(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(1));
374 test_exact(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(10));
375 test_exact(-10, -10, f64::NEG_INFINITY, ln_pmf(0));
376 test_exact(-10, -10, 0.0, ln_pmf(-10));
377 }
378
379 #[test]
380 fn test_cdf() {
381 let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg);
382 test_exact(-10, 10, 0.2857142857142857142857, cdf(-5));
383 test_exact(-10, 10, 0.5714285714285714285714, cdf(1));
384 test_exact(-10, 10, 1.0, cdf(10));
385 test_exact(-10, -10, 1.0, cdf(-10));
386 }
387
388 #[test]
389 fn test_sf() {
390 let sf = |arg: i64| move |x: DiscreteUniform| x.sf(arg);
391 test_exact(-10, 10, 0.7142857142857142857143, sf(-5));
392 test_exact(-10, 10, 0.42857142857142855, sf(1));
393 test_exact(-10, 10, 0.0, sf(10));
394 test_exact(-10, -10, 0.0, sf(-10));
395 }
396
397 #[test]
398 fn test_cdf_lower_bound() {
399 let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg);
400 test_exact(0, 3, 0.0, cdf(-1));
401 }
402
403 #[test]
404 fn test_sf_lower_bound() {
405 let sf = |arg: i64| move |x: DiscreteUniform| x.sf(arg);
406 test_exact(0, 3, 1.0, sf(-1));
407 }
408
409 #[test]
410 fn test_cdf_upper_bound() {
411 let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg);
412 test_exact(0, 3, 1.0, cdf(5));
413 }
414}