1use crate::distribution::{ziggurat, Continuous, ContinuousCDF};
2use crate::function::erf;
3use crate::statistics::*;
4use crate::{consts, Result, StatsError};
5use rand::Rng;
6use std::f64;
7
8#[derive(Debug, Copy, Clone, PartialEq)]
22pub struct Normal {
23 mean: f64,
24 std_dev: f64,
25}
26
27impl Normal {
28 pub fn new(mean: f64, std_dev: f64) -> Result<Normal> {
48 if mean.is_nan() || std_dev.is_nan() || std_dev <= 0.0 {
49 Err(StatsError::BadParams)
50 } else {
51 Ok(Normal { mean, std_dev })
52 }
53 }
54}
55
56impl ::rand::distributions::Distribution<f64> for Normal {
57 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
58 sample_unchecked(rng, self.mean, self.std_dev)
59 }
60}
61
62impl ContinuousCDF<f64, f64> for Normal {
63 fn cdf(&self, x: f64) -> f64 {
75 cdf_unchecked(x, self.mean, self.std_dev)
76 }
77
78 fn sf(&self, x: f64) -> f64 {
99 sf_unchecked(x, self.mean, self.std_dev)
100 }
101
102 fn inverse_cdf(&self, x: f64) -> f64 {
118 if !(0.0..=1.0).contains(&x) {
119 panic!("x must be in [0, 1]");
120 } else {
121 self.mean - (self.std_dev * f64::consts::SQRT_2 * erf::erfc_inv(2.0 * x))
122 }
123 }
124}
125
126impl Min<f64> for Normal {
127 fn min(&self) -> f64 {
136 f64::NEG_INFINITY
137 }
138}
139
140impl Max<f64> for Normal {
141 fn max(&self) -> f64 {
150 f64::INFINITY
151 }
152}
153
154impl Distribution<f64> for Normal {
155 fn mean(&self) -> Option<f64> {
161 Some(self.mean)
162 }
163 fn variance(&self) -> Option<f64> {
173 Some(self.std_dev * self.std_dev)
174 }
175 fn entropy(&self) -> Option<f64> {
185 Some(self.std_dev.ln() + consts::LN_SQRT_2PIE)
186 }
187 fn skewness(&self) -> Option<f64> {
195 Some(0.0)
196 }
197}
198
199impl Median<f64> for Normal {
200 fn median(&self) -> f64 {
210 self.mean
211 }
212}
213
214impl Mode<Option<f64>> for Normal {
215 fn mode(&self) -> Option<f64> {
225 Some(self.mean)
226 }
227}
228
229impl Continuous<f64, f64> for Normal {
230 fn pdf(&self, x: f64) -> f64 {
241 pdf_unchecked(x, self.mean, self.std_dev)
242 }
243
244 fn ln_pdf(&self, x: f64) -> f64 {
256 ln_pdf_unchecked(x, self.mean, self.std_dev)
257 }
258}
259
260pub fn cdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 {
263 0.5 * erf::erfc((mean - x) / (std_dev * f64::consts::SQRT_2))
264}
265
266pub fn sf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 {
269 0.5 * erf::erfc((x - mean) / (std_dev * f64::consts::SQRT_2))
270}
271
272pub fn pdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 {
275 let d = (x - mean) / std_dev;
276 (-0.5 * d * d).exp() / (consts::SQRT_2PI * std_dev)
277}
278
279pub fn ln_pdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 {
282 let d = (x - mean) / std_dev;
283 (-0.5 * d * d) - consts::LN_SQRT_2PI - std_dev.ln()
284}
285
286pub fn sample_unchecked<R: Rng + ?Sized>(rng: &mut R, mean: f64, std_dev: f64) -> f64 {
288 mean + std_dev * ziggurat::sample_std_normal(rng)
289}
290
291#[rustfmt::skip]
292#[cfg(all(test, feature = "nightly"))]
293mod tests {
294 use crate::statistics::*;
295 use crate::distribution::{ContinuousCDF, Continuous, Normal};
296 use crate::distribution::internal::*;
297 use crate::consts::ACC;
298
299 fn try_create(mean: f64, std_dev: f64) -> Normal {
300 let n = Normal::new(mean, std_dev);
301 assert!(n.is_ok());
302 n.unwrap()
303 }
304
305 fn create_case(mean: f64, std_dev: f64) {
306 let n = try_create(mean, std_dev);
307 assert_eq!(mean, n.mean().unwrap());
308 assert_eq!(std_dev, n.std_dev().unwrap());
309 }
310
311 fn bad_create_case(mean: f64, std_dev: f64) {
312 let n = Normal::new(mean, std_dev);
313 assert!(n.is_err());
314 }
315
316 fn test_case<F>(mean: f64, std_dev: f64, expected: f64, eval: F)
317 where F: Fn(Normal) -> f64
318 {
319 let n = try_create(mean, std_dev);
320 let x = eval(n);
321 assert_eq!(expected, x);
322 }
323
324 fn test_almost<F>(mean: f64, std_dev: f64, expected: f64, acc: f64, eval: F)
325 where F: Fn(Normal) -> f64
326 {
327 let n = try_create(mean, std_dev);
328 let x = eval(n);
329 assert_almost_eq!(expected, x, acc);
330 }
331
332 #[test]
333 fn test_create() {
334 create_case(10.0, 0.1);
335 create_case(-5.0, 1.0);
336 create_case(0.0, 10.0);
337 create_case(10.0, 100.0);
338 create_case(-5.0, f64::INFINITY);
339 }
340
341 #[test]
342 fn test_bad_create() {
343 bad_create_case(0.0, 0.0);
344 bad_create_case(f64::NAN, 1.0);
345 bad_create_case(1.0, f64::NAN);
346 bad_create_case(f64::NAN, f64::NAN);
347 bad_create_case(1.0, -1.0);
348 }
349
350 #[test]
351 fn test_variance() {
352 let variance = |x: Normal| x.variance().unwrap();
353 test_case(0.0, 0.1, 0.1 * 0.1, variance);
354 test_case(0.0, 1.0, 1.0, variance);
355 test_case(0.0, 10.0, 100.0, variance);
356 test_case(0.0, f64::INFINITY, f64::INFINITY, variance);
357 }
358
359 #[test]
360 fn test_entropy() {
361 let entropy = |x: Normal| x.entropy().unwrap();
362 test_almost(0.0, 0.1, -0.8836465597893729422377, 1e-15, entropy);
363 test_case(0.0, 1.0, 1.41893853320467274178, entropy);
364 test_case(0.0, 10.0, 3.721523626198718425798, entropy);
365 test_case(0.0, f64::INFINITY, f64::INFINITY, entropy);
366 }
367
368 #[test]
369 fn test_skewness() {
370 let skewness = |x: Normal| x.skewness().unwrap();
371 test_case(0.0, 0.1, 0.0, skewness);
372 test_case(4.0, 1.0, 0.0, skewness);
373 test_case(0.3, 10.0, 0.0, skewness);
374 test_case(0.0, f64::INFINITY, 0.0, skewness);
375 }
376
377 #[test]
378 fn test_mode() {
379 let mode = |x: Normal| x.mode().unwrap();
380 test_case(-0.0, 1.0, 0.0, mode);
381 test_case(0.0, 1.0, 0.0, mode);
382 test_case(0.1, 1.0, 0.1, mode);
383 test_case(1.0, 1.0, 1.0, mode);
384 test_case(-10.0, 1.0, -10.0, mode);
385 test_case(f64::INFINITY, 1.0, f64::INFINITY, mode);
386 }
387
388 #[test]
389 fn test_median() {
390 let median = |x: Normal| x.median();
391 test_case(-0.0, 1.0, 0.0, median);
392 test_case(0.0, 1.0, 0.0, median);
393 test_case(0.1, 1.0, 0.1, median);
394 test_case(1.0, 1.0, 1.0, median);
395 test_case(-0.0, 1.0, -0.0, median);
396 test_case(f64::INFINITY, 1.0, f64::INFINITY, median);
397 }
398
399 #[test]
400 fn test_min_max() {
401 let min = |x: Normal| x.min();
402 let max = |x: Normal| x.max();
403 test_case(0.0, 0.1, f64::NEG_INFINITY, min);
404 test_case(-3.0, 10.0, f64::NEG_INFINITY, min);
405 test_case(0.0, 0.1, f64::INFINITY, max);
406 test_case(-3.0, 10.0, f64::INFINITY, max);
407 }
408
409 #[test]
410 fn test_pdf() {
411 let pdf = |arg: f64| move |x: Normal| x.pdf(arg);
412 test_almost(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(8.5));
413 test_almost(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(9.8));
414 test_almost(10.0, 0.1, 3.989422804014326779399, 1e-15, pdf(10.0));
415 test_almost(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(10.2));
416 test_almost(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(11.5));
417 test_case(-5.0, 1.0, 1.486719514734297707908E-6, pdf(-10.0));
418 test_case(-5.0, 1.0, 0.01752830049356853736216, pdf(-7.5));
419 test_almost(-5.0, 1.0, 0.3989422804014326779399, 1e-16, pdf(-5.0));
420 test_case(-5.0, 1.0, 0.01752830049356853736216, pdf(-2.5));
421 test_case(-5.0, 1.0, 1.486719514734297707908E-6, pdf(0.0));
422 test_case(0.0, 10.0, 0.03520653267642994777747, pdf(-5.0));
423 test_almost(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(-2.5));
424 test_almost(0.0, 10.0, 0.03989422804014326779399, 1e-17, pdf(0.0));
425 test_almost(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(2.5));
426 test_case(0.0, 10.0, 0.03520653267642994777747, pdf(5.0));
427 test_almost(10.0, 100.0, 4.398359598042719404845E-4, 1e-19, pdf(-200.0));
428 test_case(10.0, 100.0, 0.002178521770325505313831, pdf(-100.0));
429 test_case(10.0, 100.0, 0.003969525474770117655105, pdf(0.0));
430 test_almost(10.0, 100.0, 0.002660852498987548218204, 1e-18, pdf(100.0));
431 test_case(10.0, 100.0, 6.561581477467659126534E-4, pdf(200.0));
432 test_case(-5.0, f64::INFINITY, 0.0, pdf(-5.0));
433 test_case(-5.0, f64::INFINITY, 0.0, pdf(0.0));
434 test_case(-5.0, f64::INFINITY, 0.0, pdf(100.0));
435 }
436
437 #[test]
438 fn test_ln_pdf() {
439 let ln_pdf = |arg: f64| move |x: Normal| x.ln_pdf(arg);
440 test_almost(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(8.5));
441 test_almost(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(9.8));
442 test_almost(10.0, 0.1, (3.989422804014326779399f64).ln(), 1e-15, ln_pdf(10.0));
443 test_almost(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(10.2));
444 test_almost(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(11.5));
445 test_case(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(-10.0));
446 test_case(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-7.5));
447 test_almost(-5.0, 1.0, (0.3989422804014326779399f64).ln(), 1e-15, ln_pdf(-5.0));
448 test_case(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-2.5));
449 test_case(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(0.0));
450 test_case(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(-5.0));
451 test_case(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(-2.5));
452 test_case(0.0, 10.0, (0.03989422804014326779399f64).ln(), ln_pdf(0.0));
453 test_case(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(2.5));
454 test_case(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(5.0));
455 test_case(10.0, 100.0, (4.398359598042719404845E-4f64).ln(), ln_pdf(-200.0));
456 test_case(10.0, 100.0, (0.002178521770325505313831f64).ln(), ln_pdf(-100.0));
457 test_almost(10.0, 100.0, (0.003969525474770117655105f64).ln(),1e-15, ln_pdf(0.0));
458 test_almost(10.0, 100.0, (0.002660852498987548218204f64).ln(), 1e-15, ln_pdf(100.0));
459 test_almost(10.0, 100.0, (6.561581477467659126534E-4f64).ln(), 1e-15, ln_pdf(200.0));
460 test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0));
461 test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0));
462 test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(100.0));
463 }
464
465 #[test]
466 fn test_cdf() {
467 let cdf = |arg: f64| move |x: Normal| x.cdf(arg);
468 test_case(5.0, 2.0, 0.0, cdf(f64::NEG_INFINITY));
469 test_almost(5.0, 2.0, 0.0000002866515718, 1e-16, cdf(-5.0));
470 test_almost(5.0, 2.0, 0.0002326290790, 1e-13, cdf(-2.0));
471 test_almost(5.0, 2.0, 0.006209665325, 1e-12, cdf(0.0));
472 test_case(5.0, 2.0, 0.30853753872598689636229538939166226011639782444542207, cdf(4.0));
473 test_case(5.0, 2.0, 0.5, cdf(5.0));
474 test_case(5.0, 2.0, 0.69146246127401310363770461060833773988360217555457859, cdf(6.0));
475 test_almost(5.0, 2.0, 0.993790334674, 1e-12, cdf(10.0));
476 }
477
478 #[test]
479 fn test_sf() {
480 let sf = |arg: f64| move |x: Normal| x.sf(arg);
481 test_case(5.0, 2.0, 1.0, sf(f64::NEG_INFINITY));
482 test_almost(5.0, 2.0, 0.9999997133484281, 1e-16, sf(-5.0));
483 test_almost(5.0, 2.0, 0.9997673709209455, 1e-13, sf(-2.0));
484 test_almost(5.0, 2.0, 0.9937903346744879, 1e-12, sf(0.0));
485 test_case(5.0, 2.0, 0.6914624612740131, sf(4.0));
486 test_case(5.0, 2.0, 0.5, sf(5.0));
487 test_case(5.0, 2.0, 0.3085375387259869, sf(6.0));
488 test_almost(5.0, 2.0, 0.006209665325512148, 1e-12, sf(10.0));
489 }
490
491 #[test]
492 fn test_continuous() {
493 test::check_continuous_distribution(&try_create(0.0, 1.0), -10.0, 10.0);
494 test::check_continuous_distribution(&try_create(20.0, 0.5), 10.0, 30.0);
495 }
496
497 #[test]
498 fn test_inverse_cdf() {
499 let inverse_cdf = |arg: f64| move |x: Normal| x.inverse_cdf(arg);
500 test_case(5.0, 2.0, f64::NEG_INFINITY, inverse_cdf( 0.0));
501 test_almost(5.0, 2.0, -5.0, 1e-14, inverse_cdf(0.00000028665157187919391167375233287464535385442301361187883));
502 test_almost(5.0, 2.0, -2.0, 1e-14, inverse_cdf(0.0002326290790355250363499258867279847735487493358890356));
503 test_almost(5.0, 2.0, -0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036));
504 test_almost(5.0, 2.0, 0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036));
505 test_almost(5.0, 2.0, 4.0, 1e-14, inverse_cdf(0.30853753872598689636229538939166226011639782444542207));
506 test_almost(5.0, 2.0, 5.0, 1e-14, inverse_cdf(0.5));
507 test_almost(5.0, 2.0, 6.0, 1e-14, inverse_cdf(0.69146246127401310363770461060833773988360217555457859));
508 test_almost(5.0, 2.0, 10.0, 1e-14, inverse_cdf(0.9937903346742238648330218954258077788721022530769078));
509 test_case(5.0, 2.0, f64::INFINITY, inverse_cdf(1.0));
510 }
511}