1use crate::distribution::ContinuousCDF;
2use crate::statistics::*;
3use non_nan::NonNan;
4use std::collections::btree_map::{BTreeMap, Entry};
5use std::convert::Infallible;
6use std::ops::Bound;
7
8mod non_nan {
9 use core::cmp::Ordering;
10
11 #[derive(Clone, Copy, PartialEq, Debug)]
12 pub struct NonNan<T>(T);
13
14 impl<T: Copy> NonNan<T> {
15 pub fn get(self) -> T {
16 self.0
17 }
18 }
19
20 impl NonNan<f64> {
21 #[inline]
22 pub fn new(x: f64) -> Option<Self> {
23 if x.is_nan() {
24 None
25 } else {
26 Some(Self(x))
27 }
28 }
29 }
30
31 impl<T: PartialEq> Eq for NonNan<T> {}
32
33 impl<T: PartialOrd> PartialOrd for NonNan<T> {
34 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
35 Some(self.cmp(other))
36 }
37 }
38
39 impl<T: PartialOrd> Ord for NonNan<T> {
40 fn cmp(&self, other: &Self) -> Ordering {
41 self.0.partial_cmp(&other.0).unwrap()
42 }
43 }
44}
45
46#[derive(Clone, PartialEq, Debug)]
61pub struct Empirical {
62 data: BTreeMap<NonNan<f64>, u64>,
64
65 sum: u64,
69 mean: f64,
70 var: f64,
71}
72
73impl Empirical {
74 pub fn new() -> Result<Empirical, Infallible> {
88 Ok(Empirical {
89 data: BTreeMap::new(),
90 sum: 0,
91 mean: 0.0,
92 var: 0.0,
93 })
94 }
95
96 pub fn add(&mut self, data_point: f64) {
97 let map_key = match NonNan::new(data_point) {
98 Some(valid) => valid,
99 None => return,
100 };
101
102 self.sum += 1;
103 let sum = self.sum as f64;
104 self.var += (sum - 1.) * (data_point - self.mean) * (data_point - self.mean) / sum;
105 self.mean += (data_point - self.mean) / sum;
106
107 self.data
108 .entry(map_key)
109 .and_modify(|c| *c += 1)
110 .or_insert(1);
111 }
112
113 pub fn remove(&mut self, data_point: f64) {
114 let map_key = match NonNan::new(data_point) {
115 Some(valid) => valid,
116 None => return,
117 };
118
119 let mut entry = match self.data.entry(map_key) {
120 Entry::Occupied(entry) => entry,
121 Entry::Vacant(_) => return, };
123
124 if *entry.get() == 1 {
125 entry.remove();
126 if self.data.is_empty() {
127 self.sum = 0;
130 self.mean = 0.0;
131 self.var = 0.0;
132 return;
133 }
134 } else {
135 *entry.get_mut() -= 1;
136 }
137
138 let sum = self.sum as f64;
140 self.mean = (sum * self.mean - data_point) / (sum - 1.);
141 self.var -= (sum - 1.) * (data_point - self.mean) * (data_point - self.mean) / sum;
142 self.sum -= 1;
143 }
144
145 fn __inverse_cdf(&self, p: f64) -> f64 {
154 if p == 0.0 {
155 return self.min();
156 };
157 if p == 1.0 {
158 return self.max();
159 };
160 let mut high = 2.0;
161 let mut low = -high;
162 while self.cdf(low) > p {
163 low = low + low;
164 }
165 while self.cdf(high) < p {
166 high = high + high;
167 }
168 let mut i = 16;
169 while i != 0 {
170 let mid = (high + low) / 2.0;
171 if self.cdf(mid) >= p {
172 high = mid;
173 } else {
174 low = mid;
175 }
176 i -= 1;
177 }
178 (high + low) / 2.0
179 }
180}
181
182impl std::fmt::Display for Empirical {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 let mut enumerated_values = self
185 .data
186 .iter()
187 .flat_map(|(x, &count)| std::iter::repeat(x.get()).take(count as usize));
188
189 if let Some(x) = enumerated_values.next() {
190 write!(f, "Empirical([{x:.3e}")?;
191 } else {
192 return write!(f, "Empirical(∅)");
193 }
194
195 for val in enumerated_values.by_ref().take(4) {
196 write!(f, ", {val:.3e}")?;
197 }
198 if enumerated_values.next().is_some() {
199 write!(f, ", ...")?;
200 }
201 write!(f, "])")
202 }
203}
204
205impl FromIterator<f64> for Empirical {
206 fn from_iter<T: IntoIterator<Item = f64>>(iter: T) -> Self {
207 let mut empirical = Self::new().unwrap();
208 for elt in iter {
209 empirical.add(elt);
210 }
211 empirical
212 }
213}
214
215#[cfg(feature = "rand")]
216#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
217impl ::rand::distributions::Distribution<f64> for Empirical {
218 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
219 use crate::distribution::Uniform;
220
221 let uniform = Uniform::new(0.0, 1.0).unwrap();
222 self.__inverse_cdf(uniform.sample(rng))
223 }
224}
225
226impl Max<f64> for Empirical {
228 fn max(&self) -> f64 {
229 self.data.keys().rev().map(|key| key.get()).next().unwrap()
230 }
231}
232
233impl Min<f64> for Empirical {
235 fn min(&self) -> f64 {
236 self.data.keys().map(|key| key.get()).next().unwrap()
237 }
238}
239
240impl Distribution<f64> for Empirical {
241 fn mean(&self) -> Option<f64> {
242 if self.data.is_empty() {
243 None
244 } else {
245 Some(self.mean)
246 }
247 }
248
249 fn variance(&self) -> Option<f64> {
250 if self.data.is_empty() {
251 None
252 } else {
253 Some(self.var / (self.sum as f64 - 1.))
254 }
255 }
256}
257
258impl ContinuousCDF<f64, f64> for Empirical {
259 fn cdf(&self, x: f64) -> f64 {
260 let start = Bound::Unbounded;
261 let end = Bound::Included(NonNan::new(x).expect("x must not be NaN"));
262
263 let sum: u64 = self.data.range((start, end)).map(|(_, v)| v).sum();
264 sum as f64 / self.sum as f64
265 }
266
267 fn sf(&self, x: f64) -> f64 {
268 let start = Bound::Excluded(NonNan::new(x).expect("x must not be NaN"));
269 let end = Bound::Unbounded;
270
271 let sum: u64 = self.data.range((start, end)).map(|(_, v)| v).sum();
272 sum as f64 / self.sum as f64
273 }
274
275 fn inverse_cdf(&self, p: f64) -> f64 {
276 self.__inverse_cdf(p)
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_add_nan() {
286 let mut empirical = Empirical::new().unwrap();
287
288 empirical.add(f64::NAN);
290 }
291
292 #[test]
293 fn test_remove_nan() {
294 let mut empirical = Empirical::new().unwrap();
295
296 empirical.add(5.2);
297 empirical.remove(f64::NAN);
299 }
300
301 #[test]
302 fn test_remove_nonexisting() {
303 let mut empirical = Empirical::new().unwrap();
304
305 empirical.add(5.2);
306 empirical.remove(10.0);
308 }
309
310 #[test]
311 fn test_remove_all() {
312 let mut empirical = Empirical::new().unwrap();
313
314 empirical.add(17.123);
315 empirical.add(-10.0);
316 empirical.add(0.0);
317 empirical.remove(-10.0);
318 empirical.remove(17.123);
319 empirical.remove(0.0);
320
321 assert!(empirical.mean().is_none());
322 assert!(empirical.variance().is_none());
323 }
324
325 #[test]
326 fn test_mean() {
327 fn test_mean_for_samples(expected_mean: f64, samples: Vec<f64>) {
328 let dist = Empirical::from_iter(samples);
329 assert_relative_eq!(dist.mean().unwrap(), expected_mean);
330 }
331
332 let dist = Empirical::from_iter(vec![]);
333 assert!(dist.mean().is_none());
334
335 test_mean_for_samples(4.0, vec![4.0; 100]);
336 test_mean_for_samples(-0.2, vec![-0.2; 100]);
337 test_mean_for_samples(28.5, vec![21.3, 38.4, 12.7, 41.6]);
338 }
339
340 #[test]
341 fn test_var() {
342 fn test_var_for_samples(expected_var: f64, samples: Vec<f64>) {
343 let dist = Empirical::from_iter(samples);
344 assert_relative_eq!(dist.variance().unwrap(), expected_var);
345 }
346
347 let dist = Empirical::from_iter(vec![]);
348 assert!(dist.variance().is_none());
349
350 test_var_for_samples(0.0, vec![4.0; 100]);
351 test_var_for_samples(0.0, vec![-0.2; 100]);
352 test_var_for_samples(190.36666666666667, vec![21.3, 38.4, 12.7, 41.6]);
353 }
354
355 #[test]
356 fn test_cdf() {
357 let samples = vec![5.0, 10.0];
358 let mut empirical = Empirical::from_iter(samples);
359 assert_eq!(empirical.cdf(0.0), 0.0);
360 assert_eq!(empirical.cdf(5.0), 0.5);
361 assert_eq!(empirical.cdf(5.5), 0.5);
362 assert_eq!(empirical.cdf(6.0), 0.5);
363 assert_eq!(empirical.cdf(10.0), 1.0);
364 assert_eq!(empirical.min(), 5.0);
365 assert_eq!(empirical.max(), 10.0);
366 empirical.add(2.0);
367 empirical.add(2.0);
368 assert_eq!(empirical.cdf(0.0), 0.0);
369 assert_eq!(empirical.cdf(5.0), 0.75);
370 assert_eq!(empirical.cdf(5.5), 0.75);
371 assert_eq!(empirical.cdf(6.0), 0.75);
372 assert_eq!(empirical.cdf(10.0), 1.0);
373 assert_eq!(empirical.min(), 2.0);
374 assert_eq!(empirical.max(), 10.0);
375 let unchanged = empirical.clone();
376 empirical.add(2.0);
377 empirical.remove(2.0);
378 assert_eq!(unchanged, empirical);
381 }
382
383 #[test]
384 fn test_sf() {
385 let samples = vec![5.0, 10.0];
386 let mut empirical = Empirical::from_iter(samples);
387 assert_eq!(empirical.sf(0.0), 1.0);
388 assert_eq!(empirical.sf(5.0), 0.5);
389 assert_eq!(empirical.sf(5.5), 0.5);
390 assert_eq!(empirical.sf(6.0), 0.5);
391 assert_eq!(empirical.sf(10.0), 0.0);
392 assert_eq!(empirical.min(), 5.0);
393 assert_eq!(empirical.max(), 10.0);
394 empirical.add(2.0);
395 empirical.add(2.0);
396 assert_eq!(empirical.sf(0.0), 1.0);
397 assert_eq!(empirical.sf(5.0), 0.25);
398 assert_eq!(empirical.sf(5.5), 0.25);
399 assert_eq!(empirical.sf(6.0), 0.25);
400 assert_eq!(empirical.sf(10.0), 0.0);
401 assert_eq!(empirical.min(), 2.0);
402 assert_eq!(empirical.max(), 10.0);
403 let unchanged = empirical.clone();
404 empirical.add(2.0);
405 empirical.remove(2.0);
406 assert_eq!(unchanged, empirical);
409 }
410
411 #[test]
412 fn test_display() {
413 let mut e = Empirical::new().unwrap();
414 assert_eq!(e.to_string(), "Empirical(∅)");
415 e.add(1.0);
416 assert_eq!(e.to_string(), "Empirical([1.000e0])");
417 e.add(1.0);
418 assert_eq!(e.to_string(), "Empirical([1.000e0, 1.000e0])");
419 e.add(2.0);
420 assert_eq!(e.to_string(), "Empirical([1.000e0, 1.000e0, 2.000e0])");
421 e.add(2.0);
422 assert_eq!(
423 e.to_string(),
424 "Empirical([1.000e0, 1.000e0, 2.000e0, 2.000e0])"
425 );
426 e.add(5.0);
427 assert_eq!(
428 e.to_string(),
429 "Empirical([1.000e0, 1.000e0, 2.000e0, 2.000e0, 5.000e0])"
430 );
431 e.add(5.0);
432 assert_eq!(
433 e.to_string(),
434 "Empirical([1.000e0, 1.000e0, 2.000e0, 2.000e0, 5.000e0, ...])"
435 );
436 }
437}