ndhistogram/value/weightedmean.rs
1use std::{
2 marker::PhantomData,
3 ops::{AddAssign, Div, Mul},
4};
5
6use num_traits::{Float, NumOps, One, Signed};
7
8use crate::FillWithWeighted;
9
10/// ndhistogram bin value computes the mean of the data samples provided when
11/// filling.
12///
13/// Mean has 4 type parameters:
14/// - the type that is being averaged,
15/// - the type of the weights that are being filled,
16/// - the type of the output when calculating the mean and its uncertainty,
17/// - the type that counts the number of fills.
18///
19/// This allows, for example, integers to be used when filling or counting,
20/// but a floating point type to compute the mean.
21/// In most cases, you will only need to specify the first two type parameters as
22/// sensible defaults are set for the second two type parameters.
23///
24///
25/// # Example
26/// ```rust
27/// use ndhistogram::{ndhistogram, Histogram, axis::Uniform, value::WeightedMean};
28///
29/// # fn main() -> Result<(), ndhistogram::Error> {
30/// // create a histogram and fill it with some values
31/// let mut hist = ndhistogram!(Uniform::new(10, 0.0, 10.0)?; WeightedMean<i32, i32>);
32/// hist.fill_with_weighted(&0.0, 2, 1);
33/// hist.fill_with_weighted(&0.0, 2, 2);
34/// hist.fill_with_weighted(&0.0, 4, 3);
35///
36/// let weightedmean = hist.value(&0.0);
37/// assert_eq!(weightedmean.unwrap().get(), 3.0);
38/// # Ok(()) }
39/// ```
40#[derive(Copy, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub struct WeightedMean<T = f64, W = f64, O = f64, C = u32> {
43 sumwt: T,
44 sumwt2: T,
45 sumw: W,
46 sumw2: W,
47 count: C,
48 phantom_output_type: PhantomData<O>,
49}
50
51impl<T, W, O, C> WeightedMean<T, W, O, C>
52where
53 T: Copy,
54 W: Copy,
55 O: From<T> + From<W> + From<C> + NumOps + Signed + Copy,
56 C: Copy,
57{
58 /// Factory method to create a Mean from a set of values.
59 ///
60 /// Usually this will not be used as a [Histogram](crate::Histogram) will
61 /// be responsible for creating and filling values.
62 pub fn new<I>(values: I) -> Self
63 where
64 I: IntoIterator<Item = (T, W)>,
65 Self: FillWithWeighted<T, W> + Default,
66 {
67 let mut r = Self::default();
68 values
69 .into_iter()
70 .for_each(|it| r.fill_with_weighted(it.0, it.1));
71 r
72 }
73
74 /// Get the current value of the mean.
75 pub fn get(&self) -> O {
76 self.mean()
77 }
78
79 /// Get the current value of the mean.
80 pub fn mean(&self) -> <O as Div>::Output {
81 O::from(self.sumwt) / O::from(self.sumw)
82 }
83
84 /// Get the number of times the mean value has been filled.
85 pub fn num_samples(&self) -> C {
86 self.count
87 }
88
89 /// Compute the variance of the samples.
90 pub fn variance_of_samples(&self) -> O {
91 // weighted variance is:
92 // var = ((sumwt2 - sumwt*mu) / sumw) + mu2
93 let mu = self.mean();
94 let mu2 = mu * mu;
95 mu2 + (O::from(self.sumwt2) - (O::one() + O::one()) * O::from(self.sumwt) * mu)
96 / O::from(self.sumw)
97 }
98
99 /// The square root of the variance of the samples.
100 pub fn standard_deviation_of_samples(&self) -> O
101 where
102 O: Float,
103 {
104 self.variance_of_samples().sqrt()
105 }
106
107 /// The square of the standard error of the mean.
108 pub fn variance_of_mean(&self) -> O {
109 self.variance_of_samples() * O::from(self.sumw2) / (O::from(self.sumw) * O::from(self.sumw))
110 }
111
112 /// Compute the standard error of the mean.
113 pub fn standard_error_of_mean(&self) -> O
114 where
115 O: Float,
116 {
117 self.variance_of_mean().sqrt()
118 }
119}
120
121impl<T, W, O, C> FillWithWeighted<T, W> for WeightedMean<T, W, O, C>
122where
123 T: Copy + AddAssign + Mul<W, Output = T> + Mul<T, Output = T>,
124 W: Copy + AddAssign + Mul<W, Output = W>,
125 C: AddAssign + One,
126{
127 #[inline]
128 fn fill_with_weighted(&mut self, value: T, weight: W) {
129 self.sumwt += value * weight;
130 self.sumwt2 += value * value * weight;
131 self.sumw += weight;
132 self.sumw2 += weight * weight;
133 self.count += C::one();
134 }
135}