statrs/distribution/
uniform.rs

1use crate::distribution::{Continuous, ContinuousCDF};
2use crate::statistics::*;
3use std::f64;
4use std::fmt::Debug;
5
6/// Implements the [Continuous
7/// Uniform](https://en.wikipedia.org/wiki/Uniform_distribution_(continuous))
8/// distribution
9///
10/// # Examples
11///
12/// ```
13/// use statrs::distribution::{Uniform, Continuous};
14/// use statrs::statistics::Distribution;
15///
16/// let n = Uniform::new(0.0, 1.0).unwrap();
17/// assert_eq!(n.mean().unwrap(), 0.5);
18/// assert_eq!(n.pdf(0.5), 1.0);
19/// ```
20#[derive(Debug, Copy, Clone, PartialEq)]
21pub struct Uniform {
22    min: f64,
23    max: f64,
24}
25
26/// Represents the errors that can occur when creating a [`Uniform`].
27#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
28#[non_exhaustive]
29pub enum UniformError {
30    /// The minimum is NaN or infinite.
31    MinInvalid,
32
33    /// The maximum is NaN or infinite.
34    MaxInvalid,
35
36    /// The maximum is not greater than the minimum.
37    MaxNotGreaterThanMin,
38}
39
40impl std::fmt::Display for UniformError {
41    #[cfg_attr(coverage_nightly, coverage(off))]
42    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43        match self {
44            UniformError::MinInvalid => write!(f, "Minimum is NaN or infinite"),
45            UniformError::MaxInvalid => write!(f, "Maximum is NaN or infinite"),
46            UniformError::MaxNotGreaterThanMin => {
47                write!(f, "Maximum is not greater than the minimum")
48            }
49        }
50    }
51}
52
53impl std::error::Error for UniformError {}
54
55impl Uniform {
56    /// Constructs a new uniform distribution with a min of `min` and a max
57    /// of `max`.
58    ///
59    /// # Errors
60    ///
61    /// Returns an error if `min` or `max` are `NaN` or infinite.
62    /// Returns an error if `min >= max`.
63    ///
64    /// # Examples
65    ///
66    /// ```
67    /// use statrs::distribution::Uniform;
68    /// use std::f64;
69    ///
70    /// let mut result = Uniform::new(0.0, 1.0);
71    /// assert!(result.is_ok());
72    ///
73    /// result = Uniform::new(f64::NAN, f64::NAN);
74    /// assert!(result.is_err());
75    ///
76    /// result = Uniform::new(f64::NEG_INFINITY, 1.0);
77    /// assert!(result.is_err());
78    /// ```
79    pub fn new(min: f64, max: f64) -> Result<Uniform, UniformError> {
80        if !min.is_finite() {
81            return Err(UniformError::MinInvalid);
82        }
83
84        if !max.is_finite() {
85            return Err(UniformError::MaxInvalid);
86        }
87
88        if min < max {
89            Ok(Uniform { min, max })
90        } else {
91            Err(UniformError::MaxNotGreaterThanMin)
92        }
93    }
94
95    /// Constructs a new standard uniform distribution with
96    /// a lower bound 0 and an upper bound of 1.
97    ///
98    /// # Examples
99    ///
100    /// ```
101    /// use statrs::distribution::Uniform;
102    ///
103    /// let uniform = Uniform::standard();
104    /// ```
105    pub fn standard() -> Self {
106        Self { min: 0.0, max: 1.0 }
107    }
108}
109
110impl Default for Uniform {
111    fn default() -> Self {
112        Self::standard()
113    }
114}
115
116impl std::fmt::Display for Uniform {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        write!(f, "Uni([{},{}])", self.min, self.max)
119    }
120}
121
122#[cfg(feature = "rand")]
123#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
124impl ::rand::distributions::Distribution<f64> for Uniform {
125    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
126        let d = rand::distributions::Uniform::new_inclusive(self.min, self.max);
127        rng.sample(d)
128    }
129}
130
131impl ContinuousCDF<f64, f64> for Uniform {
132    /// Calculates the cumulative distribution function for the uniform
133    /// distribution
134    /// at `x`
135    ///
136    /// # Formula
137    ///
138    /// ```text
139    /// (x - min) / (max - min)
140    /// ```
141    fn cdf(&self, x: f64) -> f64 {
142        if x <= self.min {
143            0.0
144        } else if x >= self.max {
145            1.0
146        } else {
147            (x - self.min) / (self.max - self.min)
148        }
149    }
150
151    /// Calculates the survival function for the uniform
152    /// distribution at `x`
153    ///
154    /// # Formula
155    ///
156    /// ```text
157    /// (max - x) / (max - min)
158    /// ```
159    fn sf(&self, x: f64) -> f64 {
160        if x <= self.min {
161            1.0
162        } else if x >= self.max {
163            0.0
164        } else {
165            (self.max - x) / (self.max - self.min)
166        }
167    }
168
169    /// Finds the value of `x` where `F(p) = x`
170    fn inverse_cdf(&self, p: f64) -> f64 {
171        if !(0.0..=1.0).contains(&p) {
172            panic!("p must be in [0, 1], was {p}");
173        } else if p == 0.0 {
174            self.min
175        } else if p == 1.0 {
176            self.max
177        } else {
178            (self.max - self.min) * p + self.min
179        }
180    }
181}
182
183impl Min<f64> for Uniform {
184    fn min(&self) -> f64 {
185        self.min
186    }
187}
188
189impl Max<f64> for Uniform {
190    fn max(&self) -> f64 {
191        self.max
192    }
193}
194
195impl Distribution<f64> for Uniform {
196    /// Returns the mean for the continuous uniform distribution
197    ///
198    /// # Formula
199    ///
200    /// ```text
201    /// (min + max) / 2
202    /// ```
203    fn mean(&self) -> Option<f64> {
204        Some((self.min + self.max) / 2.0)
205    }
206
207    /// Returns the variance for the continuous uniform distribution
208    ///
209    /// # Formula
210    ///
211    /// ```text
212    /// (max - min)^2 / 12
213    /// ```
214    fn variance(&self) -> Option<f64> {
215        Some((self.max - self.min) * (self.max - self.min) / 12.0)
216    }
217
218    /// Returns the entropy for the continuous uniform distribution
219    ///
220    /// # Formula
221    ///
222    /// ```text
223    /// ln(max - min)
224    /// ```
225    fn entropy(&self) -> Option<f64> {
226        Some((self.max - self.min).ln())
227    }
228
229    /// Returns the skewness for the continuous uniform distribution
230    ///
231    /// # Formula
232    ///
233    /// ```text
234    /// 0
235    /// ```
236    fn skewness(&self) -> Option<f64> {
237        Some(0.0)
238    }
239}
240
241impl Median<f64> for Uniform {
242    /// Returns the median for the continuous uniform distribution
243    ///
244    /// # Formula
245    ///
246    /// ```text
247    /// (min + max) / 2
248    /// ```
249    fn median(&self) -> f64 {
250        (self.min + self.max) / 2.0
251    }
252}
253
254impl Mode<Option<f64>> for Uniform {
255    /// Returns the mode for the continuous uniform distribution
256    ///
257    /// # Remarks
258    ///
259    /// Since every element has an equal probability, mode simply
260    /// returns the middle element
261    ///
262    /// # Formula
263    ///
264    /// ```text
265    /// N/A // (max + min) / 2 for the middle element
266    /// ```
267    fn mode(&self) -> Option<f64> {
268        Some((self.min + self.max) / 2.0)
269    }
270}
271
272impl Continuous<f64, f64> for Uniform {
273    /// Calculates the probability density function for the continuous uniform
274    /// distribution at `x`
275    ///
276    /// # Remarks
277    ///
278    /// Returns `0.0` if `x` is not in `[min, max]`
279    ///
280    /// # Formula
281    ///
282    /// ```text
283    /// 1 / (max - min)
284    /// ```
285    fn pdf(&self, x: f64) -> f64 {
286        if x < self.min || x > self.max {
287            0.0
288        } else {
289            1.0 / (self.max - self.min)
290        }
291    }
292
293    /// Calculates the log probability density function for the continuous
294    /// uniform
295    /// distribution at `x`
296    ///
297    /// # Remarks
298    ///
299    /// Returns `f64::NEG_INFINITY` if `x` is not in `[min, max]`
300    ///
301    /// # Formula
302    ///
303    /// ```text
304    /// ln(1 / (max - min))
305    /// ```
306    fn ln_pdf(&self, x: f64) -> f64 {
307        if x < self.min || x > self.max {
308            f64::NEG_INFINITY
309        } else {
310            -(self.max - self.min).ln()
311        }
312    }
313}
314
315#[rustfmt::skip]
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use crate::distribution::internal::*;
320    use crate::testing_boiler;
321
322    testing_boiler!(min: f64, max: f64; Uniform; UniformError);
323
324    #[test]
325    fn test_create() {
326        create_ok(0.0, 0.1);
327        create_ok(0.0, 1.0);
328        create_ok(-5.0, 11.0);
329        create_ok(-5.0, 100.0);
330    }
331
332    #[test]
333    fn test_bad_create() {
334        let invalid = [
335            (0.0, 0.0, UniformError::MaxNotGreaterThanMin),
336            (f64::NAN, 1.0, UniformError::MinInvalid),
337            (1.0, f64::NAN, UniformError::MaxInvalid),
338            (f64::NAN, f64::NAN, UniformError::MinInvalid),
339            (0.0, f64::INFINITY, UniformError::MaxInvalid),
340            (1.0, 0.0, UniformError::MaxNotGreaterThanMin),
341        ];
342        
343        for (min, max, err) in invalid {
344            test_create_err(min, max, err);
345        }
346    }
347
348    #[test]
349    fn test_variance() {
350        let variance = |x: Uniform| x.variance().unwrap();
351        test_exact(-0.0, 2.0, 1.0 / 3.0, variance);
352        test_exact(0.0, 2.0, 1.0 / 3.0, variance);
353        test_absolute(0.1, 4.0, 1.2675, 1e-15, variance);
354        test_exact(10.0, 11.0, 1.0 / 12.0, variance);
355    }
356
357    #[test]
358    fn test_entropy() {
359        let entropy = |x: Uniform| x.entropy().unwrap();
360        test_exact(-0.0, 2.0, 0.6931471805599453094172, entropy);
361        test_exact(0.0, 2.0, 0.6931471805599453094172, entropy);
362        test_absolute(0.1, 4.0, 1.360976553135600743431, 1e-15, entropy);
363        test_exact(1.0, 10.0, 2.19722457733621938279, entropy);
364        test_exact(10.0, 11.0, 0.0, entropy);
365    }
366
367    #[test]
368    fn test_skewness() {
369        let skewness = |x: Uniform| x.skewness().unwrap();
370        test_exact(-0.0, 2.0, 0.0, skewness);
371        test_exact(0.0, 2.0, 0.0, skewness);
372        test_exact(0.1, 4.0, 0.0, skewness);
373        test_exact(1.0, 10.0, 0.0, skewness);
374        test_exact(10.0, 11.0, 0.0, skewness);
375    }
376
377    #[test]
378    fn test_mode() {
379        let mode = |x: Uniform| x.mode().unwrap();
380        test_exact(-0.0, 2.0, 1.0, mode);
381        test_exact(0.0, 2.0, 1.0, mode);
382        test_exact(0.1, 4.0, 2.05, mode);
383        test_exact(1.0, 10.0, 5.5, mode);
384        test_exact(10.0, 11.0, 10.5, mode);
385    }
386
387    #[test]
388    fn test_median() {
389        let median = |x: Uniform| x.median();
390        test_exact(-0.0, 2.0, 1.0, median);
391        test_exact(0.0, 2.0, 1.0, median);
392        test_exact(0.1, 4.0, 2.05, median);
393        test_exact(1.0, 10.0, 5.5, median);
394        test_exact(10.0, 11.0, 10.5, median);
395    }
396
397    #[test]
398    fn test_pdf() {
399        let pdf = |arg: f64| move |x: Uniform| x.pdf(arg);
400        test_exact(0.0, 0.1, 0.0, pdf(-5.0));
401        test_exact(0.0, 0.1, 10.0, pdf(0.05));
402        test_exact(0.0, 0.1, 0.0, pdf(5.0));
403        test_exact(0.0, 1.0, 0.0, pdf(-5.0));
404        test_exact(0.0, 1.0, 1.0, pdf(0.5));
405        test_exact(0.0, 0.1, 0.0, pdf(5.0));
406        test_exact(0.0, 10.0, 0.0, pdf(-5.0));
407        test_exact(0.0, 10.0, 0.1, pdf(1.0));
408        test_exact(0.0, 10.0, 0.1, pdf(5.0));
409        test_exact(0.0, 10.0, 0.0, pdf(11.0));
410        test_exact(-5.0, 100.0, 0.0, pdf(-10.0));
411        test_exact(-5.0, 100.0, 0.009523809523809523809524, pdf(-5.0));
412        test_exact(-5.0, 100.0, 0.009523809523809523809524, pdf(0.0));
413        test_exact(-5.0, 100.0, 0.0, pdf(101.0));
414    }
415
416    #[test]
417    fn test_ln_pdf() {
418        let ln_pdf = |arg: f64| move |x: Uniform| x.ln_pdf(arg);
419        test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(-5.0));
420        test_absolute(0.0, 0.1, 2.302585092994045684018, 1e-15, ln_pdf(0.05));
421        test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0));
422        test_exact(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(-5.0));
423        test_exact(0.0, 1.0, 0.0, ln_pdf(0.5));
424        test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0));
425        test_exact(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(-5.0));
426        test_exact(0.0, 10.0, -2.302585092994045684018, ln_pdf(1.0));
427        test_exact(0.0, 10.0, -2.302585092994045684018, ln_pdf(5.0));
428        test_exact(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(11.0));
429        test_exact(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(-10.0));
430        test_exact(-5.0, 100.0, -4.653960350157523371101, ln_pdf(-5.0));
431        test_exact(-5.0, 100.0, -4.653960350157523371101, ln_pdf(0.0));
432        test_exact(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(101.0));
433    }
434
435    #[test]
436    fn test_cdf() {
437        let cdf = |arg: f64| move |x: Uniform| x.cdf(arg);
438        test_exact(0.0, 0.1, 0.5, cdf(0.05));
439        test_exact(0.0, 1.0, 0.5, cdf(0.5));
440        test_exact(0.0, 10.0, 0.1, cdf(1.0));
441        test_exact(0.0, 10.0, 0.5, cdf(5.0));
442        test_exact(-5.0, 100.0, 0.0, cdf(-5.0));
443        test_exact(-5.0, 100.0, 0.04761904761904761904762, cdf(0.0));
444    }
445
446    #[test]
447    fn test_inverse_cdf() {
448        let inverse_cdf = |arg: f64| move |x: Uniform| x.inverse_cdf(arg);
449        test_exact(0.0, 0.1, 0.05, inverse_cdf(0.5));
450        test_exact(0.0, 10.0, 5.0, inverse_cdf(0.5));
451        test_exact(1.0, 10.0, 1.0, inverse_cdf(0.0));
452        test_exact(1.0, 10.0, 4.0, inverse_cdf(1.0 / 3.0));
453        test_exact(1.0, 10.0, 10.0, inverse_cdf(1.0));
454    }
455
456    #[test]
457    fn test_cdf_lower_bound() {
458        let cdf = |arg: f64| move |x: Uniform| x.cdf(arg);
459        test_exact(0.0, 3.0, 0.0, cdf(-1.0));
460    }
461
462    #[test]
463    fn test_cdf_upper_bound() {
464        let cdf = |arg: f64| move |x: Uniform| x.cdf(arg);
465        test_exact(0.0, 3.0, 1.0, cdf(5.0));
466    }
467
468
469    #[test]
470    fn test_sf() {
471        let sf = |arg: f64| move |x: Uniform| x.sf(arg);
472        test_exact(0.0, 0.1, 0.5, sf(0.05));
473        test_exact(0.0, 1.0, 0.5, sf(0.5));
474        test_exact(0.0, 10.0, 0.9, sf(1.0));
475        test_exact(0.0, 10.0, 0.5, sf(5.0));
476        test_exact(-5.0, 100.0, 1.0, sf(-5.0));
477        test_exact(-5.0, 100.0, 0.9523809523809523, sf(0.0));
478    }
479
480    #[test]
481    fn test_sf_lower_bound() {
482        let sf = |arg: f64| move |x: Uniform| x.sf(arg);
483        test_exact(0.0, 3.0, 1.0, sf(-1.0));
484    }
485
486    #[test]
487    fn test_sf_upper_bound() {
488        let sf = |arg: f64| move |x: Uniform| x.sf(arg);
489        test_exact(0.0, 3.0, 0.0, sf(5.0));
490    }
491
492    #[test]
493    fn test_continuous() {
494        test::check_continuous_distribution(&create_ok(0.0, 10.0), 0.0, 10.0);
495        test::check_continuous_distribution(&create_ok(-2.0, 15.0), -2.0, 15.0);
496    }
497
498    #[cfg(feature = "rand")]
499#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
500    #[test]
501    fn test_samples_in_range() {
502        use rand::rngs::StdRng;
503        use rand::SeedableRng;
504        use rand::distributions::Distribution;
505
506        let seed = [
507            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
508            19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
509        ];
510        let mut r: StdRng = SeedableRng::from_seed(seed);
511
512        let min = -0.5;
513        let max = 0.5;
514        let num_trials = 10_000;
515        let n = create_ok(min, max);
516
517        assert!((0..num_trials)
518            .map(|_| n.sample::<StdRng>(&mut r))
519            .all(|v| (min <= v) && (v < max))
520        );
521    }
522
523    #[test]
524    fn test_default() {
525        let n = Uniform::default();
526
527        let n_mean = n.mean().unwrap();
528        let n_std  = n.std_dev().unwrap();
529
530        // Check that the mean of the distribution is close to 1 / 2
531        assert_almost_eq!(n_mean, 0.5, 1e-15);
532        // Check that the standard deviation of the distribution is close to 1 / sqrt(12)
533        assert_almost_eq!(n_std, 0.288_675_134_594_812_9, 1e-15);
534    }
535}