1use crate::distribution::{Discrete, DiscreteCDF};
2use crate::statistics::*;
3use crate::{Result, StatsError};
4use rand::Rng;
5
6#[derive(Debug, Copy, Clone, PartialEq)]
21pub struct DiscreteUniform {
22 min: i64,
23 max: i64,
24}
25
26impl DiscreteUniform {
27 pub fn new(min: i64, max: i64) -> Result<DiscreteUniform> {
46 if max < min {
47 Err(StatsError::BadParams)
48 } else {
49 Ok(DiscreteUniform { min, max })
50 }
51 }
52}
53
54impl ::rand::distributions::Distribution<f64> for DiscreteUniform {
55 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
56 rng.gen_range(self.min..=self.max) as f64
57 }
58}
59
60impl DiscreteCDF<i64, f64> for DiscreteUniform {
61 fn cdf(&self, x: i64) -> f64 {
70 if x < self.min {
71 0.0
72 } else if x >= self.max {
73 1.0
74 } else {
75 let lower = self.min as f64;
76 let upper = self.max as f64;
77 let ans = (x as f64 - lower + 1.0) / (upper - lower + 1.0);
78 if ans > 1.0 {
79 1.0
80 } else {
81 ans
82 }
83 }
84 }
85
86 fn sf(&self, x: i64) -> f64 {
87 if x < self.min {
89 1.0
90 } else if x >= self.max {
91 0.0
92 } else {
93 let lower = self.min as f64;
94 let upper = self.max as f64;
95 let ans = (upper - x as f64) / (upper - lower + 1.0);
96 if ans > 1.0 {
97 1.0
98 } else {
99 ans
100 }
101 }
102 }
103}
104
105impl Min<i64> for DiscreteUniform {
106 fn min(&self) -> i64 {
113 self.min
114 }
115}
116
117impl Max<i64> for DiscreteUniform {
118 fn max(&self) -> i64 {
125 self.max
126 }
127}
128
129impl Distribution<f64> for DiscreteUniform {
130 fn mean(&self) -> Option<f64> {
138 Some((self.min + self.max) as f64 / 2.0)
139 }
140 fn variance(&self) -> Option<f64> {
148 let diff = (self.max - self.min) as f64;
149 Some(((diff + 1.0) * (diff + 1.0) - 1.0) / 12.0)
150 }
151 fn entropy(&self) -> Option<f64> {
159 let diff = (self.max - self.min) as f64;
160 Some((diff + 1.0).ln())
161 }
162 fn skewness(&self) -> Option<f64> {
170 Some(0.0)
171 }
172}
173
174impl Median<f64> for DiscreteUniform {
175 fn median(&self) -> f64 {
183 (self.min + self.max) as f64 / 2.0
184 }
185}
186
187impl Mode<Option<i64>> for DiscreteUniform {
188 fn mode(&self) -> Option<i64> {
201 Some(((self.min + self.max) as f64 / 2.0).floor() as i64)
202 }
203}
204
205impl Discrete<i64, f64> for DiscreteUniform {
206 fn pmf(&self, x: i64) -> f64 {
219 if x >= self.min && x <= self.max {
220 1.0 / (self.max - self.min + 1) as f64
221 } else {
222 0.0
223 }
224 }
225
226 fn ln_pmf(&self, x: i64) -> f64 {
239 if x >= self.min && x <= self.max {
240 -((self.max - self.min + 1) as f64).ln()
241 } else {
242 f64::NEG_INFINITY
243 }
244 }
245}
246
247#[rustfmt::skip]
248#[cfg(all(test, feature = "nightly"))]
249mod tests {
250 use std::fmt::Debug;
251 use crate::statistics::*;
252 use crate::distribution::{DiscreteCDF, Discrete, DiscreteUniform};
253 use crate::consts::ACC;
254
255 fn try_create(min: i64, max: i64) -> DiscreteUniform {
256 let n = DiscreteUniform::new(min, max);
257 assert!(n.is_ok());
258 n.unwrap()
259 }
260
261 fn create_case(min: i64, max: i64) {
262 let n = try_create(min, max);
263 assert_eq!(min, n.min());
264 assert_eq!(max, n.max());
265 }
266
267 fn bad_create_case(min: i64, max: i64) {
268 let n = DiscreteUniform::new(min, max);
269 assert!(n.is_err());
270 }
271
272 fn get_value<T, F>(min: i64, max: i64, eval: F) -> T
273 where T: PartialEq + Debug,
274 F: Fn(DiscreteUniform) -> T
275 {
276 let n = try_create(min, max);
277 eval(n)
278 }
279
280 fn test_case<T, F>(min: i64, max: i64, expected: T, eval: F)
281 where T: PartialEq + Debug,
282 F: Fn(DiscreteUniform) -> T
283 {
284 let x = get_value(min, max, eval);
285 assert_eq!(expected, x);
286 }
287
288 #[test]
289 fn test_create() {
290 create_case(-10, 10);
291 create_case(0, 4);
292 create_case(10, 20);
293 create_case(20, 20);
294 }
295
296 #[test]
297 fn test_bad_create() {
298 bad_create_case(-1, -2);
299 bad_create_case(6, 5);
300 }
301
302 #[test]
303 fn test_mean() {
304 let mean = |x: DiscreteUniform| x.mean().unwrap();
305 test_case(-10, 10, 0.0, mean);
306 test_case(0, 4, 2.0, mean);
307 test_case(10, 20, 15.0, mean);
308 test_case(20, 20, 20.0, mean);
309 }
310
311 #[test]
312 fn test_variance() {
313 let variance = |x: DiscreteUniform| x.variance().unwrap();
314 test_case(-10, 10, 36.66666666666666666667, variance);
315 test_case(0, 4, 2.0, variance);
316 test_case(10, 20, 10.0, variance);
317 test_case(20, 20, 0.0, variance);
318 }
319
320 #[test]
321 fn test_entropy() {
322 let entropy = |x: DiscreteUniform| x.entropy().unwrap();
323 test_case(-10, 10, 3.0445224377234229965005979803657054342845752874046093, entropy);
324 test_case(0, 4, 1.6094379124341003746007593332261876395256013542685181, entropy);
325 test_case(10, 20, 2.3978952727983705440619435779651292998217068539374197, entropy);
326 test_case(20, 20, 0.0, entropy);
327 }
328
329 #[test]
330 fn test_skewness() {
331 let skewness = |x: DiscreteUniform| x.skewness().unwrap();
332 test_case(-10, 10, 0.0, skewness);
333 test_case(0, 4, 0.0, skewness);
334 test_case(10, 20, 0.0, skewness);
335 test_case(20, 20, 0.0, skewness);
336 }
337
338 #[test]
339 fn test_median() {
340 let median = |x: DiscreteUniform| x.median();
341 test_case(-10, 10, 0.0, median);
342 test_case(0, 4, 2.0, median);
343 test_case(10, 20, 15.0, median);
344 test_case(20, 20, 20.0, median);
345 }
346
347 #[test]
348 fn test_mode() {
349 let mode = |x: DiscreteUniform| x.mode().unwrap();
350 test_case(-10, 10, 0, mode);
351 test_case(0, 4, 2, mode);
352 test_case(10, 20, 15, mode);
353 test_case(20, 20, 20, mode);
354 }
355
356 #[test]
357 fn test_pmf() {
358 let pmf = |arg: i64| move |x: DiscreteUniform| x.pmf(arg);
359 test_case(-10, 10, 0.04761904761904761904762, pmf(-5));
360 test_case(-10, 10, 0.04761904761904761904762, pmf(1));
361 test_case(-10, 10, 0.04761904761904761904762, pmf(10));
362 test_case(-10, -10, 0.0, pmf(0));
363 test_case(-10, -10, 1.0, pmf(-10));
364 }
365
366 #[test]
367 fn test_ln_pmf() {
368 let ln_pmf = |arg: i64| move |x: DiscreteUniform| x.ln_pmf(arg);
369 test_case(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(-5));
370 test_case(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(1));
371 test_case(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(10));
372 test_case(-10, -10, f64::NEG_INFINITY, ln_pmf(0));
373 test_case(-10, -10, 0.0, ln_pmf(-10));
374 }
375
376 #[test]
377 fn test_cdf() {
378 let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg);
379 test_case(-10, 10, 0.2857142857142857142857, cdf(-5));
380 test_case(-10, 10, 0.5714285714285714285714, cdf(1));
381 test_case(-10, 10, 1.0, cdf(10));
382 test_case(-10, -10, 1.0, cdf(-10));
383 }
384
385 #[test]
386 fn test_sf() {
387 let sf = |arg: i64| move |x: DiscreteUniform| x.sf(arg);
388 test_case(-10, 10, 0.7142857142857142857143, sf(-5));
389 test_case(-10, 10, 0.42857142857142855, sf(1));
390 test_case(-10, 10, 0.0, sf(10));
391 test_case(-10, -10, 0.0, sf(-10));
392 }
393
394 #[test]
395 fn test_cdf_lower_bound() {
396 let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg);
397 test_case(0, 3, 0.0, cdf(-1));
398 }
399
400 #[test]
401 fn test_sf_lower_bound() {
402 let sf = |arg: i64| move |x: DiscreteUniform| x.sf(arg);
403 test_case(0, 3, 1.0, sf(-1));
404 }
405
406 #[test]
407 fn test_cdf_upper_bound() {
408 let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg);
409 test_case(0, 3, 1.0, cdf(5));
410 }
411}