statrs/distribution/
bernoulli.rs

1use crate::distribution::{Binomial, Discrete, DiscreteCDF};
2use crate::statistics::*;
3use crate::Result;
4use rand::Rng;
5
6/// Implements the
7/// [Bernoulli](https://en.wikipedia.org/wiki/Bernoulli_distribution)
8/// distribution which is a special case of the
9/// [Binomial](https://en.wikipedia.org/wiki/Binomial_distribution)
10/// distribution where `n = 1` (referenced [Here](./struct.Binomial.html))
11///
12/// # Examples
13///
14/// ```
15/// use statrs::distribution::{Bernoulli, Discrete};
16/// use statrs::statistics::Distribution;
17///
18/// let n = Bernoulli::new(0.5).unwrap();
19/// assert_eq!(n.mean().unwrap(), 0.5);
20/// assert_eq!(n.pmf(0), 0.5);
21/// assert_eq!(n.pmf(1), 0.5);
22/// ```
23#[derive(Debug, Copy, Clone, PartialEq)]
24pub struct Bernoulli {
25    b: Binomial,
26}
27
28impl Bernoulli {
29    /// Constructs a new bernoulli distribution with
30    /// the given `p` probability of success.
31    ///
32    /// # Errors
33    ///
34    /// Returns an error if `p` is `NaN`, less than `0.0`
35    /// or greater than `1.0`
36    ///
37    /// # Examples
38    ///
39    /// ```
40    /// use statrs::distribution::Bernoulli;
41    ///
42    /// let mut result = Bernoulli::new(0.5);
43    /// assert!(result.is_ok());
44    ///
45    /// result = Bernoulli::new(-0.5);
46    /// assert!(result.is_err());
47    /// ```
48    pub fn new(p: f64) -> Result<Bernoulli> {
49        Binomial::new(p, 1).map(|b| Bernoulli { b })
50    }
51
52    /// Returns the probability of success `p` of the
53    /// bernoulli distribution.
54    ///
55    /// # Examples
56    ///
57    /// ```
58    /// use statrs::distribution::Bernoulli;
59    ///
60    /// let n = Bernoulli::new(0.5).unwrap();
61    /// assert_eq!(n.p(), 0.5);
62    /// ```
63    pub fn p(&self) -> f64 {
64        self.b.p()
65    }
66
67    /// Returns the number of trials `n` of the
68    /// bernoulli distribution. Will always be `1.0`.
69    ///
70    /// # Examples
71    ///
72    /// ```
73    /// use statrs::distribution::Bernoulli;
74    ///
75    /// let n = Bernoulli::new(0.5).unwrap();
76    /// assert_eq!(n.n(), 1);
77    /// ```
78    pub fn n(&self) -> u64 {
79        1
80    }
81}
82
83impl ::rand::distributions::Distribution<f64> for Bernoulli {
84    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
85        rng.gen_bool(self.p()) as u8 as f64
86    }
87}
88
89impl DiscreteCDF<u64, f64> for Bernoulli {
90    /// Calculates the cumulative distribution
91    /// function for the bernoulli distribution at `x`.
92    ///
93    /// # Formula
94    ///
95    /// ```ignore
96    /// if x < 0 { 0 }
97    /// else if x >= 1 { 1 }
98    /// else { 1 - p }
99    /// ```
100    fn cdf(&self, x: u64) -> f64 {
101        self.b.cdf(x)
102    }
103
104    /// Calculates the survival function for the 
105    /// bernoulli distribution at `x`.
106    ///
107    /// # Formula
108    ///
109    /// ```ignore
110    /// if x < 0 { 1 }
111    /// else if x >= 1 { 0 }
112    /// else { p }
113    /// ```
114    fn sf(&self, x: u64) -> f64 {
115        self.b.sf(x)
116    }
117}
118
119impl Min<u64> for Bernoulli {
120    /// Returns the minimum value in the domain of the
121    /// bernoulli distribution representable by a 64-
122    /// bit integer
123    ///
124    /// # Formula
125    ///
126    /// ```ignore
127    /// 0
128    /// ```
129    fn min(&self) -> u64 {
130        0
131    }
132}
133
134impl Max<u64> for Bernoulli {
135    /// Returns the maximum value in the domain of the
136    /// bernoulli distribution representable by a 64-
137    /// bit integer
138    ///
139    /// # Formula
140    ///
141    /// ```ignore
142    /// 1
143    /// ```
144    fn max(&self) -> u64 {
145        1
146    }
147}
148
149impl Distribution<f64> for Bernoulli {
150    /// Returns the mean of the bernoulli
151    /// distribution
152    ///
153    /// # Formula
154    ///
155    /// ```ignore
156    /// p
157    /// ```
158    fn mean(&self) -> Option<f64> {
159        self.b.mean()
160    }
161    /// Returns the variance of the bernoulli
162    /// distribution
163    ///
164    /// # Formula
165    ///
166    /// ```ignore
167    /// p * (1 - p)
168    /// ```
169    fn variance(&self) -> Option<f64> {
170        self.b.variance()
171    }
172    /// Returns the entropy of the bernoulli
173    /// distribution
174    ///
175    /// # Formula
176    ///
177    /// ```ignore
178    /// q = (1 - p)
179    /// -q * ln(q) - p * ln(p)
180    /// ```
181    fn entropy(&self) -> Option<f64> {
182        self.b.entropy()
183    }
184    /// Returns the skewness of the bernoulli
185    /// distribution
186    ///
187    /// # Formula
188    ///
189    /// ```ignore
190    /// q = (1 - p)
191    /// (1 - 2p) / sqrt(p * q)
192    /// ```
193    fn skewness(&self) -> Option<f64> {
194        self.b.skewness()
195    }
196}
197
198impl Median<f64> for Bernoulli {
199    /// Returns the median of the bernoulli
200    /// distribution
201    ///
202    /// # Formula
203    ///
204    /// ```ignore
205    /// if p < 0.5 { 0 }
206    /// else if p > 0.5 { 1 }
207    /// else { 0.5 }
208    /// ```
209    fn median(&self) -> f64 {
210        self.b.median()
211    }
212}
213
214impl Mode<Option<u64>> for Bernoulli {
215    /// Returns the mode of the bernoulli distribution
216    ///
217    /// # Formula
218    ///
219    /// ```ignore
220    /// if p < 0.5 { 0 }
221    /// else { 1 }
222    /// ```
223    fn mode(&self) -> Option<u64> {
224        self.b.mode()
225    }
226}
227
228impl Discrete<u64, f64> for Bernoulli {
229    /// Calculates the probability mass function for the
230    /// bernoulli distribution at `x`.
231    ///
232    /// # Formula
233    ///
234    /// ```ignore
235    /// if x == 0 { 1 - p }
236    /// else { p }
237    /// ```
238    fn pmf(&self, x: u64) -> f64 {
239        self.b.pmf(x)
240    }
241
242    /// Calculates the log probability mass function for the
243    /// bernoulli distribution at `x`.
244    ///
245    /// # Formula
246    ///
247    /// ```ignore
248    /// else if x == 0 { ln(1 - p) }
249    /// else { ln(p) }
250    /// ```
251    fn ln_pmf(&self, x: u64) -> f64 {
252        self.b.ln_pmf(x)
253    }
254}
255
256#[rustfmt::skip]
257#[cfg(all(test, feature = "nightly"))]
258mod testing {
259    use std::fmt::Debug;
260    use crate::distribution::DiscreteCDF;
261    use super::Bernoulli;
262
263    fn try_create(p: f64) -> Bernoulli {
264        let n = Bernoulli::new(p);
265        assert!(n.is_ok());
266        n.unwrap()
267    }
268
269    fn create_case(p: f64) {
270        let dist = try_create(p);
271        assert_eq!(p, dist.p());
272    }
273
274    fn bad_create_case(p: f64) {
275        let n = Bernoulli::new(p);
276        assert!(n.is_err());
277    }
278
279    fn get_value<T, F>(p: f64, eval: F) -> T
280        where T: PartialEq + Debug,
281              F: Fn(Bernoulli) -> T
282    {
283        let n = try_create(p);
284        eval(n)
285    }
286
287    fn test_case<T, F>(p: f64, expected: T, eval: F)
288        where T: PartialEq + Debug,
289              F: Fn(Bernoulli) -> T
290    {
291        let x = get_value(p, eval);
292        assert_eq!(expected, x);
293    }
294
295    fn test_almost<F>(p: f64, expected: f64, acc: f64, eval: F)
296        where F: Fn(Bernoulli) -> f64
297    {
298        let x = get_value(p, eval);
299        assert_almost_eq!(expected, x, acc);
300    }
301
302    #[test]
303    fn test_create() {
304        create_case(0.0);
305        create_case(0.3);
306        create_case(1.0);
307    }
308
309    #[test]
310    fn test_bad_create() {
311        bad_create_case(f64::NAN);
312        bad_create_case(-1.0);
313        bad_create_case(2.0);
314    }
315
316    #[test]
317    fn test_cdf_upper_bound() {
318        let cdf = |arg: u64| move |x: Bernoulli| x.cdf(arg);
319        test_case(0.3, 1., cdf(1));
320    }
321
322    #[test]
323    fn test_sf_upper_bound() {
324        let sf = |arg: u64| move |x: Bernoulli| x.sf(arg);
325        test_case(0.3, 0., sf(1));
326    }
327
328    #[test]
329    fn test_cdf() {
330        let cdf = |arg: u64| move |x: Bernoulli| x.cdf(arg);
331        test_case(0.0, 1.0, cdf(0));
332        test_case(0.0, 1.0, cdf(1));
333        test_almost(0.3, 0.7, 1e-15, cdf(0));
334        test_almost(0.7, 0.3, 1e-15, cdf(0));
335    }
336
337    #[test]
338    fn test_sf() {
339        let sf = |arg: u64| move |x: Bernoulli| x.sf(arg);
340        test_case(0.0, 0.0, sf(0));
341        test_case(0.0, 0.0, sf(1));
342        test_almost(0.3, 0.3, 1e-15, sf(0));
343        test_almost(0.7, 0.7, 1e-15, sf(0));
344    }
345}