statrs/distribution/
bernoulli.rs

1use crate::distribution::{Binomial, BinomialError, Discrete, DiscreteCDF};
2use crate::statistics::*;
3
4/// Implements the
5/// [Bernoulli](https://en.wikipedia.org/wiki/Bernoulli_distribution)
6/// distribution which is a special case of the
7/// [Binomial](https://en.wikipedia.org/wiki/Binomial_distribution)
8/// distribution where `n = 1` (referenced [Here](./struct.Binomial.html))
9///
10/// # Examples
11///
12/// ```
13/// use statrs::distribution::{Bernoulli, Discrete};
14/// use statrs::statistics::Distribution;
15///
16/// let n = Bernoulli::new(0.5).unwrap();
17/// assert_eq!(n.mean().unwrap(), 0.5);
18/// assert_eq!(n.pmf(0), 0.5);
19/// assert_eq!(n.pmf(1), 0.5);
20/// ```
21#[derive(Copy, Clone, PartialEq, Debug)]
22pub struct Bernoulli {
23    b: Binomial,
24}
25
26impl Bernoulli {
27    /// Constructs a new bernoulli distribution with
28    /// the given `p` probability of success.
29    ///
30    /// # Errors
31    ///
32    /// Returns an error if `p` is `NaN`, less than `0.0`
33    /// or greater than `1.0`
34    ///
35    /// # Examples
36    ///
37    /// ```
38    /// use statrs::distribution::Bernoulli;
39    ///
40    /// let mut result = Bernoulli::new(0.5);
41    /// assert!(result.is_ok());
42    ///
43    /// result = Bernoulli::new(-0.5);
44    /// assert!(result.is_err());
45    /// ```
46    pub fn new(p: f64) -> Result<Bernoulli, BinomialError> {
47        Binomial::new(p, 1).map(|b| Bernoulli { b })
48    }
49
50    /// Returns the probability of success `p` of the
51    /// bernoulli distribution.
52    ///
53    /// # Examples
54    ///
55    /// ```
56    /// use statrs::distribution::Bernoulli;
57    ///
58    /// let n = Bernoulli::new(0.5).unwrap();
59    /// assert_eq!(n.p(), 0.5);
60    /// ```
61    pub fn p(&self) -> f64 {
62        self.b.p()
63    }
64
65    /// Returns the number of trials `n` of the
66    /// bernoulli distribution. Will always be `1.0`.
67    ///
68    /// # Examples
69    ///
70    /// ```
71    /// use statrs::distribution::Bernoulli;
72    ///
73    /// let n = Bernoulli::new(0.5).unwrap();
74    /// assert_eq!(n.n(), 1);
75    /// ```
76    pub fn n(&self) -> u64 {
77        1
78    }
79}
80
81impl std::fmt::Display for Bernoulli {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        write!(f, "Bernoulli({})", self.p())
84    }
85}
86
87#[cfg(feature = "rand")]
88#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
89impl ::rand::distributions::Distribution<bool> for Bernoulli {
90    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> bool {
91        rng.gen_bool(self.p())
92    }
93}
94
95#[cfg(feature = "rand")]
96#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
97impl ::rand::distributions::Distribution<f64> for Bernoulli {
98    fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
99        rng.sample::<bool, _>(self) as u8 as f64
100    }
101}
102
103impl DiscreteCDF<u64, f64> for Bernoulli {
104    /// Calculates the cumulative distribution
105    /// function for the bernoulli distribution at `x`.
106    ///
107    /// # Formula
108    ///
109    /// ```text
110    /// if x < 0 { 0 }
111    /// else if x >= 1 { 1 }
112    /// else { 1 - p }
113    /// ```
114    fn cdf(&self, x: u64) -> f64 {
115        self.b.cdf(x)
116    }
117
118    /// Calculates the survival function for the
119    /// bernoulli distribution at `x`.
120    ///
121    /// # Formula
122    ///
123    /// ```text
124    /// if x < 0 { 1 }
125    /// else if x >= 1 { 0 }
126    /// else { p }
127    /// ```
128    fn sf(&self, x: u64) -> f64 {
129        self.b.sf(x)
130    }
131}
132
133impl Min<u64> for Bernoulli {
134    /// Returns the minimum value in the domain of the
135    /// bernoulli distribution representable by a 64-
136    /// bit integer
137    ///
138    /// # Formula
139    ///
140    /// ```text
141    /// 0
142    /// ```
143    fn min(&self) -> u64 {
144        0
145    }
146}
147
148impl Max<u64> for Bernoulli {
149    /// Returns the maximum value in the domain of the
150    /// bernoulli distribution representable by a 64-
151    /// bit integer
152    ///
153    /// # Formula
154    ///
155    /// ```text
156    /// 1
157    /// ```
158    fn max(&self) -> u64 {
159        1
160    }
161}
162
163impl Distribution<f64> for Bernoulli {
164    /// Returns the mean of the bernoulli
165    /// distribution
166    ///
167    /// # Formula
168    ///
169    /// ```text
170    /// p
171    /// ```
172    fn mean(&self) -> Option<f64> {
173        self.b.mean()
174    }
175
176    /// Returns the variance of the bernoulli
177    /// distribution
178    ///
179    /// # Formula
180    ///
181    /// ```text
182    /// p * (1 - p)
183    /// ```
184    fn variance(&self) -> Option<f64> {
185        self.b.variance()
186    }
187
188    /// Returns the entropy of the bernoulli
189    /// distribution
190    ///
191    /// # Formula
192    ///
193    /// ```text
194    /// q = (1 - p)
195    /// -q * ln(q) - p * ln(p)
196    /// ```
197    fn entropy(&self) -> Option<f64> {
198        self.b.entropy()
199    }
200
201    /// Returns the skewness of the bernoulli
202    /// distribution
203    ///
204    /// # Formula
205    ///
206    /// ```text
207    /// q = (1 - p)
208    /// (1 - 2p) / sqrt(p * q)
209    /// ```
210    fn skewness(&self) -> Option<f64> {
211        self.b.skewness()
212    }
213}
214
215impl Median<f64> for Bernoulli {
216    /// Returns the median of the bernoulli
217    /// distribution
218    ///
219    /// # Formula
220    ///
221    /// ```text
222    /// if p < 0.5 { 0 }
223    /// else if p > 0.5 { 1 }
224    /// else { 0.5 }
225    /// ```
226    fn median(&self) -> f64 {
227        self.b.median()
228    }
229}
230
231impl Mode<Option<u64>> for Bernoulli {
232    /// Returns the mode of the bernoulli distribution
233    ///
234    /// # Formula
235    ///
236    /// ```text
237    /// if p < 0.5 { 0 }
238    /// else { 1 }
239    /// ```
240    fn mode(&self) -> Option<u64> {
241        self.b.mode()
242    }
243}
244
245impl Discrete<u64, f64> for Bernoulli {
246    /// Calculates the probability mass function for the
247    /// bernoulli distribution at `x`.
248    ///
249    /// # Formula
250    ///
251    /// ```text
252    /// if x == 0 { 1 - p }
253    /// else { p }
254    /// ```
255    fn pmf(&self, x: u64) -> f64 {
256        self.b.pmf(x)
257    }
258
259    /// Calculates the log probability mass function for the
260    /// bernoulli distribution at `x`.
261    ///
262    /// # Formula
263    ///
264    /// ```text
265    /// else if x == 0 { ln(1 - p) }
266    /// else { ln(p) }
267    /// ```
268    fn ln_pmf(&self, x: u64) -> f64 {
269        self.b.ln_pmf(x)
270    }
271}
272
273#[rustfmt::skip]
274#[cfg(test)]
275mod testing {
276    use super::*;
277    use crate::testing_boiler;
278
279    testing_boiler!(p: f64; Bernoulli; BinomialError);
280
281    #[test]
282    fn test_create() {
283        create_ok(0.0);
284        create_ok(0.3);
285        create_ok(1.0);
286    }
287
288    #[test]
289    fn test_bad_create() {
290        create_err(f64::NAN);
291        create_err(-1.0);
292        create_err(2.0);
293    }
294
295    #[test]
296    fn test_cdf_upper_bound() {
297        let cdf = |arg: u64| move |x: Bernoulli| x.cdf(arg);
298        test_relative(0.3, 1., cdf(1));
299    }
300
301    #[test]
302    fn test_sf_upper_bound() {
303        let sf = |arg: u64| move |x: Bernoulli| x.sf(arg);
304        test_relative(0.3, 0., sf(1));
305    }
306
307    #[test]
308    fn test_cdf() {
309        let cdf = |arg: u64| move |x: Bernoulli| x.cdf(arg);
310        test_relative(0.0, 1.0, cdf(0));
311        test_relative(0.0, 1.0, cdf(1));
312        test_absolute(0.3, 0.7, 1e-15, cdf(0));
313        test_absolute(0.7, 0.3, 1e-15, cdf(0));
314    }
315
316    #[test]
317    fn test_sf() {
318        let sf = |arg: u64| move |x: Bernoulli| x.sf(arg);
319        test_relative(0.0, 0.0, sf(0));
320        test_relative(0.0, 0.0, sf(1));
321        test_absolute(0.3, 0.3, 1e-15, sf(0));
322        test_absolute(0.7, 0.7, 1e-15, sf(0));
323    }
324}