statrs/distribution/
normal.rs

1use crate::consts;
2use crate::distribution::{Continuous, ContinuousCDF};
3use crate::function::erf;
4use crate::statistics::*;
5use std::f64;
6
7/// Implements the [Normal](https://en.wikipedia.org/wiki/Normal_distribution)
8/// distribution
9///
10/// # Examples
11///
12/// ```
13/// use statrs::distribution::{Normal, Continuous};
14/// use statrs::statistics::Distribution;
15///
16/// let n = Normal::new(0.0, 1.0).unwrap();
17/// assert_eq!(n.mean().unwrap(), 0.0);
18/// assert_eq!(n.pdf(1.0), 0.2419707245191433497978);
19/// ```
20#[derive(Copy, Clone, PartialEq, Debug)]
21pub struct Normal {
22    mean: f64,
23    std_dev: f64,
24}
25
26/// Represents the errors that can occur when creating a [`Normal`].
27#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
28#[non_exhaustive]
29pub enum NormalError {
30    /// The mean is NaN.
31    MeanInvalid,
32
33    /// The standard deviation is NaN, zero or less than zero.
34    StandardDeviationInvalid,
35}
36
37impl std::fmt::Display for NormalError {
38    #[cfg_attr(coverage_nightly, coverage(off))]
39    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
40        match self {
41            NormalError::MeanInvalid => write!(f, "Mean is NaN"),
42            NormalError::StandardDeviationInvalid => {
43                write!(f, "Standard deviation is NaN, zero or less than zero")
44            }
45        }
46    }
47}
48
49impl std::error::Error for NormalError {}
50
51impl Normal {
52    ///  Constructs a new normal distribution with a mean of `mean`
53    /// and a standard deviation of `std_dev`
54    ///
55    /// # Errors
56    ///
57    /// Returns an error if `mean` or `std_dev` are `NaN` or if
58    /// `std_dev <= 0.0`
59    ///
60    /// # Examples
61    ///
62    /// ```
63    /// use statrs::distribution::Normal;
64    ///
65    /// let mut result = Normal::new(0.0, 1.0);
66    /// assert!(result.is_ok());
67    ///
68    /// result = Normal::new(0.0, 0.0);
69    /// assert!(result.is_err());
70    /// ```
71    pub fn new(mean: f64, std_dev: f64) -> Result<Normal, NormalError> {
72        if mean.is_nan() {
73            return Err(NormalError::MeanInvalid);
74        }
75
76        if std_dev.is_nan() || std_dev <= 0.0 {
77            return Err(NormalError::StandardDeviationInvalid);
78        }
79
80        Ok(Normal { mean, std_dev })
81    }
82
83    /// Constructs a new standard normal distribution with a mean of 0
84    /// and a standard deviation of 1.
85    ///
86    ///
87    /// # Examples
88    ///
89    /// ```
90    /// use statrs::distribution::Normal;
91    ///
92    /// let mut result = Normal::standard();
93    /// ```
94    pub fn standard() -> Normal {
95        Normal {
96            mean: 0.0,
97            std_dev: 1.0,
98        }
99    }
100}
101
102impl std::fmt::Display for Normal {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        write!(f, "N({},{})", self.mean, self.std_dev)
105    }
106}
107
108#[cfg(feature = "rand")]
109#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
110impl ::rand::distributions::Distribution<f64> for Normal {
111    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
112        sample_unchecked(rng, self.mean, self.std_dev)
113    }
114}
115
116impl ContinuousCDF<f64, f64> for Normal {
117    /// Calculates the cumulative distribution function for the
118    /// normal distribution at `x`
119    ///
120    /// # Formula
121    ///
122    /// ```text
123    /// (1 / 2) * (1 + erf((x - μ) / (σ * sqrt(2))))
124    /// ```
125    ///
126    /// where `μ` is the mean, `σ` is the standard deviation, and
127    /// `erf` is the error function
128    fn cdf(&self, x: f64) -> f64 {
129        cdf_unchecked(x, self.mean, self.std_dev)
130    }
131
132    /// Calculates the survival function for the
133    /// normal distribution at `x`
134    ///
135    /// # Formula
136    ///
137    /// ```text
138    /// (1 / 2) * (1 + erf(-(x - μ) / (σ * sqrt(2))))
139    /// ```
140    ///
141    /// where `μ` is the mean, `σ` is the standard deviation, and
142    /// `erf` is the error function
143    ///
144    /// note that this calculates the complement due to flipping
145    /// the sign of the argument error function with respect to the cdf.
146    ///
147    /// the normal cdf Φ (and internal error function) as the following property:
148    /// ```text
149    ///  Φ(-x) + Φ(x) = 1
150    ///  Φ(-x)        = 1 - Φ(x)
151    /// ```
152    fn sf(&self, x: f64) -> f64 {
153        sf_unchecked(x, self.mean, self.std_dev)
154    }
155
156    /// Calculates the inverse cumulative distribution function for the
157    /// normal distribution at `x`.
158    /// In other languages, such as R, this is known as the the quantile function.
159    ///
160    /// # Panics
161    ///
162    /// If `x < 0.0` or `x > 1.0`
163    ///
164    /// # Formula
165    ///
166    /// ```text
167    /// μ - sqrt(2) * σ * erfc_inv(2x)
168    /// ```
169    ///
170    /// where `μ` is the mean, `σ` is the standard deviation and `erfc_inv` is
171    /// the inverse of the complementary error function
172    fn inverse_cdf(&self, x: f64) -> f64 {
173        if !(0.0..=1.0).contains(&x) {
174            panic!("x must be in [0, 1]");
175        } else {
176            self.mean - (self.std_dev * f64::consts::SQRT_2 * erf::erfc_inv(2.0 * x))
177        }
178    }
179}
180
181impl Min<f64> for Normal {
182    /// Returns the minimum value in the domain of the
183    /// normal distribution representable by a double precision float
184    ///
185    /// # Formula
186    ///
187    /// ```text
188    /// f64::NEG_INFINITY
189    /// ```
190    fn min(&self) -> f64 {
191        f64::NEG_INFINITY
192    }
193}
194
195impl Max<f64> for Normal {
196    /// Returns the maximum value in the domain of the
197    /// normal distribution representable by a double precision float
198    ///
199    /// # Formula
200    ///
201    /// ```text
202    /// f64::INFINITY
203    /// ```
204    fn max(&self) -> f64 {
205        f64::INFINITY
206    }
207}
208
209impl Distribution<f64> for Normal {
210    /// Returns the mean of the normal distribution
211    ///
212    /// # Remarks
213    ///
214    /// This is the same mean used to construct the distribution
215    fn mean(&self) -> Option<f64> {
216        Some(self.mean)
217    }
218
219    /// Returns the variance of the normal distribution
220    ///
221    /// # Formula
222    ///
223    /// ```text
224    /// σ^2
225    /// ```
226    ///
227    /// where `σ` is the standard deviation
228    fn variance(&self) -> Option<f64> {
229        Some(self.std_dev * self.std_dev)
230    }
231
232    /// Returns the standard deviation of the normal distribution
233    /// # Remarks
234    /// This is the same standard deviation used to construct the distribution
235    fn std_dev(&self) -> Option<f64> {
236        Some(self.std_dev)
237    }
238
239    /// Returns the entropy of the normal distribution
240    ///
241    /// # Formula
242    ///
243    /// ```text
244    /// (1 / 2) * ln(2σ^2 * π * e)
245    /// ```
246    ///
247    /// where `σ` is the standard deviation
248    fn entropy(&self) -> Option<f64> {
249        Some(self.std_dev.ln() + consts::LN_SQRT_2PIE)
250    }
251
252    /// Returns the skewness of the normal distribution
253    ///
254    /// # Formula
255    ///
256    /// ```text
257    /// 0
258    /// ```
259    fn skewness(&self) -> Option<f64> {
260        Some(0.0)
261    }
262}
263
264impl Median<f64> for Normal {
265    /// Returns the median of the normal distribution
266    ///
267    /// # Formula
268    ///
269    /// ```text
270    /// μ
271    /// ```
272    ///
273    /// where `μ` is the mean
274    fn median(&self) -> f64 {
275        self.mean
276    }
277}
278
279impl Mode<Option<f64>> for Normal {
280    /// Returns the mode of the normal distribution
281    ///
282    /// # Formula
283    ///
284    /// ```text
285    /// μ
286    /// ```
287    ///
288    /// where `μ` is the mean
289    fn mode(&self) -> Option<f64> {
290        Some(self.mean)
291    }
292}
293
294impl Continuous<f64, f64> for Normal {
295    /// Calculates the probability density function for the normal distribution
296    /// at `x`
297    ///
298    /// # Formula
299    ///
300    /// ```text
301    /// (1 / sqrt(2σ^2 * π)) * e^(-(x - μ)^2 / 2σ^2)
302    /// ```
303    ///
304    /// where `μ` is the mean and `σ` is the standard deviation
305    fn pdf(&self, x: f64) -> f64 {
306        pdf_unchecked(x, self.mean, self.std_dev)
307    }
308
309    /// Calculates the log probability density function for the normal
310    /// distribution
311    /// at `x`
312    ///
313    /// # Formula
314    ///
315    /// ```text
316    /// ln((1 / sqrt(2σ^2 * π)) * e^(-(x - μ)^2 / 2σ^2))
317    /// ```
318    ///
319    /// where `μ` is the mean and `σ` is the standard deviation
320    fn ln_pdf(&self, x: f64) -> f64 {
321        ln_pdf_unchecked(x, self.mean, self.std_dev)
322    }
323}
324
325/// performs an unchecked cdf calculation for a normal distribution
326/// with the given mean and standard deviation at x
327pub fn cdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 {
328    0.5 * erf::erfc((mean - x) / (std_dev * f64::consts::SQRT_2))
329}
330
331/// performs an unchecked sf calculation for a normal distribution
332/// with the given mean and standard deviation at x
333pub fn sf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 {
334    0.5 * erf::erfc((x - mean) / (std_dev * f64::consts::SQRT_2))
335}
336
337/// performs an unchecked pdf calculation for a normal distribution
338/// with the given mean and standard deviation at x
339pub fn pdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 {
340    let d = (x - mean) / std_dev;
341    (-0.5 * d * d).exp() / (consts::SQRT_2PI * std_dev)
342}
343
344/// performs an unchecked log(pdf) calculation for a normal distribution
345/// with the given mean and standard deviation at x
346pub fn ln_pdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 {
347    let d = (x - mean) / std_dev;
348    (-0.5 * d * d) - consts::LN_SQRT_2PI - std_dev.ln()
349}
350
351#[cfg(feature = "rand")]
352#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
353/// draws a sample from a normal distribution using the Box-Muller algorithm
354pub fn sample_unchecked<R: ::rand::Rng + ?Sized>(rng: &mut R, mean: f64, std_dev: f64) -> f64 {
355    use crate::distribution::ziggurat;
356
357    mean + std_dev * ziggurat::sample_std_normal(rng)
358}
359
360impl std::default::Default for Normal {
361    /// Returns the standard normal distribution with a mean of 0
362    /// and a standard deviation of 1.
363    fn default() -> Self {
364        Self::standard()
365    }
366}
367
368#[rustfmt::skip]
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use crate::distribution::internal::*;
373    use crate::testing_boiler;
374
375    testing_boiler!(mean: f64, std_dev: f64; Normal; NormalError);
376
377    #[test]
378    fn test_create() {
379        create_ok(10.0, 0.1);
380        create_ok(-5.0, 1.0);
381        create_ok(0.0, 10.0);
382        create_ok(10.0, 100.0);
383        create_ok(-5.0, f64::INFINITY);
384    }
385
386    #[test]
387    fn test_bad_create() {
388        test_create_err(f64::NAN, 1.0, NormalError::MeanInvalid);
389        test_create_err(1.0, f64::NAN, NormalError::StandardDeviationInvalid);
390        create_err(0.0, 0.0);
391        create_err(f64::NAN, f64::NAN);
392        create_err(1.0, -1.0);
393    }
394
395    #[test]
396    fn test_variance() {
397        let variance = |x: Normal| x.variance().unwrap();
398        test_exact(0.0, 0.1, 0.1 * 0.1, variance);
399        test_exact(0.0, 1.0, 1.0, variance);
400        test_exact(0.0, 10.0, 100.0, variance);
401        test_exact(0.0, f64::INFINITY, f64::INFINITY, variance);
402    }
403
404    #[test]
405    fn test_entropy() {
406        let entropy = |x: Normal| x.entropy().unwrap();
407        test_absolute(0.0, 0.1, -0.8836465597893729422377, 1e-15, entropy);
408        test_exact(0.0, 1.0, 1.41893853320467274178, entropy);
409        test_exact(0.0, 10.0, 3.721523626198718425798, entropy);
410        test_exact(0.0, f64::INFINITY, f64::INFINITY, entropy);
411    }
412
413    #[test]
414    fn test_skewness() {
415        let skewness = |x: Normal| x.skewness().unwrap();
416        test_exact(0.0, 0.1, 0.0, skewness);
417        test_exact(4.0, 1.0, 0.0, skewness);
418        test_exact(0.3, 10.0, 0.0, skewness);
419        test_exact(0.0, f64::INFINITY, 0.0, skewness);
420    }
421
422    #[test]
423    fn test_mode() {
424        let mode = |x: Normal| x.mode().unwrap();
425        test_exact(-0.0, 1.0, 0.0, mode);
426        test_exact(0.0, 1.0, 0.0, mode);
427        test_exact(0.1, 1.0, 0.1, mode);
428        test_exact(1.0, 1.0, 1.0, mode);
429        test_exact(-10.0, 1.0, -10.0, mode);
430        test_exact(f64::INFINITY, 1.0, f64::INFINITY, mode);
431    }
432
433    #[test]
434    fn test_median() {
435        let median = |x: Normal| x.median();
436        test_exact(-0.0, 1.0, 0.0, median);
437        test_exact(0.0, 1.0, 0.0, median);
438        test_exact(0.1, 1.0, 0.1, median);
439        test_exact(1.0, 1.0, 1.0, median);
440        test_exact(-0.0, 1.0, -0.0, median);
441        test_exact(f64::INFINITY, 1.0, f64::INFINITY, median);
442    }
443
444    #[test]
445    fn test_min_max() {
446        let min = |x: Normal| x.min();
447        let max = |x: Normal| x.max();
448        test_exact(0.0, 0.1, f64::NEG_INFINITY, min);
449        test_exact(-3.0, 10.0, f64::NEG_INFINITY, min);
450        test_exact(0.0, 0.1, f64::INFINITY, max);
451        test_exact(-3.0, 10.0, f64::INFINITY, max);
452    }
453
454    #[test]
455    fn test_pdf() {
456        let pdf = |arg: f64| move |x: Normal| x.pdf(arg);
457        test_absolute(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(8.5));
458        test_absolute(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(9.8));
459        test_absolute(10.0, 0.1, 3.989422804014326779399, 1e-15, pdf(10.0));
460        test_absolute(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(10.2));
461        test_absolute(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(11.5));
462        test_exact(-5.0, 1.0, 1.486719514734297707908E-6, pdf(-10.0));
463        test_exact(-5.0, 1.0, 0.01752830049356853736216, pdf(-7.5));
464        test_absolute(-5.0, 1.0, 0.3989422804014326779399, 1e-16, pdf(-5.0));
465        test_exact(-5.0, 1.0, 0.01752830049356853736216, pdf(-2.5));
466        test_exact(-5.0, 1.0, 1.486719514734297707908E-6, pdf(0.0));
467        test_exact(0.0, 10.0, 0.03520653267642994777747, pdf(-5.0));
468        test_absolute(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(-2.5));
469        test_absolute(0.0, 10.0, 0.03989422804014326779399, 1e-17, pdf(0.0));
470        test_absolute(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(2.5));
471        test_exact(0.0, 10.0, 0.03520653267642994777747, pdf(5.0));
472        test_absolute(10.0, 100.0, 4.398359598042719404845E-4, 1e-19, pdf(-200.0));
473        test_exact(10.0, 100.0, 0.002178521770325505313831, pdf(-100.0));
474        test_exact(10.0, 100.0, 0.003969525474770117655105, pdf(0.0));
475        test_absolute(10.0, 100.0, 0.002660852498987548218204, 1e-18, pdf(100.0));
476        test_exact(10.0, 100.0, 6.561581477467659126534E-4, pdf(200.0));
477        test_exact(-5.0, f64::INFINITY, 0.0, pdf(-5.0));
478        test_exact(-5.0, f64::INFINITY, 0.0, pdf(0.0));
479        test_exact(-5.0, f64::INFINITY, 0.0, pdf(100.0));
480    }
481
482    #[test]
483    fn test_ln_pdf() {
484        let ln_pdf = |arg: f64| move |x: Normal| x.ln_pdf(arg);
485        test_absolute(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(8.5));
486        test_absolute(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(9.8));
487        test_absolute(10.0, 0.1, (3.989422804014326779399f64).ln(), 1e-15, ln_pdf(10.0));
488        test_absolute(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(10.2));
489        test_absolute(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(11.5));
490        test_exact(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(-10.0));
491        test_exact(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-7.5));
492        test_absolute(-5.0, 1.0, (0.3989422804014326779399f64).ln(), 1e-15, ln_pdf(-5.0));
493        test_exact(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-2.5));
494        test_exact(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(0.0));
495        test_exact(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(-5.0));
496        test_exact(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(-2.5));
497        test_exact(0.0, 10.0, (0.03989422804014326779399f64).ln(), ln_pdf(0.0));
498        test_exact(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(2.5));
499        test_exact(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(5.0));
500        test_exact(10.0, 100.0, (4.398359598042719404845E-4f64).ln(), ln_pdf(-200.0));
501        test_exact(10.0, 100.0, (0.002178521770325505313831f64).ln(), ln_pdf(-100.0));
502        test_absolute(10.0, 100.0, (0.003969525474770117655105f64).ln(),1e-15, ln_pdf(0.0));
503        test_absolute(10.0, 100.0, (0.002660852498987548218204f64).ln(), 1e-15, ln_pdf(100.0));
504        test_absolute(10.0, 100.0, (6.561581477467659126534E-4f64).ln(), 1e-15, ln_pdf(200.0));
505        test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0));
506        test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0));
507        test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(100.0));
508    }
509
510    #[test]
511    fn test_cdf() {
512        let cdf = |arg: f64| move |x: Normal| x.cdf(arg);
513        test_exact(5.0, 2.0, 0.0, cdf(f64::NEG_INFINITY));
514        test_absolute(5.0, 2.0, 0.0000002866515718, 1e-16, cdf(-5.0));
515        test_absolute(5.0, 2.0, 0.0002326290790, 1e-13, cdf(-2.0));
516        test_absolute(5.0, 2.0, 0.006209665325, 1e-12, cdf(0.0));
517        test_exact(5.0, 2.0, 0.30853753872598689636229538939166226011639782444542207, cdf(4.0));
518        test_exact(5.0, 2.0, 0.5, cdf(5.0));
519        test_exact(5.0, 2.0, 0.69146246127401310363770461060833773988360217555457859, cdf(6.0));
520        test_absolute(5.0, 2.0, 0.993790334674, 1e-12, cdf(10.0));
521    }
522
523    #[test]
524    fn test_sf() {
525        let sf = |arg: f64| move |x: Normal| x.sf(arg);
526        test_exact(5.0, 2.0, 1.0, sf(f64::NEG_INFINITY));
527        test_absolute(5.0, 2.0, 0.9999997133484281, 1e-16, sf(-5.0));
528        test_absolute(5.0, 2.0, 0.9997673709209455, 1e-13, sf(-2.0));
529        test_absolute(5.0, 2.0, 0.9937903346744879, 1e-12, sf(0.0));
530        test_exact(5.0, 2.0, 0.6914624612740131, sf(4.0));
531        test_exact(5.0, 2.0, 0.5, sf(5.0));
532        test_exact(5.0, 2.0, 0.3085375387259869, sf(6.0));
533        test_absolute(5.0, 2.0, 0.006209665325512148, 1e-12, sf(10.0));
534    }
535
536    #[test]
537    fn test_continuous() {
538        test::check_continuous_distribution(&create_ok(0.0, 1.0), -10.0, 10.0);
539        test::check_continuous_distribution(&create_ok(20.0, 0.5), 10.0, 30.0);
540    }
541
542    #[test]
543    fn test_inverse_cdf() {
544        let inverse_cdf = |arg: f64| move |x: Normal| x.inverse_cdf(arg);
545        test_exact(5.0, 2.0, f64::NEG_INFINITY, inverse_cdf( 0.0));
546        test_absolute(5.0, 2.0, -5.0, 1e-14, inverse_cdf(0.00000028665157187919391167375233287464535385442301361187883));
547        test_absolute(5.0, 2.0, -2.0, 1e-14, inverse_cdf(0.0002326290790355250363499258867279847735487493358890356));
548        test_absolute(5.0, 2.0, -0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036));
549        test_absolute(5.0, 2.0, 0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036));
550        test_absolute(5.0, 2.0, 4.0, 1e-14, inverse_cdf(0.30853753872598689636229538939166226011639782444542207));
551        test_absolute(5.0, 2.0, 5.0, 1e-14, inverse_cdf(0.5));
552        test_absolute(5.0, 2.0, 6.0, 1e-14, inverse_cdf(0.69146246127401310363770461060833773988360217555457859));
553        test_absolute(5.0, 2.0, 10.0, 1e-14, inverse_cdf(0.9937903346742238648330218954258077788721022530769078));
554        test_exact(5.0, 2.0, f64::INFINITY, inverse_cdf(1.0));
555    }
556
557    #[test]
558    fn test_default() {
559        let n = Normal::default();
560
561        let n_mean = n.mean().unwrap();
562        let n_std  = n.std_dev().unwrap();
563
564        // Check that the mean of the distribution is close to 0
565        assert_almost_eq!(n_mean, 0.0, 1e-15);
566        // Check that the standard deviation of the distribution is close to 1
567        assert_almost_eq!(n_std, 1.0, 1e-15);
568    }
569}