statrs/distribution/bernoulli.rs
1use crate::distribution::{Binomial, BinomialError, Discrete, DiscreteCDF};
2use crate::statistics::*;
3
4/// Implements the
5/// [Bernoulli](https://en.wikipedia.org/wiki/Bernoulli_distribution)
6/// distribution which is a special case of the
7/// [Binomial](https://en.wikipedia.org/wiki/Binomial_distribution)
8/// distribution where `n = 1` (referenced [Here](./struct.Binomial.html))
9///
10/// # Examples
11///
12/// ```
13/// use statrs::distribution::{Bernoulli, Discrete};
14/// use statrs::statistics::Distribution;
15///
16/// let n = Bernoulli::new(0.5).unwrap();
17/// assert_eq!(n.mean().unwrap(), 0.5);
18/// assert_eq!(n.pmf(0), 0.5);
19/// assert_eq!(n.pmf(1), 0.5);
20/// ```
21#[derive(Copy, Clone, PartialEq, Debug)]
22pub struct Bernoulli {
23 b: Binomial,
24}
25
26impl Bernoulli {
27 /// Constructs a new bernoulli distribution with
28 /// the given `p` probability of success.
29 ///
30 /// # Errors
31 ///
32 /// Returns an error if `p` is `NaN`, less than `0.0`
33 /// or greater than `1.0`
34 ///
35 /// # Examples
36 ///
37 /// ```
38 /// use statrs::distribution::Bernoulli;
39 ///
40 /// let mut result = Bernoulli::new(0.5);
41 /// assert!(result.is_ok());
42 ///
43 /// result = Bernoulli::new(-0.5);
44 /// assert!(result.is_err());
45 /// ```
46 pub fn new(p: f64) -> Result<Bernoulli, BinomialError> {
47 Binomial::new(p, 1).map(|b| Bernoulli { b })
48 }
49
50 /// Returns the probability of success `p` of the
51 /// bernoulli distribution.
52 ///
53 /// # Examples
54 ///
55 /// ```
56 /// use statrs::distribution::Bernoulli;
57 ///
58 /// let n = Bernoulli::new(0.5).unwrap();
59 /// assert_eq!(n.p(), 0.5);
60 /// ```
61 pub fn p(&self) -> f64 {
62 self.b.p()
63 }
64
65 /// Returns the number of trials `n` of the
66 /// bernoulli distribution. Will always be `1.0`.
67 ///
68 /// # Examples
69 ///
70 /// ```
71 /// use statrs::distribution::Bernoulli;
72 ///
73 /// let n = Bernoulli::new(0.5).unwrap();
74 /// assert_eq!(n.n(), 1);
75 /// ```
76 pub fn n(&self) -> u64 {
77 1
78 }
79}
80
81impl std::fmt::Display for Bernoulli {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 write!(f, "Bernoulli({})", self.p())
84 }
85}
86
87#[cfg(feature = "rand")]
88#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
89impl ::rand::distributions::Distribution<bool> for Bernoulli {
90 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> bool {
91 rng.gen_bool(self.p())
92 }
93}
94
95#[cfg(feature = "rand")]
96#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
97impl ::rand::distributions::Distribution<f64> for Bernoulli {
98 fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
99 rng.sample::<bool, _>(self) as u8 as f64
100 }
101}
102
103impl DiscreteCDF<u64, f64> for Bernoulli {
104 /// Calculates the cumulative distribution
105 /// function for the bernoulli distribution at `x`.
106 ///
107 /// # Formula
108 ///
109 /// ```text
110 /// if x < 0 { 0 }
111 /// else if x >= 1 { 1 }
112 /// else { 1 - p }
113 /// ```
114 fn cdf(&self, x: u64) -> f64 {
115 self.b.cdf(x)
116 }
117
118 /// Calculates the survival function for the
119 /// bernoulli distribution at `x`.
120 ///
121 /// # Formula
122 ///
123 /// ```text
124 /// if x < 0 { 1 }
125 /// else if x >= 1 { 0 }
126 /// else { p }
127 /// ```
128 fn sf(&self, x: u64) -> f64 {
129 self.b.sf(x)
130 }
131}
132
133impl Min<u64> for Bernoulli {
134 /// Returns the minimum value in the domain of the
135 /// bernoulli distribution representable by a 64-
136 /// bit integer
137 ///
138 /// # Formula
139 ///
140 /// ```text
141 /// 0
142 /// ```
143 fn min(&self) -> u64 {
144 0
145 }
146}
147
148impl Max<u64> for Bernoulli {
149 /// Returns the maximum value in the domain of the
150 /// bernoulli distribution representable by a 64-
151 /// bit integer
152 ///
153 /// # Formula
154 ///
155 /// ```text
156 /// 1
157 /// ```
158 fn max(&self) -> u64 {
159 1
160 }
161}
162
163impl Distribution<f64> for Bernoulli {
164 /// Returns the mean of the bernoulli
165 /// distribution
166 ///
167 /// # Formula
168 ///
169 /// ```text
170 /// p
171 /// ```
172 fn mean(&self) -> Option<f64> {
173 self.b.mean()
174 }
175
176 /// Returns the variance of the bernoulli
177 /// distribution
178 ///
179 /// # Formula
180 ///
181 /// ```text
182 /// p * (1 - p)
183 /// ```
184 fn variance(&self) -> Option<f64> {
185 self.b.variance()
186 }
187
188 /// Returns the entropy of the bernoulli
189 /// distribution
190 ///
191 /// # Formula
192 ///
193 /// ```text
194 /// q = (1 - p)
195 /// -q * ln(q) - p * ln(p)
196 /// ```
197 fn entropy(&self) -> Option<f64> {
198 self.b.entropy()
199 }
200
201 /// Returns the skewness of the bernoulli
202 /// distribution
203 ///
204 /// # Formula
205 ///
206 /// ```text
207 /// q = (1 - p)
208 /// (1 - 2p) / sqrt(p * q)
209 /// ```
210 fn skewness(&self) -> Option<f64> {
211 self.b.skewness()
212 }
213}
214
215impl Median<f64> for Bernoulli {
216 /// Returns the median of the bernoulli
217 /// distribution
218 ///
219 /// # Formula
220 ///
221 /// ```text
222 /// if p < 0.5 { 0 }
223 /// else if p > 0.5 { 1 }
224 /// else { 0.5 }
225 /// ```
226 fn median(&self) -> f64 {
227 self.b.median()
228 }
229}
230
231impl Mode<Option<u64>> for Bernoulli {
232 /// Returns the mode of the bernoulli distribution
233 ///
234 /// # Formula
235 ///
236 /// ```text
237 /// if p < 0.5 { 0 }
238 /// else { 1 }
239 /// ```
240 fn mode(&self) -> Option<u64> {
241 self.b.mode()
242 }
243}
244
245impl Discrete<u64, f64> for Bernoulli {
246 /// Calculates the probability mass function for the
247 /// bernoulli distribution at `x`.
248 ///
249 /// # Formula
250 ///
251 /// ```text
252 /// if x == 0 { 1 - p }
253 /// else { p }
254 /// ```
255 fn pmf(&self, x: u64) -> f64 {
256 self.b.pmf(x)
257 }
258
259 /// Calculates the log probability mass function for the
260 /// bernoulli distribution at `x`.
261 ///
262 /// # Formula
263 ///
264 /// ```text
265 /// else if x == 0 { ln(1 - p) }
266 /// else { ln(p) }
267 /// ```
268 fn ln_pmf(&self, x: u64) -> f64 {
269 self.b.ln_pmf(x)
270 }
271}
272
273#[rustfmt::skip]
274#[cfg(test)]
275mod testing {
276 use super::*;
277 use crate::testing_boiler;
278
279 testing_boiler!(p: f64; Bernoulli; BinomialError);
280
281 #[test]
282 fn test_create() {
283 create_ok(0.0);
284 create_ok(0.3);
285 create_ok(1.0);
286 }
287
288 #[test]
289 fn test_bad_create() {
290 create_err(f64::NAN);
291 create_err(-1.0);
292 create_err(2.0);
293 }
294
295 #[test]
296 fn test_cdf_upper_bound() {
297 let cdf = |arg: u64| move |x: Bernoulli| x.cdf(arg);
298 test_relative(0.3, 1., cdf(1));
299 }
300
301 #[test]
302 fn test_sf_upper_bound() {
303 let sf = |arg: u64| move |x: Bernoulli| x.sf(arg);
304 test_relative(0.3, 0., sf(1));
305 }
306
307 #[test]
308 fn test_cdf() {
309 let cdf = |arg: u64| move |x: Bernoulli| x.cdf(arg);
310 test_relative(0.0, 1.0, cdf(0));
311 test_relative(0.0, 1.0, cdf(1));
312 test_absolute(0.3, 0.7, 1e-15, cdf(0));
313 test_absolute(0.7, 0.3, 1e-15, cdf(0));
314 }
315
316 #[test]
317 fn test_sf() {
318 let sf = |arg: u64| move |x: Bernoulli| x.sf(arg);
319 test_relative(0.0, 0.0, sf(0));
320 test_relative(0.0, 0.0, sf(1));
321 test_absolute(0.3, 0.3, 1e-15, sf(0));
322 test_absolute(0.7, 0.7, 1e-15, sf(0));
323 }
324}