1use crate::consts;
2use crate::distribution::{Continuous, ContinuousCDF};
3use crate::function::erf;
4use crate::statistics::*;
5use std::f64;
6
7#[derive(Copy, Clone, PartialEq, Debug)]
21pub struct Normal {
22 mean: f64,
23 std_dev: f64,
24}
25
26#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
28#[non_exhaustive]
29pub enum NormalError {
30 MeanInvalid,
32
33 StandardDeviationInvalid,
35}
36
37impl std::fmt::Display for NormalError {
38 #[cfg_attr(coverage_nightly, coverage(off))]
39 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
40 match self {
41 NormalError::MeanInvalid => write!(f, "Mean is NaN"),
42 NormalError::StandardDeviationInvalid => {
43 write!(f, "Standard deviation is NaN, zero or less than zero")
44 }
45 }
46 }
47}
48
49impl std::error::Error for NormalError {}
50
51impl Normal {
52 pub fn new(mean: f64, std_dev: f64) -> Result<Normal, NormalError> {
72 if mean.is_nan() {
73 return Err(NormalError::MeanInvalid);
74 }
75
76 if std_dev.is_nan() || std_dev <= 0.0 {
77 return Err(NormalError::StandardDeviationInvalid);
78 }
79
80 Ok(Normal { mean, std_dev })
81 }
82
83 pub fn standard() -> Normal {
95 Normal {
96 mean: 0.0,
97 std_dev: 1.0,
98 }
99 }
100}
101
102impl std::fmt::Display for Normal {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 write!(f, "N({},{})", self.mean, self.std_dev)
105 }
106}
107
108#[cfg(feature = "rand")]
109#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
110impl ::rand::distributions::Distribution<f64> for Normal {
111 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
112 sample_unchecked(rng, self.mean, self.std_dev)
113 }
114}
115
116impl ContinuousCDF<f64, f64> for Normal {
117 fn cdf(&self, x: f64) -> f64 {
129 cdf_unchecked(x, self.mean, self.std_dev)
130 }
131
132 fn sf(&self, x: f64) -> f64 {
153 sf_unchecked(x, self.mean, self.std_dev)
154 }
155
156 fn inverse_cdf(&self, x: f64) -> f64 {
173 if !(0.0..=1.0).contains(&x) {
174 panic!("x must be in [0, 1]");
175 } else {
176 self.mean - (self.std_dev * f64::consts::SQRT_2 * erf::erfc_inv(2.0 * x))
177 }
178 }
179}
180
181impl Min<f64> for Normal {
182 fn min(&self) -> f64 {
191 f64::NEG_INFINITY
192 }
193}
194
195impl Max<f64> for Normal {
196 fn max(&self) -> f64 {
205 f64::INFINITY
206 }
207}
208
209impl Distribution<f64> for Normal {
210 fn mean(&self) -> Option<f64> {
216 Some(self.mean)
217 }
218
219 fn variance(&self) -> Option<f64> {
229 Some(self.std_dev * self.std_dev)
230 }
231
232 fn std_dev(&self) -> Option<f64> {
236 Some(self.std_dev)
237 }
238
239 fn entropy(&self) -> Option<f64> {
249 Some(self.std_dev.ln() + consts::LN_SQRT_2PIE)
250 }
251
252 fn skewness(&self) -> Option<f64> {
260 Some(0.0)
261 }
262}
263
264impl Median<f64> for Normal {
265 fn median(&self) -> f64 {
275 self.mean
276 }
277}
278
279impl Mode<Option<f64>> for Normal {
280 fn mode(&self) -> Option<f64> {
290 Some(self.mean)
291 }
292}
293
294impl Continuous<f64, f64> for Normal {
295 fn pdf(&self, x: f64) -> f64 {
306 pdf_unchecked(x, self.mean, self.std_dev)
307 }
308
309 fn ln_pdf(&self, x: f64) -> f64 {
321 ln_pdf_unchecked(x, self.mean, self.std_dev)
322 }
323}
324
325pub fn cdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 {
328 0.5 * erf::erfc((mean - x) / (std_dev * f64::consts::SQRT_2))
329}
330
331pub fn sf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 {
334 0.5 * erf::erfc((x - mean) / (std_dev * f64::consts::SQRT_2))
335}
336
337pub fn pdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 {
340 let d = (x - mean) / std_dev;
341 (-0.5 * d * d).exp() / (consts::SQRT_2PI * std_dev)
342}
343
344pub fn ln_pdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 {
347 let d = (x - mean) / std_dev;
348 (-0.5 * d * d) - consts::LN_SQRT_2PI - std_dev.ln()
349}
350
351#[cfg(feature = "rand")]
352#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
353pub fn sample_unchecked<R: ::rand::Rng + ?Sized>(rng: &mut R, mean: f64, std_dev: f64) -> f64 {
355 use crate::distribution::ziggurat;
356
357 mean + std_dev * ziggurat::sample_std_normal(rng)
358}
359
360impl std::default::Default for Normal {
361 fn default() -> Self {
364 Self::standard()
365 }
366}
367
368#[rustfmt::skip]
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use crate::distribution::internal::*;
373 use crate::testing_boiler;
374
375 testing_boiler!(mean: f64, std_dev: f64; Normal; NormalError);
376
377 #[test]
378 fn test_create() {
379 create_ok(10.0, 0.1);
380 create_ok(-5.0, 1.0);
381 create_ok(0.0, 10.0);
382 create_ok(10.0, 100.0);
383 create_ok(-5.0, f64::INFINITY);
384 }
385
386 #[test]
387 fn test_bad_create() {
388 test_create_err(f64::NAN, 1.0, NormalError::MeanInvalid);
389 test_create_err(1.0, f64::NAN, NormalError::StandardDeviationInvalid);
390 create_err(0.0, 0.0);
391 create_err(f64::NAN, f64::NAN);
392 create_err(1.0, -1.0);
393 }
394
395 #[test]
396 fn test_variance() {
397 let variance = |x: Normal| x.variance().unwrap();
398 test_exact(0.0, 0.1, 0.1 * 0.1, variance);
399 test_exact(0.0, 1.0, 1.0, variance);
400 test_exact(0.0, 10.0, 100.0, variance);
401 test_exact(0.0, f64::INFINITY, f64::INFINITY, variance);
402 }
403
404 #[test]
405 fn test_entropy() {
406 let entropy = |x: Normal| x.entropy().unwrap();
407 test_absolute(0.0, 0.1, -0.8836465597893729422377, 1e-15, entropy);
408 test_exact(0.0, 1.0, 1.41893853320467274178, entropy);
409 test_exact(0.0, 10.0, 3.721523626198718425798, entropy);
410 test_exact(0.0, f64::INFINITY, f64::INFINITY, entropy);
411 }
412
413 #[test]
414 fn test_skewness() {
415 let skewness = |x: Normal| x.skewness().unwrap();
416 test_exact(0.0, 0.1, 0.0, skewness);
417 test_exact(4.0, 1.0, 0.0, skewness);
418 test_exact(0.3, 10.0, 0.0, skewness);
419 test_exact(0.0, f64::INFINITY, 0.0, skewness);
420 }
421
422 #[test]
423 fn test_mode() {
424 let mode = |x: Normal| x.mode().unwrap();
425 test_exact(-0.0, 1.0, 0.0, mode);
426 test_exact(0.0, 1.0, 0.0, mode);
427 test_exact(0.1, 1.0, 0.1, mode);
428 test_exact(1.0, 1.0, 1.0, mode);
429 test_exact(-10.0, 1.0, -10.0, mode);
430 test_exact(f64::INFINITY, 1.0, f64::INFINITY, mode);
431 }
432
433 #[test]
434 fn test_median() {
435 let median = |x: Normal| x.median();
436 test_exact(-0.0, 1.0, 0.0, median);
437 test_exact(0.0, 1.0, 0.0, median);
438 test_exact(0.1, 1.0, 0.1, median);
439 test_exact(1.0, 1.0, 1.0, median);
440 test_exact(-0.0, 1.0, -0.0, median);
441 test_exact(f64::INFINITY, 1.0, f64::INFINITY, median);
442 }
443
444 #[test]
445 fn test_min_max() {
446 let min = |x: Normal| x.min();
447 let max = |x: Normal| x.max();
448 test_exact(0.0, 0.1, f64::NEG_INFINITY, min);
449 test_exact(-3.0, 10.0, f64::NEG_INFINITY, min);
450 test_exact(0.0, 0.1, f64::INFINITY, max);
451 test_exact(-3.0, 10.0, f64::INFINITY, max);
452 }
453
454 #[test]
455 fn test_pdf() {
456 let pdf = |arg: f64| move |x: Normal| x.pdf(arg);
457 test_absolute(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(8.5));
458 test_absolute(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(9.8));
459 test_absolute(10.0, 0.1, 3.989422804014326779399, 1e-15, pdf(10.0));
460 test_absolute(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(10.2));
461 test_absolute(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(11.5));
462 test_exact(-5.0, 1.0, 1.486719514734297707908E-6, pdf(-10.0));
463 test_exact(-5.0, 1.0, 0.01752830049356853736216, pdf(-7.5));
464 test_absolute(-5.0, 1.0, 0.3989422804014326779399, 1e-16, pdf(-5.0));
465 test_exact(-5.0, 1.0, 0.01752830049356853736216, pdf(-2.5));
466 test_exact(-5.0, 1.0, 1.486719514734297707908E-6, pdf(0.0));
467 test_exact(0.0, 10.0, 0.03520653267642994777747, pdf(-5.0));
468 test_absolute(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(-2.5));
469 test_absolute(0.0, 10.0, 0.03989422804014326779399, 1e-17, pdf(0.0));
470 test_absolute(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(2.5));
471 test_exact(0.0, 10.0, 0.03520653267642994777747, pdf(5.0));
472 test_absolute(10.0, 100.0, 4.398359598042719404845E-4, 1e-19, pdf(-200.0));
473 test_exact(10.0, 100.0, 0.002178521770325505313831, pdf(-100.0));
474 test_exact(10.0, 100.0, 0.003969525474770117655105, pdf(0.0));
475 test_absolute(10.0, 100.0, 0.002660852498987548218204, 1e-18, pdf(100.0));
476 test_exact(10.0, 100.0, 6.561581477467659126534E-4, pdf(200.0));
477 test_exact(-5.0, f64::INFINITY, 0.0, pdf(-5.0));
478 test_exact(-5.0, f64::INFINITY, 0.0, pdf(0.0));
479 test_exact(-5.0, f64::INFINITY, 0.0, pdf(100.0));
480 }
481
482 #[test]
483 fn test_ln_pdf() {
484 let ln_pdf = |arg: f64| move |x: Normal| x.ln_pdf(arg);
485 test_absolute(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(8.5));
486 test_absolute(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(9.8));
487 test_absolute(10.0, 0.1, (3.989422804014326779399f64).ln(), 1e-15, ln_pdf(10.0));
488 test_absolute(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(10.2));
489 test_absolute(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(11.5));
490 test_exact(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(-10.0));
491 test_exact(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-7.5));
492 test_absolute(-5.0, 1.0, (0.3989422804014326779399f64).ln(), 1e-15, ln_pdf(-5.0));
493 test_exact(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-2.5));
494 test_exact(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(0.0));
495 test_exact(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(-5.0));
496 test_exact(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(-2.5));
497 test_exact(0.0, 10.0, (0.03989422804014326779399f64).ln(), ln_pdf(0.0));
498 test_exact(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(2.5));
499 test_exact(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(5.0));
500 test_exact(10.0, 100.0, (4.398359598042719404845E-4f64).ln(), ln_pdf(-200.0));
501 test_exact(10.0, 100.0, (0.002178521770325505313831f64).ln(), ln_pdf(-100.0));
502 test_absolute(10.0, 100.0, (0.003969525474770117655105f64).ln(),1e-15, ln_pdf(0.0));
503 test_absolute(10.0, 100.0, (0.002660852498987548218204f64).ln(), 1e-15, ln_pdf(100.0));
504 test_absolute(10.0, 100.0, (6.561581477467659126534E-4f64).ln(), 1e-15, ln_pdf(200.0));
505 test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0));
506 test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0));
507 test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(100.0));
508 }
509
510 #[test]
511 fn test_cdf() {
512 let cdf = |arg: f64| move |x: Normal| x.cdf(arg);
513 test_exact(5.0, 2.0, 0.0, cdf(f64::NEG_INFINITY));
514 test_absolute(5.0, 2.0, 0.0000002866515718, 1e-16, cdf(-5.0));
515 test_absolute(5.0, 2.0, 0.0002326290790, 1e-13, cdf(-2.0));
516 test_absolute(5.0, 2.0, 0.006209665325, 1e-12, cdf(0.0));
517 test_exact(5.0, 2.0, 0.30853753872598689636229538939166226011639782444542207, cdf(4.0));
518 test_exact(5.0, 2.0, 0.5, cdf(5.0));
519 test_exact(5.0, 2.0, 0.69146246127401310363770461060833773988360217555457859, cdf(6.0));
520 test_absolute(5.0, 2.0, 0.993790334674, 1e-12, cdf(10.0));
521 }
522
523 #[test]
524 fn test_sf() {
525 let sf = |arg: f64| move |x: Normal| x.sf(arg);
526 test_exact(5.0, 2.0, 1.0, sf(f64::NEG_INFINITY));
527 test_absolute(5.0, 2.0, 0.9999997133484281, 1e-16, sf(-5.0));
528 test_absolute(5.0, 2.0, 0.9997673709209455, 1e-13, sf(-2.0));
529 test_absolute(5.0, 2.0, 0.9937903346744879, 1e-12, sf(0.0));
530 test_exact(5.0, 2.0, 0.6914624612740131, sf(4.0));
531 test_exact(5.0, 2.0, 0.5, sf(5.0));
532 test_exact(5.0, 2.0, 0.3085375387259869, sf(6.0));
533 test_absolute(5.0, 2.0, 0.006209665325512148, 1e-12, sf(10.0));
534 }
535
536 #[test]
537 fn test_continuous() {
538 test::check_continuous_distribution(&create_ok(0.0, 1.0), -10.0, 10.0);
539 test::check_continuous_distribution(&create_ok(20.0, 0.5), 10.0, 30.0);
540 }
541
542 #[test]
543 fn test_inverse_cdf() {
544 let inverse_cdf = |arg: f64| move |x: Normal| x.inverse_cdf(arg);
545 test_exact(5.0, 2.0, f64::NEG_INFINITY, inverse_cdf( 0.0));
546 test_absolute(5.0, 2.0, -5.0, 1e-14, inverse_cdf(0.00000028665157187919391167375233287464535385442301361187883));
547 test_absolute(5.0, 2.0, -2.0, 1e-14, inverse_cdf(0.0002326290790355250363499258867279847735487493358890356));
548 test_absolute(5.0, 2.0, -0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036));
549 test_absolute(5.0, 2.0, 0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036));
550 test_absolute(5.0, 2.0, 4.0, 1e-14, inverse_cdf(0.30853753872598689636229538939166226011639782444542207));
551 test_absolute(5.0, 2.0, 5.0, 1e-14, inverse_cdf(0.5));
552 test_absolute(5.0, 2.0, 6.0, 1e-14, inverse_cdf(0.69146246127401310363770461060833773988360217555457859));
553 test_absolute(5.0, 2.0, 10.0, 1e-14, inverse_cdf(0.9937903346742238648330218954258077788721022530769078));
554 test_exact(5.0, 2.0, f64::INFINITY, inverse_cdf(1.0));
555 }
556
557 #[test]
558 fn test_default() {
559 let n = Normal::default();
560
561 let n_mean = n.mean().unwrap();
562 let n_std = n.std_dev().unwrap();
563
564 assert_almost_eq!(n_mean, 0.0, 1e-15);
566 assert_almost_eq!(n_std, 1.0, 1e-15);
568 }
569}