1use crate::distribution::{Continuous, ContinuousCDF};
2use crate::statistics::*;
3use std::f64;
4use std::fmt::Debug;
5
6#[derive(Debug, Copy, Clone, PartialEq)]
21pub struct Uniform {
22 min: f64,
23 max: f64,
24}
25
26#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
28#[non_exhaustive]
29pub enum UniformError {
30 MinInvalid,
32
33 MaxInvalid,
35
36 MaxNotGreaterThanMin,
38}
39
40impl std::fmt::Display for UniformError {
41 #[cfg_attr(coverage_nightly, coverage(off))]
42 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43 match self {
44 UniformError::MinInvalid => write!(f, "Minimum is NaN or infinite"),
45 UniformError::MaxInvalid => write!(f, "Maximum is NaN or infinite"),
46 UniformError::MaxNotGreaterThanMin => {
47 write!(f, "Maximum is not greater than the minimum")
48 }
49 }
50 }
51}
52
53impl std::error::Error for UniformError {}
54
55impl Uniform {
56 pub fn new(min: f64, max: f64) -> Result<Uniform, UniformError> {
80 if !min.is_finite() {
81 return Err(UniformError::MinInvalid);
82 }
83
84 if !max.is_finite() {
85 return Err(UniformError::MaxInvalid);
86 }
87
88 if min < max {
89 Ok(Uniform { min, max })
90 } else {
91 Err(UniformError::MaxNotGreaterThanMin)
92 }
93 }
94
95 pub fn standard() -> Self {
106 Self { min: 0.0, max: 1.0 }
107 }
108}
109
110impl Default for Uniform {
111 fn default() -> Self {
112 Self::standard()
113 }
114}
115
116impl std::fmt::Display for Uniform {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 write!(f, "Uni([{},{}])", self.min, self.max)
119 }
120}
121
122#[cfg(feature = "rand")]
123#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
124impl ::rand::distributions::Distribution<f64> for Uniform {
125 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
126 let d = rand::distributions::Uniform::new_inclusive(self.min, self.max);
127 rng.sample(d)
128 }
129}
130
131impl ContinuousCDF<f64, f64> for Uniform {
132 fn cdf(&self, x: f64) -> f64 {
142 if x <= self.min {
143 0.0
144 } else if x >= self.max {
145 1.0
146 } else {
147 (x - self.min) / (self.max - self.min)
148 }
149 }
150
151 fn sf(&self, x: f64) -> f64 {
160 if x <= self.min {
161 1.0
162 } else if x >= self.max {
163 0.0
164 } else {
165 (self.max - x) / (self.max - self.min)
166 }
167 }
168
169 fn inverse_cdf(&self, p: f64) -> f64 {
171 if !(0.0..=1.0).contains(&p) {
172 panic!("p must be in [0, 1], was {p}");
173 } else if p == 0.0 {
174 self.min
175 } else if p == 1.0 {
176 self.max
177 } else {
178 (self.max - self.min) * p + self.min
179 }
180 }
181}
182
183impl Min<f64> for Uniform {
184 fn min(&self) -> f64 {
185 self.min
186 }
187}
188
189impl Max<f64> for Uniform {
190 fn max(&self) -> f64 {
191 self.max
192 }
193}
194
195impl Distribution<f64> for Uniform {
196 fn mean(&self) -> Option<f64> {
204 Some((self.min + self.max) / 2.0)
205 }
206
207 fn variance(&self) -> Option<f64> {
215 Some((self.max - self.min) * (self.max - self.min) / 12.0)
216 }
217
218 fn entropy(&self) -> Option<f64> {
226 Some((self.max - self.min).ln())
227 }
228
229 fn skewness(&self) -> Option<f64> {
237 Some(0.0)
238 }
239}
240
241impl Median<f64> for Uniform {
242 fn median(&self) -> f64 {
250 (self.min + self.max) / 2.0
251 }
252}
253
254impl Mode<Option<f64>> for Uniform {
255 fn mode(&self) -> Option<f64> {
268 Some((self.min + self.max) / 2.0)
269 }
270}
271
272impl Continuous<f64, f64> for Uniform {
273 fn pdf(&self, x: f64) -> f64 {
286 if x < self.min || x > self.max {
287 0.0
288 } else {
289 1.0 / (self.max - self.min)
290 }
291 }
292
293 fn ln_pdf(&self, x: f64) -> f64 {
307 if x < self.min || x > self.max {
308 f64::NEG_INFINITY
309 } else {
310 -(self.max - self.min).ln()
311 }
312 }
313}
314
315#[rustfmt::skip]
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use crate::distribution::internal::*;
320 use crate::testing_boiler;
321
322 testing_boiler!(min: f64, max: f64; Uniform; UniformError);
323
324 #[test]
325 fn test_create() {
326 create_ok(0.0, 0.1);
327 create_ok(0.0, 1.0);
328 create_ok(-5.0, 11.0);
329 create_ok(-5.0, 100.0);
330 }
331
332 #[test]
333 fn test_bad_create() {
334 let invalid = [
335 (0.0, 0.0, UniformError::MaxNotGreaterThanMin),
336 (f64::NAN, 1.0, UniformError::MinInvalid),
337 (1.0, f64::NAN, UniformError::MaxInvalid),
338 (f64::NAN, f64::NAN, UniformError::MinInvalid),
339 (0.0, f64::INFINITY, UniformError::MaxInvalid),
340 (1.0, 0.0, UniformError::MaxNotGreaterThanMin),
341 ];
342
343 for (min, max, err) in invalid {
344 test_create_err(min, max, err);
345 }
346 }
347
348 #[test]
349 fn test_variance() {
350 let variance = |x: Uniform| x.variance().unwrap();
351 test_exact(-0.0, 2.0, 1.0 / 3.0, variance);
352 test_exact(0.0, 2.0, 1.0 / 3.0, variance);
353 test_absolute(0.1, 4.0, 1.2675, 1e-15, variance);
354 test_exact(10.0, 11.0, 1.0 / 12.0, variance);
355 }
356
357 #[test]
358 fn test_entropy() {
359 let entropy = |x: Uniform| x.entropy().unwrap();
360 test_exact(-0.0, 2.0, 0.6931471805599453094172, entropy);
361 test_exact(0.0, 2.0, 0.6931471805599453094172, entropy);
362 test_absolute(0.1, 4.0, 1.360976553135600743431, 1e-15, entropy);
363 test_exact(1.0, 10.0, 2.19722457733621938279, entropy);
364 test_exact(10.0, 11.0, 0.0, entropy);
365 }
366
367 #[test]
368 fn test_skewness() {
369 let skewness = |x: Uniform| x.skewness().unwrap();
370 test_exact(-0.0, 2.0, 0.0, skewness);
371 test_exact(0.0, 2.0, 0.0, skewness);
372 test_exact(0.1, 4.0, 0.0, skewness);
373 test_exact(1.0, 10.0, 0.0, skewness);
374 test_exact(10.0, 11.0, 0.0, skewness);
375 }
376
377 #[test]
378 fn test_mode() {
379 let mode = |x: Uniform| x.mode().unwrap();
380 test_exact(-0.0, 2.0, 1.0, mode);
381 test_exact(0.0, 2.0, 1.0, mode);
382 test_exact(0.1, 4.0, 2.05, mode);
383 test_exact(1.0, 10.0, 5.5, mode);
384 test_exact(10.0, 11.0, 10.5, mode);
385 }
386
387 #[test]
388 fn test_median() {
389 let median = |x: Uniform| x.median();
390 test_exact(-0.0, 2.0, 1.0, median);
391 test_exact(0.0, 2.0, 1.0, median);
392 test_exact(0.1, 4.0, 2.05, median);
393 test_exact(1.0, 10.0, 5.5, median);
394 test_exact(10.0, 11.0, 10.5, median);
395 }
396
397 #[test]
398 fn test_pdf() {
399 let pdf = |arg: f64| move |x: Uniform| x.pdf(arg);
400 test_exact(0.0, 0.1, 0.0, pdf(-5.0));
401 test_exact(0.0, 0.1, 10.0, pdf(0.05));
402 test_exact(0.0, 0.1, 0.0, pdf(5.0));
403 test_exact(0.0, 1.0, 0.0, pdf(-5.0));
404 test_exact(0.0, 1.0, 1.0, pdf(0.5));
405 test_exact(0.0, 0.1, 0.0, pdf(5.0));
406 test_exact(0.0, 10.0, 0.0, pdf(-5.0));
407 test_exact(0.0, 10.0, 0.1, pdf(1.0));
408 test_exact(0.0, 10.0, 0.1, pdf(5.0));
409 test_exact(0.0, 10.0, 0.0, pdf(11.0));
410 test_exact(-5.0, 100.0, 0.0, pdf(-10.0));
411 test_exact(-5.0, 100.0, 0.009523809523809523809524, pdf(-5.0));
412 test_exact(-5.0, 100.0, 0.009523809523809523809524, pdf(0.0));
413 test_exact(-5.0, 100.0, 0.0, pdf(101.0));
414 }
415
416 #[test]
417 fn test_ln_pdf() {
418 let ln_pdf = |arg: f64| move |x: Uniform| x.ln_pdf(arg);
419 test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(-5.0));
420 test_absolute(0.0, 0.1, 2.302585092994045684018, 1e-15, ln_pdf(0.05));
421 test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0));
422 test_exact(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(-5.0));
423 test_exact(0.0, 1.0, 0.0, ln_pdf(0.5));
424 test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0));
425 test_exact(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(-5.0));
426 test_exact(0.0, 10.0, -2.302585092994045684018, ln_pdf(1.0));
427 test_exact(0.0, 10.0, -2.302585092994045684018, ln_pdf(5.0));
428 test_exact(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(11.0));
429 test_exact(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(-10.0));
430 test_exact(-5.0, 100.0, -4.653960350157523371101, ln_pdf(-5.0));
431 test_exact(-5.0, 100.0, -4.653960350157523371101, ln_pdf(0.0));
432 test_exact(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(101.0));
433 }
434
435 #[test]
436 fn test_cdf() {
437 let cdf = |arg: f64| move |x: Uniform| x.cdf(arg);
438 test_exact(0.0, 0.1, 0.5, cdf(0.05));
439 test_exact(0.0, 1.0, 0.5, cdf(0.5));
440 test_exact(0.0, 10.0, 0.1, cdf(1.0));
441 test_exact(0.0, 10.0, 0.5, cdf(5.0));
442 test_exact(-5.0, 100.0, 0.0, cdf(-5.0));
443 test_exact(-5.0, 100.0, 0.04761904761904761904762, cdf(0.0));
444 }
445
446 #[test]
447 fn test_inverse_cdf() {
448 let inverse_cdf = |arg: f64| move |x: Uniform| x.inverse_cdf(arg);
449 test_exact(0.0, 0.1, 0.05, inverse_cdf(0.5));
450 test_exact(0.0, 10.0, 5.0, inverse_cdf(0.5));
451 test_exact(1.0, 10.0, 1.0, inverse_cdf(0.0));
452 test_exact(1.0, 10.0, 4.0, inverse_cdf(1.0 / 3.0));
453 test_exact(1.0, 10.0, 10.0, inverse_cdf(1.0));
454 }
455
456 #[test]
457 fn test_cdf_lower_bound() {
458 let cdf = |arg: f64| move |x: Uniform| x.cdf(arg);
459 test_exact(0.0, 3.0, 0.0, cdf(-1.0));
460 }
461
462 #[test]
463 fn test_cdf_upper_bound() {
464 let cdf = |arg: f64| move |x: Uniform| x.cdf(arg);
465 test_exact(0.0, 3.0, 1.0, cdf(5.0));
466 }
467
468
469 #[test]
470 fn test_sf() {
471 let sf = |arg: f64| move |x: Uniform| x.sf(arg);
472 test_exact(0.0, 0.1, 0.5, sf(0.05));
473 test_exact(0.0, 1.0, 0.5, sf(0.5));
474 test_exact(0.0, 10.0, 0.9, sf(1.0));
475 test_exact(0.0, 10.0, 0.5, sf(5.0));
476 test_exact(-5.0, 100.0, 1.0, sf(-5.0));
477 test_exact(-5.0, 100.0, 0.9523809523809523, sf(0.0));
478 }
479
480 #[test]
481 fn test_sf_lower_bound() {
482 let sf = |arg: f64| move |x: Uniform| x.sf(arg);
483 test_exact(0.0, 3.0, 1.0, sf(-1.0));
484 }
485
486 #[test]
487 fn test_sf_upper_bound() {
488 let sf = |arg: f64| move |x: Uniform| x.sf(arg);
489 test_exact(0.0, 3.0, 0.0, sf(5.0));
490 }
491
492 #[test]
493 fn test_continuous() {
494 test::check_continuous_distribution(&create_ok(0.0, 10.0), 0.0, 10.0);
495 test::check_continuous_distribution(&create_ok(-2.0, 15.0), -2.0, 15.0);
496 }
497
498 #[cfg(feature = "rand")]
499#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
500 #[test]
501 fn test_samples_in_range() {
502 use rand::rngs::StdRng;
503 use rand::SeedableRng;
504 use rand::distributions::Distribution;
505
506 let seed = [
507 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
508 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
509 ];
510 let mut r: StdRng = SeedableRng::from_seed(seed);
511
512 let min = -0.5;
513 let max = 0.5;
514 let num_trials = 10_000;
515 let n = create_ok(min, max);
516
517 assert!((0..num_trials)
518 .map(|_| n.sample::<StdRng>(&mut r))
519 .all(|v| (min <= v) && (v < max))
520 );
521 }
522
523 #[test]
524 fn test_default() {
525 let n = Uniform::default();
526
527 let n_mean = n.mean().unwrap();
528 let n_std = n.std_dev().unwrap();
529
530 assert_almost_eq!(n_mean, 0.5, 1e-15);
532 assert_almost_eq!(n_std, 0.288_675_134_594_812_9, 1e-15);
534 }
535}