statrs/distribution/
normal.rs

1use crate::distribution::{ziggurat, Continuous, ContinuousCDF};
2use crate::function::erf;
3use crate::statistics::*;
4use crate::{consts, Result, StatsError};
5use rand::Rng;
6use std::f64;
7
8/// Implements the [Normal](https://en.wikipedia.org/wiki/Normal_distribution)
9/// distribution
10///
11/// # Examples
12///
13/// ```
14/// use statrs::distribution::{Normal, Continuous};
15/// use statrs::statistics::Distribution;
16///
17/// let n = Normal::new(0.0, 1.0).unwrap();
18/// assert_eq!(n.mean().unwrap(), 0.0);
19/// assert_eq!(n.pdf(1.0), 0.2419707245191433497978);
20/// ```
21#[derive(Debug, Copy, Clone, PartialEq)]
22pub struct Normal {
23    mean: f64,
24    std_dev: f64,
25}
26
27impl Normal {
28    ///  Constructs a new normal distribution with a mean of `mean`
29    /// and a standard deviation of `std_dev`
30    ///
31    /// # Errors
32    ///
33    /// Returns an error if `mean` or `std_dev` are `NaN` or if
34    /// `std_dev <= 0.0`
35    ///
36    /// # Examples
37    ///
38    /// ```
39    /// use statrs::distribution::Normal;
40    ///
41    /// let mut result = Normal::new(0.0, 1.0);
42    /// assert!(result.is_ok());
43    ///
44    /// result = Normal::new(0.0, 0.0);
45    /// assert!(result.is_err());
46    /// ```
47    pub fn new(mean: f64, std_dev: f64) -> Result<Normal> {
48        if mean.is_nan() || std_dev.is_nan() || std_dev <= 0.0 {
49            Err(StatsError::BadParams)
50        } else {
51            Ok(Normal { mean, std_dev })
52        }
53    }
54}
55
56impl ::rand::distributions::Distribution<f64> for Normal {
57    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
58        sample_unchecked(rng, self.mean, self.std_dev)
59    }
60}
61
62impl ContinuousCDF<f64, f64> for Normal {
63    /// Calculates the cumulative distribution function for the
64    /// normal distribution at `x`
65    ///
66    /// # Formula
67    ///
68    /// ```ignore
69    /// (1 / 2) * (1 + erf((x - μ) / (σ * sqrt(2))))
70    /// ```
71    ///
72    /// where `μ` is the mean, `σ` is the standard deviation, and
73    /// `erf` is the error function
74    fn cdf(&self, x: f64) -> f64 {
75        cdf_unchecked(x, self.mean, self.std_dev)
76    }
77
78    /// Calculates the survival function for the
79    /// normal distribution at `x`
80    ///
81    /// # Formula
82    ///
83    /// ```ignore
84    /// (1 / 2) * (1 + erf(-(x - μ) / (σ * sqrt(2))))
85    /// ```
86    ///
87    /// where `μ` is the mean, `σ` is the standard deviation, and
88    /// `erf` is the error function
89    ///
90    /// note that this calculates the complement due to flipping
91    /// the sign of the argument error function with respect to the cdf.
92    ///
93    /// the normal cdf Φ (and internal error function) as the following property:
94    /// ```ignore
95    ///  Φ(-x) + Φ(x) = 1
96    ///  Φ(-x)        = 1 - Φ(x) 
97    /// ```
98    fn sf(&self, x: f64) -> f64 {
99        sf_unchecked(x, self.mean, self.std_dev)
100    }
101
102    /// Calculates the inverse cumulative distribution function for the
103    /// normal distribution at `x`
104    ///
105    /// # Panics
106    ///
107    /// If `x < 0.0` or `x > 1.0`
108    ///
109    /// # Formula
110    ///
111    /// ```ignore
112    /// μ - sqrt(2) * σ * erfc_inv(2x)
113    /// ```
114    ///
115    /// where `μ` is the mean, `σ` is the standard deviation and `erfc_inv` is
116    /// the inverse of the complementary error function
117    fn inverse_cdf(&self, x: f64) -> f64 {
118        if !(0.0..=1.0).contains(&x) {
119            panic!("x must be in [0, 1]");
120        } else {
121            self.mean - (self.std_dev * f64::consts::SQRT_2 * erf::erfc_inv(2.0 * x))
122        }
123    }
124}
125
126impl Min<f64> for Normal {
127    /// Returns the minimum value in the domain of the
128    /// normal distribution representable by a double precision float
129    ///
130    /// # Formula
131    ///
132    /// ```ignore
133    /// -INF
134    /// ```
135    fn min(&self) -> f64 {
136        f64::NEG_INFINITY
137    }
138}
139
140impl Max<f64> for Normal {
141    /// Returns the maximum value in the domain of the
142    /// normal distribution representable by a double precision float
143    ///
144    /// # Formula
145    ///
146    /// ```ignore
147    /// INF
148    /// ```
149    fn max(&self) -> f64 {
150        f64::INFINITY
151    }
152}
153
154impl Distribution<f64> for Normal {
155    /// Returns the mean of the normal distribution
156    ///
157    /// # Remarks
158    ///
159    /// This is the same mean used to construct the distribution
160    fn mean(&self) -> Option<f64> {
161        Some(self.mean)
162    }
163    /// Returns the variance of the normal distribution
164    ///
165    /// # Formula
166    ///
167    /// ```ignore
168    /// σ^2
169    /// ```
170    ///
171    /// where `σ` is the standard deviation
172    fn variance(&self) -> Option<f64> {
173        Some(self.std_dev * self.std_dev)
174    }
175    /// Returns the entropy of the normal distribution
176    ///
177    /// # Formula
178    ///
179    /// ```ignore
180    /// (1 / 2) * ln(2σ^2 * π * e)
181    /// ```
182    ///
183    /// where `σ` is the standard deviation
184    fn entropy(&self) -> Option<f64> {
185        Some(self.std_dev.ln() + consts::LN_SQRT_2PIE)
186    }
187    /// Returns the skewness of the normal distribution
188    ///
189    /// # Formula
190    ///
191    /// ```ignore
192    /// 0
193    /// ```
194    fn skewness(&self) -> Option<f64> {
195        Some(0.0)
196    }
197}
198
199impl Median<f64> for Normal {
200    /// Returns the median of the normal distribution
201    ///
202    /// # Formula
203    ///
204    /// ```ignore
205    /// μ
206    /// ```
207    ///
208    /// where `μ` is the mean
209    fn median(&self) -> f64 {
210        self.mean
211    }
212}
213
214impl Mode<Option<f64>> for Normal {
215    /// Returns the mode of the normal distribution
216    ///
217    /// # Formula
218    ///
219    /// ```ignore
220    /// μ
221    /// ```
222    ///
223    /// where `μ` is the mean
224    fn mode(&self) -> Option<f64> {
225        Some(self.mean)
226    }
227}
228
229impl Continuous<f64, f64> for Normal {
230    /// Calculates the probability density function for the normal distribution
231    /// at `x`
232    ///
233    /// # Formula
234    ///
235    /// ```ignore
236    /// (1 / sqrt(2σ^2 * π)) * e^(-(x - μ)^2 / 2σ^2)
237    /// ```
238    ///
239    /// where `μ` is the mean and `σ` is the standard deviation
240    fn pdf(&self, x: f64) -> f64 {
241        pdf_unchecked(x, self.mean, self.std_dev)
242    }
243
244    /// Calculates the log probability density function for the normal
245    /// distribution
246    /// at `x`
247    ///
248    /// # Formula
249    ///
250    /// ```ignore
251    /// ln((1 / sqrt(2σ^2 * π)) * e^(-(x - μ)^2 / 2σ^2))
252    /// ```
253    ///
254    /// where `μ` is the mean and `σ` is the standard deviation
255    fn ln_pdf(&self, x: f64) -> f64 {
256        ln_pdf_unchecked(x, self.mean, self.std_dev)
257    }
258}
259
260/// performs an unchecked cdf calculation for a normal distribution
261/// with the given mean and standard deviation at x
262pub fn cdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 {
263    0.5 * erf::erfc((mean - x) / (std_dev * f64::consts::SQRT_2))
264}
265
266/// performs an unchecked sf calculation for a normal distribution
267/// with the given mean and standard deviation at x
268pub fn sf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 {
269    0.5 * erf::erfc((x - mean) / (std_dev * f64::consts::SQRT_2))
270}
271
272/// performs an unchecked pdf calculation for a normal distribution
273/// with the given mean and standard deviation at x
274pub fn pdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 {
275    let d = (x - mean) / std_dev;
276    (-0.5 * d * d).exp() / (consts::SQRT_2PI * std_dev)
277}
278
279/// performs an unchecked log(pdf) calculation for a normal distribution
280/// with the given mean and standard deviation at x
281pub fn ln_pdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 {
282    let d = (x - mean) / std_dev;
283    (-0.5 * d * d) - consts::LN_SQRT_2PI - std_dev.ln()
284}
285
286/// draws a sample from a normal distribution using the Box-Muller algorithm
287pub fn sample_unchecked<R: Rng + ?Sized>(rng: &mut R, mean: f64, std_dev: f64) -> f64 {
288    mean + std_dev * ziggurat::sample_std_normal(rng)
289}
290
291#[rustfmt::skip]
292#[cfg(all(test, feature = "nightly"))]
293mod tests {
294    use crate::statistics::*;
295    use crate::distribution::{ContinuousCDF, Continuous, Normal};
296    use crate::distribution::internal::*;
297    use crate::consts::ACC;
298
299    fn try_create(mean: f64, std_dev: f64) -> Normal {
300        let n = Normal::new(mean, std_dev);
301        assert!(n.is_ok());
302        n.unwrap()
303    }
304
305    fn create_case(mean: f64, std_dev: f64) {
306        let n = try_create(mean, std_dev);
307        assert_eq!(mean, n.mean().unwrap());
308        assert_eq!(std_dev, n.std_dev().unwrap());
309    }
310
311    fn bad_create_case(mean: f64, std_dev: f64) {
312        let n = Normal::new(mean, std_dev);
313        assert!(n.is_err());
314    }
315
316    fn test_case<F>(mean: f64, std_dev: f64, expected: f64, eval: F)
317        where F: Fn(Normal) -> f64
318    {
319        let n = try_create(mean, std_dev);
320        let x = eval(n);
321        assert_eq!(expected, x);
322    }
323
324    fn test_almost<F>(mean: f64, std_dev: f64, expected: f64, acc: f64, eval: F)
325        where F: Fn(Normal) -> f64
326    {
327        let n = try_create(mean, std_dev);
328        let x = eval(n);
329        assert_almost_eq!(expected, x, acc);
330    }
331
332    #[test]
333    fn test_create() {
334        create_case(10.0, 0.1);
335        create_case(-5.0, 1.0);
336        create_case(0.0, 10.0);
337        create_case(10.0, 100.0);
338        create_case(-5.0, f64::INFINITY);
339    }
340
341    #[test]
342    fn test_bad_create() {
343        bad_create_case(0.0, 0.0);
344        bad_create_case(f64::NAN, 1.0);
345        bad_create_case(1.0, f64::NAN);
346        bad_create_case(f64::NAN, f64::NAN);
347        bad_create_case(1.0, -1.0);
348    }
349
350    #[test]
351    fn test_variance() {
352        let variance = |x: Normal| x.variance().unwrap();
353        test_case(0.0, 0.1, 0.1 * 0.1, variance);
354        test_case(0.0, 1.0, 1.0, variance);
355        test_case(0.0, 10.0, 100.0, variance);
356        test_case(0.0, f64::INFINITY, f64::INFINITY, variance);
357    }
358
359    #[test]
360    fn test_entropy() {
361        let entropy = |x: Normal| x.entropy().unwrap();
362        test_almost(0.0, 0.1, -0.8836465597893729422377, 1e-15, entropy);
363        test_case(0.0, 1.0, 1.41893853320467274178, entropy);
364        test_case(0.0, 10.0, 3.721523626198718425798, entropy);
365        test_case(0.0, f64::INFINITY, f64::INFINITY, entropy);
366    }
367
368    #[test]
369    fn test_skewness() {
370        let skewness = |x: Normal| x.skewness().unwrap();
371        test_case(0.0, 0.1, 0.0, skewness);
372        test_case(4.0, 1.0, 0.0, skewness);
373        test_case(0.3, 10.0, 0.0, skewness);
374        test_case(0.0, f64::INFINITY, 0.0, skewness);
375    }
376
377    #[test]
378    fn test_mode() {
379        let mode = |x: Normal| x.mode().unwrap();
380        test_case(-0.0, 1.0, 0.0, mode);
381        test_case(0.0, 1.0, 0.0, mode);
382        test_case(0.1, 1.0, 0.1, mode);
383        test_case(1.0, 1.0, 1.0, mode);
384        test_case(-10.0, 1.0, -10.0, mode);
385        test_case(f64::INFINITY, 1.0, f64::INFINITY, mode);
386    }
387
388    #[test]
389    fn test_median() {
390        let median = |x: Normal| x.median();
391        test_case(-0.0, 1.0, 0.0, median);
392        test_case(0.0, 1.0, 0.0, median);
393        test_case(0.1, 1.0, 0.1, median);
394        test_case(1.0, 1.0, 1.0, median);
395        test_case(-0.0, 1.0, -0.0, median);
396        test_case(f64::INFINITY, 1.0, f64::INFINITY, median);
397    }
398
399    #[test]
400    fn test_min_max() {
401        let min = |x: Normal| x.min();
402        let max = |x: Normal| x.max();
403        test_case(0.0, 0.1, f64::NEG_INFINITY, min);
404        test_case(-3.0, 10.0, f64::NEG_INFINITY, min);
405        test_case(0.0, 0.1, f64::INFINITY, max);
406        test_case(-3.0, 10.0, f64::INFINITY, max);
407    }
408
409    #[test]
410    fn test_pdf() {
411        let pdf = |arg: f64| move |x: Normal| x.pdf(arg);
412        test_almost(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(8.5));
413        test_almost(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(9.8));
414        test_almost(10.0, 0.1, 3.989422804014326779399, 1e-15, pdf(10.0));
415        test_almost(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(10.2));
416        test_almost(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(11.5));
417        test_case(-5.0, 1.0, 1.486719514734297707908E-6, pdf(-10.0));
418        test_case(-5.0, 1.0, 0.01752830049356853736216, pdf(-7.5));
419        test_almost(-5.0, 1.0, 0.3989422804014326779399, 1e-16, pdf(-5.0));
420        test_case(-5.0, 1.0, 0.01752830049356853736216, pdf(-2.5));
421        test_case(-5.0, 1.0, 1.486719514734297707908E-6, pdf(0.0));
422        test_case(0.0, 10.0, 0.03520653267642994777747, pdf(-5.0));
423        test_almost(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(-2.5));
424        test_almost(0.0, 10.0, 0.03989422804014326779399, 1e-17, pdf(0.0));
425        test_almost(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(2.5));
426        test_case(0.0, 10.0, 0.03520653267642994777747, pdf(5.0));
427        test_almost(10.0, 100.0, 4.398359598042719404845E-4, 1e-19, pdf(-200.0));
428        test_case(10.0, 100.0, 0.002178521770325505313831, pdf(-100.0));
429        test_case(10.0, 100.0, 0.003969525474770117655105, pdf(0.0));
430        test_almost(10.0, 100.0, 0.002660852498987548218204, 1e-18, pdf(100.0));
431        test_case(10.0, 100.0, 6.561581477467659126534E-4, pdf(200.0));
432        test_case(-5.0, f64::INFINITY, 0.0, pdf(-5.0));
433        test_case(-5.0, f64::INFINITY, 0.0, pdf(0.0));
434        test_case(-5.0, f64::INFINITY, 0.0, pdf(100.0));
435    }
436
437    #[test]
438    fn test_ln_pdf() {
439        let ln_pdf = |arg: f64| move |x: Normal| x.ln_pdf(arg);
440        test_almost(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(8.5));
441        test_almost(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(9.8));
442        test_almost(10.0, 0.1, (3.989422804014326779399f64).ln(), 1e-15, ln_pdf(10.0));
443        test_almost(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(10.2));
444        test_almost(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(11.5));
445        test_case(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(-10.0));
446        test_case(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-7.5));
447        test_almost(-5.0, 1.0, (0.3989422804014326779399f64).ln(), 1e-15, ln_pdf(-5.0));
448        test_case(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-2.5));
449        test_case(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(0.0));
450        test_case(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(-5.0));
451        test_case(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(-2.5));
452        test_case(0.0, 10.0, (0.03989422804014326779399f64).ln(), ln_pdf(0.0));
453        test_case(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(2.5));
454        test_case(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(5.0));
455        test_case(10.0, 100.0, (4.398359598042719404845E-4f64).ln(), ln_pdf(-200.0));
456        test_case(10.0, 100.0, (0.002178521770325505313831f64).ln(), ln_pdf(-100.0));
457        test_almost(10.0, 100.0, (0.003969525474770117655105f64).ln(),1e-15, ln_pdf(0.0));
458        test_almost(10.0, 100.0, (0.002660852498987548218204f64).ln(), 1e-15, ln_pdf(100.0));
459        test_almost(10.0, 100.0, (6.561581477467659126534E-4f64).ln(), 1e-15, ln_pdf(200.0));
460        test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0));
461        test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0));
462        test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(100.0));
463    }
464
465    #[test]
466    fn test_cdf() {
467        let cdf = |arg: f64| move |x: Normal| x.cdf(arg);
468        test_case(5.0, 2.0, 0.0, cdf(f64::NEG_INFINITY));
469        test_almost(5.0, 2.0, 0.0000002866515718, 1e-16, cdf(-5.0));
470        test_almost(5.0, 2.0, 0.0002326290790, 1e-13, cdf(-2.0));
471        test_almost(5.0, 2.0, 0.006209665325, 1e-12, cdf(0.0));
472        test_case(5.0, 2.0, 0.30853753872598689636229538939166226011639782444542207, cdf(4.0));
473        test_case(5.0, 2.0, 0.5, cdf(5.0));
474        test_case(5.0, 2.0, 0.69146246127401310363770461060833773988360217555457859, cdf(6.0));
475        test_almost(5.0, 2.0, 0.993790334674, 1e-12, cdf(10.0));
476    }
477
478    #[test]
479    fn test_sf() {
480        let sf = |arg: f64| move |x: Normal| x.sf(arg);
481        test_case(5.0, 2.0, 1.0, sf(f64::NEG_INFINITY));
482        test_almost(5.0, 2.0, 0.9999997133484281, 1e-16, sf(-5.0));
483        test_almost(5.0, 2.0, 0.9997673709209455, 1e-13, sf(-2.0));
484        test_almost(5.0, 2.0, 0.9937903346744879, 1e-12, sf(0.0));
485        test_case(5.0, 2.0, 0.6914624612740131, sf(4.0));
486        test_case(5.0, 2.0, 0.5, sf(5.0));
487        test_case(5.0, 2.0, 0.3085375387259869, sf(6.0));
488        test_almost(5.0, 2.0, 0.006209665325512148, 1e-12, sf(10.0));
489    }
490
491    #[test]
492    fn test_continuous() {
493        test::check_continuous_distribution(&try_create(0.0, 1.0), -10.0, 10.0);
494        test::check_continuous_distribution(&try_create(20.0, 0.5), 10.0, 30.0);
495    }
496
497    #[test]
498    fn test_inverse_cdf() {
499        let inverse_cdf = |arg: f64| move |x: Normal| x.inverse_cdf(arg);
500        test_case(5.0, 2.0, f64::NEG_INFINITY, inverse_cdf( 0.0));
501        test_almost(5.0, 2.0, -5.0, 1e-14, inverse_cdf(0.00000028665157187919391167375233287464535385442301361187883));
502        test_almost(5.0, 2.0, -2.0, 1e-14, inverse_cdf(0.0002326290790355250363499258867279847735487493358890356));
503        test_almost(5.0, 2.0, -0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036));
504        test_almost(5.0, 2.0, 0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036));
505        test_almost(5.0, 2.0, 4.0, 1e-14, inverse_cdf(0.30853753872598689636229538939166226011639782444542207));
506        test_almost(5.0, 2.0, 5.0, 1e-14, inverse_cdf(0.5));
507        test_almost(5.0, 2.0, 6.0, 1e-14, inverse_cdf(0.69146246127401310363770461060833773988360217555457859));
508        test_almost(5.0, 2.0, 10.0, 1e-14, inverse_cdf(0.9937903346742238648330218954258077788721022530769078));
509        test_case(5.0, 2.0, f64::INFINITY, inverse_cdf(1.0));
510    }
511}