statrs/statistics/
slice_statistics.rs

1use crate::statistics::*;
2use core::ops::{Index, IndexMut};
3
4#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
5pub struct Data<D>(D);
6
7impl<D, I> std::fmt::Display for Data<D>
8where
9    D: Clone + IntoIterator<Item = I>,
10    I: Clone + std::fmt::Display,
11{
12    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13        let mut tee = self.0.clone().into_iter();
14        write!(f, "Data([")?;
15
16        if let Some(v) = tee.next() {
17            write!(f, "{v}")?;
18        }
19        for _ in 1..5 {
20            if let Some(v) = tee.next() {
21                write!(f, ", {v}")?;
22            }
23        }
24        if tee.next().is_some() {
25            write!(f, "...")?;
26        }
27
28        write!(f, "])")
29    }
30}
31
32impl<D: AsRef<[f64]>> Index<usize> for Data<D> {
33    type Output = f64;
34
35    fn index(&self, i: usize) -> &f64 {
36        &self.0.as_ref()[i]
37    }
38}
39
40impl<D: AsMut<[f64]> + AsRef<[f64]>> IndexMut<usize> for Data<D> {
41    fn index_mut(&mut self, i: usize) -> &mut f64 {
42        &mut self.0.as_mut()[i]
43    }
44}
45
46impl<D: AsMut<[f64]> + AsRef<[f64]>> Data<D> {
47    pub fn new(data: D) -> Self {
48        Data(data)
49    }
50
51    pub fn swap(&mut self, i: usize, j: usize) {
52        self.0.as_mut().swap(i, j)
53    }
54
55    pub fn len(&self) -> usize {
56        self.0.as_ref().len()
57    }
58
59    pub fn is_empty(&self) -> bool {
60        self.0.as_ref().len() == 0
61    }
62
63    pub fn iter(&self) -> core::slice::Iter<'_, f64> {
64        self.0.as_ref().iter()
65    }
66
67    // Selection algorithm from Numerical Recipes
68    // See: https://en.wikipedia.org/wiki/Selection_algorithm
69    fn select_inplace(&mut self, rank: usize) -> f64 {
70        if rank == 0 {
71            return self.min();
72        }
73        if rank > self.len() - 1 {
74            return self.max();
75        }
76
77        let mut low = 0;
78        let mut high = self.len() - 1;
79        loop {
80            if high <= low + 1 {
81                if high == low + 1 && self[high] < self[low] {
82                    self.swap(low, high)
83                }
84                return self[rank];
85            }
86
87            let middle = (low + high) / 2;
88            self.swap(middle, low + 1);
89
90            if self[low] > self[high] {
91                self.swap(low, high);
92            }
93            if self[low + 1] > self[high] {
94                self.swap(low + 1, high);
95            }
96            if self[low] > self[low + 1] {
97                self.swap(low, low + 1);
98            }
99
100            let mut begin = low + 1;
101            let mut end = high;
102            let pivot = self[begin];
103            loop {
104                loop {
105                    begin += 1;
106                    if self[begin] >= pivot {
107                        break;
108                    }
109                }
110                loop {
111                    end -= 1;
112                    if self[end] <= pivot {
113                        break;
114                    }
115                }
116                if end < begin {
117                    break;
118                }
119                self.swap(begin, end);
120            }
121
122            self[low + 1] = self[end];
123            self[end] = pivot;
124
125            if end >= rank {
126                high = end - 1;
127            }
128            if end <= rank {
129                low = begin;
130            }
131        }
132    }
133}
134
135#[cfg(feature = "rand")]
136impl<D: AsRef<[f64]>> ::rand::distributions::Distribution<f64> for Data<D> {
137    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
138        use rand::prelude::SliceRandom;
139
140        *self.0.as_ref().choose(rng).unwrap()
141    }
142}
143
144impl<D: AsMut<[f64]> + AsRef<[f64]>> OrderStatistics<f64> for Data<D> {
145    fn order_statistic(&mut self, order: usize) -> f64 {
146        let n = self.len();
147        match order {
148            1 => self.min(),
149            _ if order == n => self.max(),
150            _ if order < 1 || order > n => f64::NAN,
151            _ => self.select_inplace(order - 1),
152        }
153    }
154
155    fn median(&mut self) -> f64 {
156        let k = self.len() / 2;
157        if self.len() % 2 != 0 {
158            self.select_inplace(k)
159        } else {
160            (self.select_inplace(k.saturating_sub(1)) + self.select_inplace(k)) / 2.0
161        }
162    }
163
164    fn quantile(&mut self, tau: f64) -> f64 {
165        if !(0.0..=1.0).contains(&tau) || self.is_empty() {
166            return f64::NAN;
167        }
168
169        let h = (self.len() as f64 + 1.0 / 3.0) * tau + 1.0 / 3.0;
170        let hf = h as i64;
171
172        if hf <= 0 || tau == 0.0 {
173            return self.min();
174        }
175        if hf >= self.len() as i64 || ulps_eq!(tau, 1.0) {
176            return self.max();
177        }
178
179        let a = self.select_inplace((hf as usize).saturating_sub(1));
180        let b = self.select_inplace(hf as usize);
181        a + (h - hf as f64) * (b - a)
182    }
183
184    fn percentile(&mut self, p: usize) -> f64 {
185        self.quantile(p as f64 / 100.0)
186    }
187
188    fn lower_quartile(&mut self) -> f64 {
189        self.quantile(0.25)
190    }
191
192    fn upper_quartile(&mut self) -> f64 {
193        self.quantile(0.75)
194    }
195
196    fn interquartile_range(&mut self) -> f64 {
197        self.upper_quartile() - self.lower_quartile()
198    }
199
200    fn ranks(&mut self, tie_breaker: RankTieBreaker) -> Vec<f64> {
201        let n = self.len();
202        let mut ranks: Vec<f64> = vec![0.0; n];
203        let mut enumerated: Vec<_> = self.iter().enumerate().collect();
204        enumerated.sort_by(|(_, el_a), (_, el_b)| el_a.partial_cmp(el_b).unwrap());
205        match tie_breaker {
206            RankTieBreaker::First => {
207                for (i, idx) in enumerated.into_iter().map(|(idx, _)| idx).enumerate() {
208                    ranks[idx] = (i + 1) as f64
209                }
210                ranks
211            }
212            _ => {
213                let mut prev = 0;
214                let mut prev_idx = 0;
215                let mut prev_elt = 0.0;
216                for (i, (idx, elt)) in enumerated.iter().cloned().enumerate() {
217                    if i == 0 {
218                        prev_idx = idx;
219                        prev_elt = *elt;
220                    }
221                    if (*elt - prev_elt).abs() <= 0.0 {
222                        continue;
223                    }
224                    if i == prev + 1 {
225                        ranks[prev_idx] = i as f64;
226                    } else {
227                        handle_rank_ties(&mut ranks, &enumerated, prev, i, tie_breaker);
228                    }
229                    prev = i;
230                    prev_idx = idx;
231                    prev_elt = *elt;
232                }
233
234                handle_rank_ties(&mut ranks, &enumerated, prev, n, tie_breaker);
235                ranks
236            }
237        }
238    }
239}
240
241impl<D: AsMut<[f64]> + AsRef<[f64]>> Min<f64> for Data<D> {
242    /// Returns the minimum value in the data
243    ///
244    /// # Remarks
245    ///
246    /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN`
247    ///
248    /// # Examples
249    ///
250    /// ```
251    /// use statrs::statistics::Min;
252    /// use statrs::statistics::Data;
253    ///
254    /// let x = [];
255    /// let x = Data::new(x);
256    /// assert!(x.min().is_nan());
257    ///
258    /// let y = [0.0, f64::NAN, 3.0, -2.0];
259    /// let y = Data::new(y);
260    /// assert!(y.min().is_nan());
261    ///
262    /// let z = [0.0, 3.0, -2.0];
263    /// let z = Data::new(z);
264    /// assert_eq!(z.min(), -2.0);
265    /// ```
266    fn min(&self) -> f64 {
267        Statistics::min(self.iter())
268    }
269}
270
271impl<D: AsMut<[f64]> + AsRef<[f64]>> Max<f64> for Data<D> {
272    /// Returns the maximum value in the data
273    ///
274    /// # Remarks
275    ///
276    /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN`
277    ///
278    /// # Examples
279    ///
280    /// ```
281    /// use statrs::statistics::Max;
282    /// use statrs::statistics::Data;
283    ///
284    /// let x = [];
285    /// let x = Data::new(x);
286    /// assert!(x.max().is_nan());
287    ///
288    /// let y = [0.0, f64::NAN, 3.0, -2.0];
289    /// let y = Data::new(y);
290    /// assert!(y.max().is_nan());
291    ///
292    /// let z = [0.0, 3.0, -2.0];
293    /// let z = Data::new(z);
294    /// assert_eq!(z.max(), 3.0);
295    /// ```
296    fn max(&self) -> f64 {
297        Statistics::max(self.iter())
298    }
299}
300
301impl<D: AsMut<[f64]> + AsRef<[f64]>> Distribution<f64> for Data<D> {
302    /// Evaluates the sample mean, an estimate of the population
303    /// mean.
304    ///
305    /// # Remarks
306    ///
307    /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN`
308    ///
309    /// # Examples
310    ///
311    /// ```
312    /// #[macro_use]
313    /// extern crate statrs;
314    ///
315    /// use statrs::statistics::Distribution;
316    /// use statrs::statistics::Data;
317    ///
318    /// # fn main() {
319    /// let x = [];
320    /// let x = Data::new(x);
321    /// assert!(x.mean().unwrap().is_nan());
322    ///
323    /// let y = [0.0, f64::NAN, 3.0, -2.0];
324    /// let y = Data::new(y);
325    /// assert!(y.mean().unwrap().is_nan());
326    ///
327    /// let z = [0.0, 3.0, -2.0];
328    /// let z = Data::new(z);
329    /// assert_almost_eq!(z.mean().unwrap(), 1.0 / 3.0, 1e-15);
330    /// # }
331    /// ```
332    fn mean(&self) -> Option<f64> {
333        Some(Statistics::mean(self.iter()))
334    }
335
336    /// Estimates the unbiased population variance from the provided samples
337    ///
338    /// # Remarks
339    ///
340    /// On a dataset of size `N`, `N-1` is used as a normalizer (Bessel's
341    /// correction).
342    ///
343    /// Returns `f64::NAN` if data has less than two entries or if any entry is
344    /// `f64::NAN`
345    ///
346    /// # Examples
347    ///
348    /// ```
349    /// use statrs::statistics::Distribution;
350    /// use statrs::statistics::Data;
351    ///
352    /// let x = [];
353    /// let x = Data::new(x);
354    /// assert!(x.variance().unwrap().is_nan());
355    ///
356    /// let y = [0.0, f64::NAN, 3.0, -2.0];
357    /// let y = Data::new(y);
358    /// assert!(y.variance().unwrap().is_nan());
359    ///
360    /// let z = [0.0, 3.0, -2.0];
361    /// let z = Data::new(z);
362    /// assert_eq!(z.variance().unwrap(), 19.0 / 3.0);
363    /// ```
364    fn variance(&self) -> Option<f64> {
365        Some(Statistics::variance(self.iter()))
366    }
367}
368
369impl<D: AsMut<[f64]> + AsRef<[f64]> + Clone> Median<f64> for Data<D> {
370    /// Returns the median value from the data
371    ///
372    /// # Remarks
373    ///
374    /// Returns `f64::NAN` if data is empty
375    ///
376    /// # Examples
377    ///
378    /// ```
379    /// use statrs::statistics::Median;
380    /// use statrs::statistics::Data;
381    ///
382    /// let x = [];
383    /// let x = Data::new(x);
384    /// assert!(x.median().is_nan());
385    ///
386    /// let y = [0.0, 3.0, -2.0];
387    /// let y = Data::new(y);
388    /// assert_eq!(y.median(), 0.0);
389    fn median(&self) -> f64 {
390        let mut v = self.clone();
391        OrderStatistics::median(&mut v)
392    }
393}
394
395fn handle_rank_ties(
396    ranks: &mut [f64],
397    index: &[(usize, &f64)],
398    a: usize,
399    b: usize,
400    tie_breaker: RankTieBreaker,
401) {
402    let rank = match tie_breaker {
403        // equivalent to (b + a - 1) as f64 / 2.0 + 1.0 but less overflow issues
404        RankTieBreaker::Average => b as f64 / 2.0 + a as f64 / 2.0 + 0.5,
405        RankTieBreaker::Min => (a + 1) as f64,
406        RankTieBreaker::Max => b as f64,
407        RankTieBreaker::First => unreachable!(),
408    };
409    for i in &index[a..b] {
410        ranks[i.0] = rank
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use crate::statistics::*;
418
419    #[test]
420    fn test_order_statistic_short() {
421        let data = [-1.0, 5.0, 0.0, -3.0, 10.0, -0.5, 4.0, 1.0, 6.0];
422        let mut data = Data::new(data);
423        assert!(data.order_statistic(0).is_nan());
424        assert_eq!(data.order_statistic(1), -3.0);
425        assert_eq!(data.order_statistic(2), -1.0);
426        assert_eq!(data.order_statistic(3), -0.5);
427        assert_eq!(data.order_statistic(7), 5.0);
428        assert_eq!(data.order_statistic(8), 6.0);
429        assert_eq!(data.order_statistic(9), 10.0);
430        assert!(data.order_statistic(10).is_nan());
431    }
432
433    #[test]
434    fn test_quantile_short() {
435        let data = [-1.0, 5.0, 0.0, -3.0, 10.0, -0.5, 4.0, 0.2, 1.0, 6.0];
436        let mut data = Data::new(data);
437        assert_eq!(data.quantile(0.0), -3.0);
438        assert_eq!(data.quantile(1.0), 10.0);
439        assert_almost_eq!(data.quantile(0.5), 3.0 / 5.0, 1e-15);
440        assert_almost_eq!(data.quantile(0.2), -4.0 / 5.0, 1e-15);
441        assert_eq!(data.quantile(0.7), 137.0 / 30.0);
442        assert_eq!(data.quantile(0.01), -3.0);
443        assert_eq!(data.quantile(0.99), 10.0);
444        assert_almost_eq!(data.quantile(0.52), 287.0 / 375.0, 1e-15);
445        assert_almost_eq!(data.quantile(0.325), -37.0 / 240.0, 1e-15);
446    }
447
448    #[test]
449    fn test_ranks() {
450        let sorted_distinct = [1.0, 2.0, 4.0, 7.0, 8.0, 9.0, 10.0, 12.0];
451        let mut sorted_distinct = Data::new(sorted_distinct);
452        let sorted_ties = [1.0, 2.0, 2.0, 7.0, 9.0, 9.0, 10.0, 12.0];
453        let mut sorted_ties = Data::new(sorted_ties);
454        assert_eq!(
455            sorted_distinct.ranks(RankTieBreaker::Average),
456            [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
457        );
458        assert_eq!(
459            sorted_ties.ranks(RankTieBreaker::Average),
460            [1.0, 2.5, 2.5, 4.0, 5.5, 5.5, 7.0, 8.0]
461        );
462        assert_eq!(
463            sorted_distinct.ranks(RankTieBreaker::Min),
464            [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
465        );
466        assert_eq!(
467            sorted_ties.ranks(RankTieBreaker::Min),
468            [1.0, 2.0, 2.0, 4.0, 5.0, 5.0, 7.0, 8.0]
469        );
470        assert_eq!(
471            sorted_distinct.ranks(RankTieBreaker::Max),
472            [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
473        );
474        assert_eq!(
475            sorted_ties.ranks(RankTieBreaker::Max),
476            [1.0, 3.0, 3.0, 4.0, 6.0, 6.0, 7.0, 8.0]
477        );
478        assert_eq!(
479            sorted_distinct.ranks(RankTieBreaker::First),
480            [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
481        );
482        assert_eq!(
483            sorted_ties.ranks(RankTieBreaker::First),
484            [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
485        );
486
487        let distinct = [1.0, 8.0, 12.0, 7.0, 2.0, 9.0, 10.0, 4.0];
488        let distinct = Data::new(distinct);
489        let ties = [1.0, 9.0, 12.0, 7.0, 2.0, 9.0, 10.0, 2.0];
490        let ties = Data::new(ties);
491        assert_eq!(
492            distinct.clone().ranks(RankTieBreaker::Average),
493            [1.0, 5.0, 8.0, 4.0, 2.0, 6.0, 7.0, 3.0]
494        );
495        assert_eq!(
496            ties.clone().ranks(RankTieBreaker::Average),
497            [1.0, 5.5, 8.0, 4.0, 2.5, 5.5, 7.0, 2.5]
498        );
499        assert_eq!(
500            distinct.clone().ranks(RankTieBreaker::Min),
501            [1.0, 5.0, 8.0, 4.0, 2.0, 6.0, 7.0, 3.0]
502        );
503        assert_eq!(
504            ties.clone().ranks(RankTieBreaker::Min),
505            [1.0, 5.0, 8.0, 4.0, 2.0, 5.0, 7.0, 2.0]
506        );
507        assert_eq!(
508            distinct.clone().ranks(RankTieBreaker::Max),
509            [1.0, 5.0, 8.0, 4.0, 2.0, 6.0, 7.0, 3.0]
510        );
511        assert_eq!(
512            ties.clone().ranks(RankTieBreaker::Max),
513            [1.0, 6.0, 8.0, 4.0, 3.0, 6.0, 7.0, 3.0]
514        );
515        assert_eq!(
516            distinct.clone().ranks(RankTieBreaker::First),
517            [1.0, 5.0, 8.0, 4.0, 2.0, 6.0, 7.0, 3.0]
518        );
519        assert_eq!(
520            ties.clone().ranks(RankTieBreaker::First),
521            [1.0, 5.0, 8.0, 4.0, 2.0, 6.0, 7.0, 3.0]
522        );
523    }
524
525    #[test]
526    fn test_median_short() {
527        let even = [-1.0, 5.0, 0.0, -3.0, 10.0, -0.5, 4.0, 0.2, 1.0, 6.0];
528        let even = Data::new(even);
529        assert_eq!(even.median(), 0.6);
530
531        let odd = [-1.0, 5.0, 0.0, -3.0, 10.0, -0.5, 4.0, 0.2, 1.0];
532        let odd = Data::new(odd);
533        assert_eq!(odd.median(), 0.2);
534    }
535
536    #[test]
537    fn test_median_long_constant_seq() {
538        let even = vec![2.0; 100000];
539        let even = Data::new(even);
540        assert_eq!(2.0, even.median());
541
542        let odd = vec![2.0; 100001];
543        let odd = Data::new(odd);
544        assert_eq!(2.0, odd.median());
545    }
546
547    // TODO: test codeplex issue 5667 (Math.NET)
548
549    #[test]
550    fn test_median_robust_on_infinities() {
551        let data3 = [2.0, f64::NEG_INFINITY, f64::INFINITY];
552        let data3 = Data::new(data3);
553        assert_eq!(data3.median(), 2.0);
554        assert_eq!(data3.median(), 2.0);
555
556        let data3 = [f64::NEG_INFINITY, 2.0, f64::INFINITY];
557        let data3 = Data::new(data3);
558        assert_eq!(data3.median(), 2.0);
559        assert_eq!(data3.median(), 2.0);
560
561        let data3 = [f64::NEG_INFINITY, f64::INFINITY, 2.0];
562        let data3 = Data::new(data3);
563        assert_eq!(data3.median(), 2.0);
564        assert_eq!(data3.median(), 2.0);
565
566        let data4 = [f64::NEG_INFINITY, 2.0, 3.0, f64::INFINITY];
567        let data4 = Data::new(data4);
568        assert_eq!(data4.median(), 2.5);
569        assert_eq!(data4.median(), 2.5);
570    }
571    #[test]
572    fn test_foo() {
573        let arr = [0.0, 1.0, 2.0, 3.0];
574        let mut arr = Data::new(arr);
575        arr.order_statistic(2);
576    }
577}