statrs/distribution/bernoulli.rs
1use crate::distribution::{Binomial, Discrete, DiscreteCDF};
2use crate::statistics::*;
3use crate::Result;
4use rand::Rng;
5
6/// Implements the
7/// [Bernoulli](https://en.wikipedia.org/wiki/Bernoulli_distribution)
8/// distribution which is a special case of the
9/// [Binomial](https://en.wikipedia.org/wiki/Binomial_distribution)
10/// distribution where `n = 1` (referenced [Here](./struct.Binomial.html))
11///
12/// # Examples
13///
14/// ```
15/// use statrs::distribution::{Bernoulli, Discrete};
16/// use statrs::statistics::Distribution;
17///
18/// let n = Bernoulli::new(0.5).unwrap();
19/// assert_eq!(n.mean().unwrap(), 0.5);
20/// assert_eq!(n.pmf(0), 0.5);
21/// assert_eq!(n.pmf(1), 0.5);
22/// ```
23#[derive(Debug, Copy, Clone, PartialEq)]
24pub struct Bernoulli {
25 b: Binomial,
26}
27
28impl Bernoulli {
29 /// Constructs a new bernoulli distribution with
30 /// the given `p` probability of success.
31 ///
32 /// # Errors
33 ///
34 /// Returns an error if `p` is `NaN`, less than `0.0`
35 /// or greater than `1.0`
36 ///
37 /// # Examples
38 ///
39 /// ```
40 /// use statrs::distribution::Bernoulli;
41 ///
42 /// let mut result = Bernoulli::new(0.5);
43 /// assert!(result.is_ok());
44 ///
45 /// result = Bernoulli::new(-0.5);
46 /// assert!(result.is_err());
47 /// ```
48 pub fn new(p: f64) -> Result<Bernoulli> {
49 Binomial::new(p, 1).map(|b| Bernoulli { b })
50 }
51
52 /// Returns the probability of success `p` of the
53 /// bernoulli distribution.
54 ///
55 /// # Examples
56 ///
57 /// ```
58 /// use statrs::distribution::Bernoulli;
59 ///
60 /// let n = Bernoulli::new(0.5).unwrap();
61 /// assert_eq!(n.p(), 0.5);
62 /// ```
63 pub fn p(&self) -> f64 {
64 self.b.p()
65 }
66
67 /// Returns the number of trials `n` of the
68 /// bernoulli distribution. Will always be `1.0`.
69 ///
70 /// # Examples
71 ///
72 /// ```
73 /// use statrs::distribution::Bernoulli;
74 ///
75 /// let n = Bernoulli::new(0.5).unwrap();
76 /// assert_eq!(n.n(), 1);
77 /// ```
78 pub fn n(&self) -> u64 {
79 1
80 }
81}
82
83impl ::rand::distributions::Distribution<f64> for Bernoulli {
84 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
85 rng.gen_bool(self.p()) as u8 as f64
86 }
87}
88
89impl DiscreteCDF<u64, f64> for Bernoulli {
90 /// Calculates the cumulative distribution
91 /// function for the bernoulli distribution at `x`.
92 ///
93 /// # Formula
94 ///
95 /// ```ignore
96 /// if x < 0 { 0 }
97 /// else if x >= 1 { 1 }
98 /// else { 1 - p }
99 /// ```
100 fn cdf(&self, x: u64) -> f64 {
101 self.b.cdf(x)
102 }
103
104 /// Calculates the survival function for the
105 /// bernoulli distribution at `x`.
106 ///
107 /// # Formula
108 ///
109 /// ```ignore
110 /// if x < 0 { 1 }
111 /// else if x >= 1 { 0 }
112 /// else { p }
113 /// ```
114 fn sf(&self, x: u64) -> f64 {
115 self.b.sf(x)
116 }
117}
118
119impl Min<u64> for Bernoulli {
120 /// Returns the minimum value in the domain of the
121 /// bernoulli distribution representable by a 64-
122 /// bit integer
123 ///
124 /// # Formula
125 ///
126 /// ```ignore
127 /// 0
128 /// ```
129 fn min(&self) -> u64 {
130 0
131 }
132}
133
134impl Max<u64> for Bernoulli {
135 /// Returns the maximum value in the domain of the
136 /// bernoulli distribution representable by a 64-
137 /// bit integer
138 ///
139 /// # Formula
140 ///
141 /// ```ignore
142 /// 1
143 /// ```
144 fn max(&self) -> u64 {
145 1
146 }
147}
148
149impl Distribution<f64> for Bernoulli {
150 /// Returns the mean of the bernoulli
151 /// distribution
152 ///
153 /// # Formula
154 ///
155 /// ```ignore
156 /// p
157 /// ```
158 fn mean(&self) -> Option<f64> {
159 self.b.mean()
160 }
161 /// Returns the variance of the bernoulli
162 /// distribution
163 ///
164 /// # Formula
165 ///
166 /// ```ignore
167 /// p * (1 - p)
168 /// ```
169 fn variance(&self) -> Option<f64> {
170 self.b.variance()
171 }
172 /// Returns the entropy of the bernoulli
173 /// distribution
174 ///
175 /// # Formula
176 ///
177 /// ```ignore
178 /// q = (1 - p)
179 /// -q * ln(q) - p * ln(p)
180 /// ```
181 fn entropy(&self) -> Option<f64> {
182 self.b.entropy()
183 }
184 /// Returns the skewness of the bernoulli
185 /// distribution
186 ///
187 /// # Formula
188 ///
189 /// ```ignore
190 /// q = (1 - p)
191 /// (1 - 2p) / sqrt(p * q)
192 /// ```
193 fn skewness(&self) -> Option<f64> {
194 self.b.skewness()
195 }
196}
197
198impl Median<f64> for Bernoulli {
199 /// Returns the median of the bernoulli
200 /// distribution
201 ///
202 /// # Formula
203 ///
204 /// ```ignore
205 /// if p < 0.5 { 0 }
206 /// else if p > 0.5 { 1 }
207 /// else { 0.5 }
208 /// ```
209 fn median(&self) -> f64 {
210 self.b.median()
211 }
212}
213
214impl Mode<Option<u64>> for Bernoulli {
215 /// Returns the mode of the bernoulli distribution
216 ///
217 /// # Formula
218 ///
219 /// ```ignore
220 /// if p < 0.5 { 0 }
221 /// else { 1 }
222 /// ```
223 fn mode(&self) -> Option<u64> {
224 self.b.mode()
225 }
226}
227
228impl Discrete<u64, f64> for Bernoulli {
229 /// Calculates the probability mass function for the
230 /// bernoulli distribution at `x`.
231 ///
232 /// # Formula
233 ///
234 /// ```ignore
235 /// if x == 0 { 1 - p }
236 /// else { p }
237 /// ```
238 fn pmf(&self, x: u64) -> f64 {
239 self.b.pmf(x)
240 }
241
242 /// Calculates the log probability mass function for the
243 /// bernoulli distribution at `x`.
244 ///
245 /// # Formula
246 ///
247 /// ```ignore
248 /// else if x == 0 { ln(1 - p) }
249 /// else { ln(p) }
250 /// ```
251 fn ln_pmf(&self, x: u64) -> f64 {
252 self.b.ln_pmf(x)
253 }
254}
255
256#[rustfmt::skip]
257#[cfg(all(test, feature = "nightly"))]
258mod testing {
259 use std::fmt::Debug;
260 use crate::distribution::DiscreteCDF;
261 use super::Bernoulli;
262
263 fn try_create(p: f64) -> Bernoulli {
264 let n = Bernoulli::new(p);
265 assert!(n.is_ok());
266 n.unwrap()
267 }
268
269 fn create_case(p: f64) {
270 let dist = try_create(p);
271 assert_eq!(p, dist.p());
272 }
273
274 fn bad_create_case(p: f64) {
275 let n = Bernoulli::new(p);
276 assert!(n.is_err());
277 }
278
279 fn get_value<T, F>(p: f64, eval: F) -> T
280 where T: PartialEq + Debug,
281 F: Fn(Bernoulli) -> T
282 {
283 let n = try_create(p);
284 eval(n)
285 }
286
287 fn test_case<T, F>(p: f64, expected: T, eval: F)
288 where T: PartialEq + Debug,
289 F: Fn(Bernoulli) -> T
290 {
291 let x = get_value(p, eval);
292 assert_eq!(expected, x);
293 }
294
295 fn test_almost<F>(p: f64, expected: f64, acc: f64, eval: F)
296 where F: Fn(Bernoulli) -> f64
297 {
298 let x = get_value(p, eval);
299 assert_almost_eq!(expected, x, acc);
300 }
301
302 #[test]
303 fn test_create() {
304 create_case(0.0);
305 create_case(0.3);
306 create_case(1.0);
307 }
308
309 #[test]
310 fn test_bad_create() {
311 bad_create_case(f64::NAN);
312 bad_create_case(-1.0);
313 bad_create_case(2.0);
314 }
315
316 #[test]
317 fn test_cdf_upper_bound() {
318 let cdf = |arg: u64| move |x: Bernoulli| x.cdf(arg);
319 test_case(0.3, 1., cdf(1));
320 }
321
322 #[test]
323 fn test_sf_upper_bound() {
324 let sf = |arg: u64| move |x: Bernoulli| x.sf(arg);
325 test_case(0.3, 0., sf(1));
326 }
327
328 #[test]
329 fn test_cdf() {
330 let cdf = |arg: u64| move |x: Bernoulli| x.cdf(arg);
331 test_case(0.0, 1.0, cdf(0));
332 test_case(0.0, 1.0, cdf(1));
333 test_almost(0.3, 0.7, 1e-15, cdf(0));
334 test_almost(0.7, 0.3, 1e-15, cdf(0));
335 }
336
337 #[test]
338 fn test_sf() {
339 let sf = |arg: u64| move |x: Bernoulli| x.sf(arg);
340 test_case(0.0, 0.0, sf(0));
341 test_case(0.0, 0.0, sf(1));
342 test_almost(0.3, 0.3, 1e-15, sf(0));
343 test_almost(0.7, 0.7, 1e-15, sf(0));
344 }
345}