statrs/distribution/
empirical.rs

1use crate::distribution::ContinuousCDF;
2use crate::statistics::*;
3use non_nan::NonNan;
4use std::collections::btree_map::{BTreeMap, Entry};
5use std::convert::Infallible;
6use std::ops::Bound;
7
8mod non_nan {
9    use core::cmp::Ordering;
10
11    #[derive(Clone, Copy, PartialEq, Debug)]
12    pub struct NonNan<T>(T);
13
14    impl<T: Copy> NonNan<T> {
15        pub fn get(self) -> T {
16            self.0
17        }
18    }
19
20    impl NonNan<f64> {
21        #[inline]
22        pub fn new(x: f64) -> Option<Self> {
23            if x.is_nan() {
24                None
25            } else {
26                Some(Self(x))
27            }
28        }
29    }
30
31    impl<T: PartialEq> Eq for NonNan<T> {}
32
33    impl<T: PartialOrd> PartialOrd for NonNan<T> {
34        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
35            Some(self.cmp(other))
36        }
37    }
38
39    impl<T: PartialOrd> Ord for NonNan<T> {
40        fn cmp(&self, other: &Self) -> Ordering {
41            self.0.partial_cmp(&other.0).unwrap()
42        }
43    }
44}
45
46/// Implements the [Empirical
47/// Distribution](https://en.wikipedia.org/wiki/Empirical_distribution_function)
48///
49/// # Examples
50///
51/// ```
52/// use statrs::distribution::{Continuous, Empirical};
53/// use statrs::statistics::Distribution;
54///
55/// let samples = vec![0.0, 5.0, 10.0];
56///
57/// let empirical = Empirical::from_iter(samples);
58/// assert_eq!(empirical.mean().unwrap(), 5.0);
59/// ```
60#[derive(Clone, PartialEq, Debug)]
61pub struct Empirical {
62    // keys are data points, values are number of data points with equal value
63    data: BTreeMap<NonNan<f64>, u64>,
64
65    // The following fields are only logically valid if !data.is_empty():
66    /// Total amount of data points (== sum of all _values_ inside self.data).
67    /// Must be 0 iff data.is_empty()
68    sum: u64,
69    mean: f64,
70    var: f64,
71}
72
73impl Empirical {
74    /// Constructs a new discrete uniform distribution with a minimum value
75    /// of `min` and a maximum value of `max`.
76    ///
77    /// Note that this will always succeed and never return the [`Err`][Result::Err] variant.
78    ///
79    /// # Examples
80    ///
81    /// ```
82    /// use statrs::distribution::Empirical;
83    ///
84    /// let mut result = Empirical::new();
85    /// assert!(result.is_ok());
86    /// ```
87    pub fn new() -> Result<Empirical, Infallible> {
88        Ok(Empirical {
89            data: BTreeMap::new(),
90            sum: 0,
91            mean: 0.0,
92            var: 0.0,
93        })
94    }
95
96    pub fn add(&mut self, data_point: f64) {
97        let map_key = match NonNan::new(data_point) {
98            Some(valid) => valid,
99            None => return,
100        };
101
102        self.sum += 1;
103        let sum = self.sum as f64;
104        self.var += (sum - 1.) * (data_point - self.mean) * (data_point - self.mean) / sum;
105        self.mean += (data_point - self.mean) / sum;
106
107        self.data
108            .entry(map_key)
109            .and_modify(|c| *c += 1)
110            .or_insert(1);
111    }
112
113    pub fn remove(&mut self, data_point: f64) {
114        let map_key = match NonNan::new(data_point) {
115            Some(valid) => valid,
116            None => return,
117        };
118
119        let mut entry = match self.data.entry(map_key) {
120            Entry::Occupied(entry) => entry,
121            Entry::Vacant(_) => return, // no entry found
122        };
123
124        if *entry.get() == 1 {
125            entry.remove();
126            if self.data.is_empty() {
127                // logically, this should not need special handling.
128                // FP math can result in mean or var being != 0.0 though.
129                self.sum = 0;
130                self.mean = 0.0;
131                self.var = 0.0;
132                return;
133            }
134        } else {
135            *entry.get_mut() -= 1;
136        }
137
138        // reset mean and var
139        let sum = self.sum as f64;
140        self.mean = (sum * self.mean - data_point) / (sum - 1.);
141        self.var -= (sum - 1.) * (data_point - self.mean) * (data_point - self.mean) / sum;
142        self.sum -= 1;
143    }
144
145    // Due to issues with rounding and floating-point accuracy the default
146    // implementation may be ill-behaved.
147    // Specialized inverse cdfs should be used whenever possible.
148    // Performs a binary search on the domain of `cdf` to obtain an approximation
149    // of `F^-1(p) := inf { x | F(x) >= p }`. Needless to say, performance may
150    // may be lacking.
151    // This function is identical to the default method implementation in the
152    // `ContinuousCDF` trait and is used to implement the rand trait `Distribution`.
153    fn __inverse_cdf(&self, p: f64) -> f64 {
154        if p == 0.0 {
155            return self.min();
156        };
157        if p == 1.0 {
158            return self.max();
159        };
160        let mut high = 2.0;
161        let mut low = -high;
162        while self.cdf(low) > p {
163            low = low + low;
164        }
165        while self.cdf(high) < p {
166            high = high + high;
167        }
168        let mut i = 16;
169        while i != 0 {
170            let mid = (high + low) / 2.0;
171            if self.cdf(mid) >= p {
172                high = mid;
173            } else {
174                low = mid;
175            }
176            i -= 1;
177        }
178        (high + low) / 2.0
179    }
180}
181
182impl std::fmt::Display for Empirical {
183    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184        let mut enumerated_values = self
185            .data
186            .iter()
187            .flat_map(|(x, &count)| std::iter::repeat(x.get()).take(count as usize));
188
189        if let Some(x) = enumerated_values.next() {
190            write!(f, "Empirical([{x:.3e}")?;
191        } else {
192            return write!(f, "Empirical(∅)");
193        }
194
195        for val in enumerated_values.by_ref().take(4) {
196            write!(f, ", {val:.3e}")?;
197        }
198        if enumerated_values.next().is_some() {
199            write!(f, ", ...")?;
200        }
201        write!(f, "])")
202    }
203}
204
205impl FromIterator<f64> for Empirical {
206    fn from_iter<T: IntoIterator<Item = f64>>(iter: T) -> Self {
207        let mut empirical = Self::new().unwrap();
208        for elt in iter {
209            empirical.add(elt);
210        }
211        empirical
212    }
213}
214
215#[cfg(feature = "rand")]
216#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
217impl ::rand::distributions::Distribution<f64> for Empirical {
218    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
219        use crate::distribution::Uniform;
220
221        let uniform = Uniform::new(0.0, 1.0).unwrap();
222        self.__inverse_cdf(uniform.sample(rng))
223    }
224}
225
226/// Panics if number of samples is zero
227impl Max<f64> for Empirical {
228    fn max(&self) -> f64 {
229        self.data.keys().rev().map(|key| key.get()).next().unwrap()
230    }
231}
232
233/// Panics if number of samples is zero
234impl Min<f64> for Empirical {
235    fn min(&self) -> f64 {
236        self.data.keys().map(|key| key.get()).next().unwrap()
237    }
238}
239
240impl Distribution<f64> for Empirical {
241    fn mean(&self) -> Option<f64> {
242        if self.data.is_empty() {
243            None
244        } else {
245            Some(self.mean)
246        }
247    }
248
249    fn variance(&self) -> Option<f64> {
250        if self.data.is_empty() {
251            None
252        } else {
253            Some(self.var / (self.sum as f64 - 1.))
254        }
255    }
256}
257
258impl ContinuousCDF<f64, f64> for Empirical {
259    fn cdf(&self, x: f64) -> f64 {
260        let start = Bound::Unbounded;
261        let end = Bound::Included(NonNan::new(x).expect("x must not be NaN"));
262
263        let sum: u64 = self.data.range((start, end)).map(|(_, v)| v).sum();
264        sum as f64 / self.sum as f64
265    }
266
267    fn sf(&self, x: f64) -> f64 {
268        let start = Bound::Excluded(NonNan::new(x).expect("x must not be NaN"));
269        let end = Bound::Unbounded;
270
271        let sum: u64 = self.data.range((start, end)).map(|(_, v)| v).sum();
272        sum as f64 / self.sum as f64
273    }
274
275    fn inverse_cdf(&self, p: f64) -> f64 {
276        self.__inverse_cdf(p)
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_add_nan() {
286        let mut empirical = Empirical::new().unwrap();
287
288        // should not panic
289        empirical.add(f64::NAN);
290    }
291
292    #[test]
293    fn test_remove_nan() {
294        let mut empirical = Empirical::new().unwrap();
295
296        empirical.add(5.2);
297        // should not panic
298        empirical.remove(f64::NAN);
299    }
300
301    #[test]
302    fn test_remove_nonexisting() {
303        let mut empirical = Empirical::new().unwrap();
304
305        empirical.add(5.2);
306        // should not panic
307        empirical.remove(10.0);
308    }
309
310    #[test]
311    fn test_remove_all() {
312        let mut empirical = Empirical::new().unwrap();
313
314        empirical.add(17.123);
315        empirical.add(-10.0);
316        empirical.add(0.0);
317        empirical.remove(-10.0);
318        empirical.remove(17.123);
319        empirical.remove(0.0);
320
321        assert!(empirical.mean().is_none());
322        assert!(empirical.variance().is_none());
323    }
324
325    #[test]
326    fn test_mean() {
327        fn test_mean_for_samples(expected_mean: f64, samples: Vec<f64>) {
328            let dist = Empirical::from_iter(samples);
329            assert_relative_eq!(dist.mean().unwrap(), expected_mean);
330        }
331
332        let dist = Empirical::from_iter(vec![]);
333        assert!(dist.mean().is_none());
334
335        test_mean_for_samples(4.0, vec![4.0; 100]);
336        test_mean_for_samples(-0.2, vec![-0.2; 100]);
337        test_mean_for_samples(28.5, vec![21.3, 38.4, 12.7, 41.6]);
338    }
339
340    #[test]
341    fn test_var() {
342        fn test_var_for_samples(expected_var: f64, samples: Vec<f64>) {
343            let dist = Empirical::from_iter(samples);
344            assert_relative_eq!(dist.variance().unwrap(), expected_var);
345        }
346
347        let dist = Empirical::from_iter(vec![]);
348        assert!(dist.variance().is_none());
349
350        test_var_for_samples(0.0, vec![4.0; 100]);
351        test_var_for_samples(0.0, vec![-0.2; 100]);
352        test_var_for_samples(190.36666666666667, vec![21.3, 38.4, 12.7, 41.6]);
353    }
354
355    #[test]
356    fn test_cdf() {
357        let samples = vec![5.0, 10.0];
358        let mut empirical = Empirical::from_iter(samples);
359        assert_eq!(empirical.cdf(0.0), 0.0);
360        assert_eq!(empirical.cdf(5.0), 0.5);
361        assert_eq!(empirical.cdf(5.5), 0.5);
362        assert_eq!(empirical.cdf(6.0), 0.5);
363        assert_eq!(empirical.cdf(10.0), 1.0);
364        assert_eq!(empirical.min(), 5.0);
365        assert_eq!(empirical.max(), 10.0);
366        empirical.add(2.0);
367        empirical.add(2.0);
368        assert_eq!(empirical.cdf(0.0), 0.0);
369        assert_eq!(empirical.cdf(5.0), 0.75);
370        assert_eq!(empirical.cdf(5.5), 0.75);
371        assert_eq!(empirical.cdf(6.0), 0.75);
372        assert_eq!(empirical.cdf(10.0), 1.0);
373        assert_eq!(empirical.min(), 2.0);
374        assert_eq!(empirical.max(), 10.0);
375        let unchanged = empirical.clone();
376        empirical.add(2.0);
377        empirical.remove(2.0);
378        // because of rounding errors, this doesn't hold in general
379        // due to the mean and variance being calculated in a streaming way
380        assert_eq!(unchanged, empirical);
381    }
382
383    #[test]
384    fn test_sf() {
385        let samples = vec![5.0, 10.0];
386        let mut empirical = Empirical::from_iter(samples);
387        assert_eq!(empirical.sf(0.0), 1.0);
388        assert_eq!(empirical.sf(5.0), 0.5);
389        assert_eq!(empirical.sf(5.5), 0.5);
390        assert_eq!(empirical.sf(6.0), 0.5);
391        assert_eq!(empirical.sf(10.0), 0.0);
392        assert_eq!(empirical.min(), 5.0);
393        assert_eq!(empirical.max(), 10.0);
394        empirical.add(2.0);
395        empirical.add(2.0);
396        assert_eq!(empirical.sf(0.0), 1.0);
397        assert_eq!(empirical.sf(5.0), 0.25);
398        assert_eq!(empirical.sf(5.5), 0.25);
399        assert_eq!(empirical.sf(6.0), 0.25);
400        assert_eq!(empirical.sf(10.0), 0.0);
401        assert_eq!(empirical.min(), 2.0);
402        assert_eq!(empirical.max(), 10.0);
403        let unchanged = empirical.clone();
404        empirical.add(2.0);
405        empirical.remove(2.0);
406        // because of rounding errors, this doesn't hold in general
407        // due to the mean and variance being calculated in a streaming way
408        assert_eq!(unchanged, empirical);
409    }
410
411    #[test]
412    fn test_display() {
413        let mut e = Empirical::new().unwrap();
414        assert_eq!(e.to_string(), "Empirical(∅)");
415        e.add(1.0);
416        assert_eq!(e.to_string(), "Empirical([1.000e0])");
417        e.add(1.0);
418        assert_eq!(e.to_string(), "Empirical([1.000e0, 1.000e0])");
419        e.add(2.0);
420        assert_eq!(e.to_string(), "Empirical([1.000e0, 1.000e0, 2.000e0])");
421        e.add(2.0);
422        assert_eq!(
423            e.to_string(),
424            "Empirical([1.000e0, 1.000e0, 2.000e0, 2.000e0])"
425        );
426        e.add(5.0);
427        assert_eq!(
428            e.to_string(),
429            "Empirical([1.000e0, 1.000e0, 2.000e0, 2.000e0, 5.000e0])"
430        );
431        e.add(5.0);
432        assert_eq!(
433            e.to_string(),
434            "Empirical([1.000e0, 1.000e0, 2.000e0, 2.000e0, 5.000e0, ...])"
435        );
436    }
437}