statrs/distribution/
internal.rs

1use num_traits::{Bounded, Float, Num};
2
3/// Returns true if there are no elements in `x` in `arr`
4/// such that `x <= 0.0` or `x` is `f64::NAN` and `sum(arr) > 0.0`.
5/// IF `incl_zero` is true, it tests for `x < 0.0` instead of `x <= 0.0`
6pub fn is_valid_multinomial(arr: &[f64], incl_zero: bool) -> bool {
7    let mut sum = 0.0;
8    for &elt in arr {
9        if incl_zero && elt < 0.0 || !incl_zero && elt <= 0.0 || elt.is_nan() {
10            return false;
11        }
12        sum += elt;
13    }
14    sum != 0.0
15}
16
17/// Implements univariate function bisection searching for criteria
18/// ```text
19/// smallest k such that f(k) >= z
20/// ```
21/// Evaluates to `None` if
22/// - provided interval has lower bound greater than upper bound
23/// - function found not semi-monotone on the provided interval containing `z`
24/// Evaluates to `Some(k)`, where `k` satisfies the search criteria
25pub fn integral_bisection_search<K: Num + Clone, T: Num + PartialOrd>(
26    f: impl Fn(&K) -> T, z: T, lb: K, ub: K,
27) -> Option<K> {
28    if !(f(&lb)..=f(&ub)).contains(&z) {
29        return None;
30    }
31    let two = K::one() + K::one();
32    let mut lb = lb;
33    let mut ub = ub;
34    loop {
35        let mid = (lb.clone() + ub.clone()) / two.clone();
36        if !(f(&lb)..=f(&ub)).contains(&f(&mid)) {
37            // if f found not monotone on the interval
38            return None;
39        } else if f(&lb) == z {
40            return Some(lb);
41        } else if f(&ub) == z {
42            return Some(ub);
43        } else if (lb.clone() + K::one()) == ub {
44            // no more elements to search
45            return Some(ub);
46        } else if f(&mid) >= z {
47            ub = mid;
48        } else {
49            lb = mid;
50        }
51    }
52}
53
54#[macro_use]
55#[cfg(all(test, feature = "nightly"))]
56pub mod test {
57    use super::*;
58    use crate::consts::ACC;
59    use crate::distribution::{Continuous, ContinuousCDF, Discrete, DiscreteCDF};
60
61    #[macro_export]
62    macro_rules! testing_boiler {
63        ($arg:ty, $dist:ty) => {
64            fn try_create(arg: $arg) -> $dist {
65                let n = <$dist>::new.call_once(arg);
66                assert!(n.is_ok());
67                n.unwrap()
68            }
69
70            fn bad_create_case(arg: $arg) {
71                let n = <$dist>::new.call(arg);
72                assert!(n.is_err());
73            }
74
75            fn get_value<F, T>(arg: $arg, eval: F) -> T
76            where
77                F: Fn($dist) -> T,
78            {
79                let n = try_create(arg);
80                eval(n)
81            }
82
83            fn test_case<F, T>(arg: $arg, expected: T, eval: F)
84            where
85                F: Fn($dist) -> T,
86                T: ::core::fmt::Debug + ::approx::RelativeEq<Epsilon = f64>,
87            {
88                let x = get_value(arg, eval);
89                assert_relative_eq!(expected, x, max_relative = ACC);
90            }
91
92            #[allow(dead_code)] // This is not used by all distributions.
93            fn test_case_special<F, T>(arg: $arg, expected: T, acc: f64, eval: F)
94            where
95                F: Fn($dist) -> T,
96                T: ::core::fmt::Debug + ::approx::AbsDiffEq<Epsilon = f64>,
97            {
98                let x = get_value(arg, eval);
99                assert_abs_diff_eq!(expected, x, epsilon = acc);
100            }
101
102            #[allow(dead_code)] // This is not used by all distributions.
103            fn test_none<F, T>(arg: $arg, eval: F)
104            where
105                F: Fn($dist) -> Option<T>,
106                T: ::core::cmp::PartialEq + ::core::fmt::Debug,
107            {
108                let x = get_value(arg, eval);
109                assert_eq!(None, x);
110            }
111        };
112    }
113
114    /// cdf should be the integral of the pdf
115    fn check_integrate_pdf_is_cdf<D: ContinuousCDF<f64, f64> + Continuous<f64, f64>>(
116        dist: &D,
117        x_min: f64,
118        x_max: f64,
119        step: f64,
120    ) {
121        let mut prev_x = x_min;
122        let mut prev_density = dist.pdf(x_min);
123        let mut sum = 0.0;
124
125        loop {
126            let x = prev_x + step;
127            let density = dist.pdf(x);
128
129            assert!(density >= 0.0);
130
131            let ln_density = dist.ln_pdf(x);
132
133            assert_almost_eq!(density.ln(), ln_density, 1e-10);
134
135            // triangle rule
136            sum += (prev_density + density) * step / 2.0;
137
138            let cdf = dist.cdf(x);
139            if (sum - cdf).abs() > 1e-3 {
140                println!("Integral of pdf doesn't equal cdf!");
141                println!("Integration from {} by {} to {} = {}", x_min, step, x, sum);
142                println!("cdf = {}", cdf);
143                panic!();
144            }
145
146            if x >= x_max {
147                break;
148            } else {
149                prev_x = x;
150                prev_density = density;
151            }
152        }
153
154        assert!(sum > 0.99);
155        assert!(sum <= 1.001);
156    }
157
158    /// cdf should be the sum of the pmf
159    fn check_sum_pmf_is_cdf<D: DiscreteCDF<u64, f64> + Discrete<u64, f64>>(dist: &D, x_max: u64) {
160        let mut sum = 0.0;
161
162        // go slightly beyond x_max to test for off-by-one errors
163        for i in 0..x_max + 3 {
164            let prob = dist.pmf(i);
165
166            assert!(prob >= 0.0);
167            assert!(prob <= 1.0);
168
169            sum += prob;
170
171            if i == x_max {
172                assert!(sum > 0.99);
173            }
174
175            assert_almost_eq!(sum, dist.cdf(i), 1e-10);
176            // assert_almost_eq!(sum, dist.cdf(i as f64), 1e-10);
177            // assert_almost_eq!(sum, dist.cdf(i as f64 + 0.1), 1e-10);
178            // assert_almost_eq!(sum, dist.cdf(i as f64 + 0.5), 1e-10);
179            // assert_almost_eq!(sum, dist.cdf(i as f64 + 0.9), 1e-10);
180        }
181
182        assert!(sum > 0.99);
183        assert!(sum <= 1.0 + 1e-10);
184    }
185
186    /// Does a series of checks that all continuous distributions must obey.
187    /// 99% of the probability mass should be between x_min and x_max.
188    pub fn check_continuous_distribution<D: ContinuousCDF<f64, f64> + Continuous<f64, f64>>(
189        dist: &D,
190        x_min: f64,
191        x_max: f64,
192    ) {
193        assert_eq!(dist.pdf(f64::NEG_INFINITY), 0.0);
194        assert_eq!(dist.pdf(f64::INFINITY), 0.0);
195        assert_eq!(dist.ln_pdf(f64::NEG_INFINITY), f64::NEG_INFINITY);
196        assert_eq!(dist.ln_pdf(f64::INFINITY), f64::NEG_INFINITY);
197        assert_eq!(dist.cdf(f64::NEG_INFINITY), 0.0);
198        assert_eq!(dist.cdf(f64::INFINITY), 1.0);
199
200        check_integrate_pdf_is_cdf(dist, x_min, x_max, (x_max - x_min) / 100000.0);
201    }
202
203    /// Does a series of checks that all positive discrete distributions must
204    /// obey.
205    /// 99% of the probability mass should be between 0 and x_max (inclusive).
206    pub fn check_discrete_distribution<D: DiscreteCDF<u64, f64> + Discrete<u64, f64>>(
207        dist: &D,
208        x_max: u64,
209    ) {
210        // assert_eq!(dist.cdf(f64::NEG_INFINITY), 0.0);
211        // assert_eq!(dist.cdf(-10.0), 0.0);
212        // assert_eq!(dist.cdf(-1.0), 0.0);
213        // assert_eq!(dist.cdf(-0.01), 0.0);
214        // assert_eq!(dist.cdf(f64::INFINITY), 1.0);
215
216        check_sum_pmf_is_cdf(dist, x_max);
217    }
218
219    #[test]
220    fn test_is_valid_multinomial() {
221        use std::f64;
222
223        let invalid = [1.0, f64::NAN, 3.0];
224        assert!(!is_valid_multinomial(&invalid, true));
225        let invalid2 = [-2.0, 5.0, 1.0, 6.2];
226        assert!(!is_valid_multinomial(&invalid2, true));
227        let invalid3 = [0.0, 0.0, 0.0];
228        assert!(!is_valid_multinomial(&invalid3, true));
229        let valid = [5.2, 0.0, 1e-15, 1000000.12];
230        assert!(is_valid_multinomial(&valid, true));
231    }
232
233    #[test]
234    fn test_is_valid_multinomial_no_zero() {
235        let invalid = [5.2, 0.0, 1e-15, 1000000.12];
236        assert!(!is_valid_multinomial(&invalid, false));
237    }
238
239    #[test]
240    fn test_integer_bisection() {
241        fn search(z: usize, data: &Vec<usize>) -> Option<usize> {
242            integral_bisection_search(|idx: &usize| data[*idx], z, 0, data.len() - 1)
243        }
244
245        let needle = 3;
246        let data = (0..5)
247            .map(|n| if n >= needle { n + 1 } else { n })
248            .collect::<Vec<_>>();
249
250        for i in 0..(data.len()) {
251            assert_eq!(search(data[i], &data), Some(i),)
252        }
253        {
254            let infimum = search(needle, &data);
255            let found_element = search(needle + 1, &data); // 4 > needle && member of range
256            assert_eq!(found_element, Some(needle));
257            assert_eq!(infimum, found_element)
258        }
259    }
260}