1use num_traits::{Bounded, Float, Num};
2
3pub fn is_valid_multinomial(arr: &[f64], incl_zero: bool) -> bool {
7 let mut sum = 0.0;
8 for &elt in arr {
9 if incl_zero && elt < 0.0 || !incl_zero && elt <= 0.0 || elt.is_nan() {
10 return false;
11 }
12 sum += elt;
13 }
14 sum != 0.0
15}
16
17pub fn integral_bisection_search<K: Num + Clone, T: Num + PartialOrd>(
26 f: impl Fn(&K) -> T, z: T, lb: K, ub: K,
27) -> Option<K> {
28 if !(f(&lb)..=f(&ub)).contains(&z) {
29 return None;
30 }
31 let two = K::one() + K::one();
32 let mut lb = lb;
33 let mut ub = ub;
34 loop {
35 let mid = (lb.clone() + ub.clone()) / two.clone();
36 if !(f(&lb)..=f(&ub)).contains(&f(&mid)) {
37 return None;
39 } else if f(&lb) == z {
40 return Some(lb);
41 } else if f(&ub) == z {
42 return Some(ub);
43 } else if (lb.clone() + K::one()) == ub {
44 return Some(ub);
46 } else if f(&mid) >= z {
47 ub = mid;
48 } else {
49 lb = mid;
50 }
51 }
52}
53
54#[macro_use]
55#[cfg(all(test, feature = "nightly"))]
56pub mod test {
57 use super::*;
58 use crate::consts::ACC;
59 use crate::distribution::{Continuous, ContinuousCDF, Discrete, DiscreteCDF};
60
61 #[macro_export]
62 macro_rules! testing_boiler {
63 ($arg:ty, $dist:ty) => {
64 fn try_create(arg: $arg) -> $dist {
65 let n = <$dist>::new.call_once(arg);
66 assert!(n.is_ok());
67 n.unwrap()
68 }
69
70 fn bad_create_case(arg: $arg) {
71 let n = <$dist>::new.call(arg);
72 assert!(n.is_err());
73 }
74
75 fn get_value<F, T>(arg: $arg, eval: F) -> T
76 where
77 F: Fn($dist) -> T,
78 {
79 let n = try_create(arg);
80 eval(n)
81 }
82
83 fn test_case<F, T>(arg: $arg, expected: T, eval: F)
84 where
85 F: Fn($dist) -> T,
86 T: ::core::fmt::Debug + ::approx::RelativeEq<Epsilon = f64>,
87 {
88 let x = get_value(arg, eval);
89 assert_relative_eq!(expected, x, max_relative = ACC);
90 }
91
92 #[allow(dead_code)] fn test_case_special<F, T>(arg: $arg, expected: T, acc: f64, eval: F)
94 where
95 F: Fn($dist) -> T,
96 T: ::core::fmt::Debug + ::approx::AbsDiffEq<Epsilon = f64>,
97 {
98 let x = get_value(arg, eval);
99 assert_abs_diff_eq!(expected, x, epsilon = acc);
100 }
101
102 #[allow(dead_code)] fn test_none<F, T>(arg: $arg, eval: F)
104 where
105 F: Fn($dist) -> Option<T>,
106 T: ::core::cmp::PartialEq + ::core::fmt::Debug,
107 {
108 let x = get_value(arg, eval);
109 assert_eq!(None, x);
110 }
111 };
112 }
113
114 fn check_integrate_pdf_is_cdf<D: ContinuousCDF<f64, f64> + Continuous<f64, f64>>(
116 dist: &D,
117 x_min: f64,
118 x_max: f64,
119 step: f64,
120 ) {
121 let mut prev_x = x_min;
122 let mut prev_density = dist.pdf(x_min);
123 let mut sum = 0.0;
124
125 loop {
126 let x = prev_x + step;
127 let density = dist.pdf(x);
128
129 assert!(density >= 0.0);
130
131 let ln_density = dist.ln_pdf(x);
132
133 assert_almost_eq!(density.ln(), ln_density, 1e-10);
134
135 sum += (prev_density + density) * step / 2.0;
137
138 let cdf = dist.cdf(x);
139 if (sum - cdf).abs() > 1e-3 {
140 println!("Integral of pdf doesn't equal cdf!");
141 println!("Integration from {} by {} to {} = {}", x_min, step, x, sum);
142 println!("cdf = {}", cdf);
143 panic!();
144 }
145
146 if x >= x_max {
147 break;
148 } else {
149 prev_x = x;
150 prev_density = density;
151 }
152 }
153
154 assert!(sum > 0.99);
155 assert!(sum <= 1.001);
156 }
157
158 fn check_sum_pmf_is_cdf<D: DiscreteCDF<u64, f64> + Discrete<u64, f64>>(dist: &D, x_max: u64) {
160 let mut sum = 0.0;
161
162 for i in 0..x_max + 3 {
164 let prob = dist.pmf(i);
165
166 assert!(prob >= 0.0);
167 assert!(prob <= 1.0);
168
169 sum += prob;
170
171 if i == x_max {
172 assert!(sum > 0.99);
173 }
174
175 assert_almost_eq!(sum, dist.cdf(i), 1e-10);
176 }
181
182 assert!(sum > 0.99);
183 assert!(sum <= 1.0 + 1e-10);
184 }
185
186 pub fn check_continuous_distribution<D: ContinuousCDF<f64, f64> + Continuous<f64, f64>>(
189 dist: &D,
190 x_min: f64,
191 x_max: f64,
192 ) {
193 assert_eq!(dist.pdf(f64::NEG_INFINITY), 0.0);
194 assert_eq!(dist.pdf(f64::INFINITY), 0.0);
195 assert_eq!(dist.ln_pdf(f64::NEG_INFINITY), f64::NEG_INFINITY);
196 assert_eq!(dist.ln_pdf(f64::INFINITY), f64::NEG_INFINITY);
197 assert_eq!(dist.cdf(f64::NEG_INFINITY), 0.0);
198 assert_eq!(dist.cdf(f64::INFINITY), 1.0);
199
200 check_integrate_pdf_is_cdf(dist, x_min, x_max, (x_max - x_min) / 100000.0);
201 }
202
203 pub fn check_discrete_distribution<D: DiscreteCDF<u64, f64> + Discrete<u64, f64>>(
207 dist: &D,
208 x_max: u64,
209 ) {
210 check_sum_pmf_is_cdf(dist, x_max);
217 }
218
219 #[test]
220 fn test_is_valid_multinomial() {
221 use std::f64;
222
223 let invalid = [1.0, f64::NAN, 3.0];
224 assert!(!is_valid_multinomial(&invalid, true));
225 let invalid2 = [-2.0, 5.0, 1.0, 6.2];
226 assert!(!is_valid_multinomial(&invalid2, true));
227 let invalid3 = [0.0, 0.0, 0.0];
228 assert!(!is_valid_multinomial(&invalid3, true));
229 let valid = [5.2, 0.0, 1e-15, 1000000.12];
230 assert!(is_valid_multinomial(&valid, true));
231 }
232
233 #[test]
234 fn test_is_valid_multinomial_no_zero() {
235 let invalid = [5.2, 0.0, 1e-15, 1000000.12];
236 assert!(!is_valid_multinomial(&invalid, false));
237 }
238
239 #[test]
240 fn test_integer_bisection() {
241 fn search(z: usize, data: &Vec<usize>) -> Option<usize> {
242 integral_bisection_search(|idx: &usize| data[*idx], z, 0, data.len() - 1)
243 }
244
245 let needle = 3;
246 let data = (0..5)
247 .map(|n| if n >= needle { n + 1 } else { n })
248 .collect::<Vec<_>>();
249
250 for i in 0..(data.len()) {
251 assert_eq!(search(data[i], &data), Some(i),)
252 }
253 {
254 let infimum = search(needle, &data);
255 let found_element = search(needle + 1, &data); assert_eq!(found_element, Some(needle));
257 assert_eq!(infimum, found_element)
258 }
259 }
260}