1use super::Alternative;
2use crate::distribution::{Discrete, DiscreteCDF, Hypergeometric, HypergeometricError};
3
4const EPSILON: f64 = 1.0 - 1e-4;
5
6fn binary_search(
8 n: u64,
9 n1: u64,
10 n2: u64,
11 mode: u64,
12 p_exact: f64,
13 epsilon: f64,
14 upper: bool,
15) -> u64 {
16 let (mut min_val, mut max_val) = {
17 if upper {
18 (mode, n)
19 } else {
20 (0, mode)
21 }
22 };
23
24 let population = n1 + n2;
25 let successes = n1;
26 let draws = n;
27 let dist = Hypergeometric::new(population, successes, draws).unwrap();
28
29 let mut guess = 0;
30 loop {
31 if max_val - min_val <= 1 {
32 break;
33 }
34 guess = {
35 if max_val == min_val + 1 && guess == min_val {
36 max_val
37 } else {
38 (max_val + min_val) / 2
39 }
40 };
41
42 let ng = {
43 if upper {
44 guess - 1
45 } else {
46 guess + 1
47 }
48 };
49
50 let pmf_comp = dist.pmf(ng);
51 let p_guess = dist.pmf(guess);
52 if p_guess <= p_exact && p_exact < pmf_comp {
53 break;
54 }
55 if p_guess < p_exact {
56 max_val = guess
57 } else {
58 min_val = guess
59 }
60 }
61
62 if guess == 0 {
63 guess = min_val
64 }
65 if upper {
66 loop {
67 if guess > 0 && dist.pmf(guess) < p_exact * epsilon {
68 guess -= 1;
69 } else {
70 break;
71 }
72 }
73 loop {
74 if dist.pmf(guess) > p_exact / epsilon {
75 guess += 1;
76 } else {
77 break;
78 }
79 }
80 } else {
81 loop {
82 if dist.pmf(guess) < p_exact * epsilon {
83 guess += 1;
84 } else {
85 break;
86 }
87 }
88 loop {
89 if guess > 0 && dist.pmf(guess) > p_exact / epsilon {
90 guess -= 1;
91 } else {
92 break;
93 }
94 }
95 }
96 guess
97}
98
99#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
100#[non_exhaustive]
101pub enum FishersExactTestError {
102 TableInvalidForHypergeometric(HypergeometricError),
105}
106
107impl std::fmt::Display for FishersExactTestError {
108 #[cfg_attr(coverage_nightly, coverage(off))]
109 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
110 match self {
111 FishersExactTestError::TableInvalidForHypergeometric(hg_err) => {
112 writeln!(f, "Cannot create a Hypergeometric distribution from the data in the contingency table.")?;
113 writeln!(f, "Is it in row-major order?")?;
114 write!(f, "Inner error: '{hg_err}'")
115 }
116 }
117 }
118}
119
120impl std::error::Error for FishersExactTestError {}
121
122impl From<HypergeometricError> for FishersExactTestError {
123 fn from(value: HypergeometricError) -> Self {
124 Self::TableInvalidForHypergeometric(value)
125 }
126}
127
128pub fn fishers_exact_with_odds_ratio(
141 table: &[u64; 4],
142 alternative: Alternative,
143) -> Result<(f64, f64), FishersExactTestError> {
144 match table {
146 [0, _, 0, _] | [_, 0, _, 0] => return Ok((f64::NAN, 1.0)), [0, 0, _, _] | [_, _, 0, 0] => return Ok((f64::NAN, 1.0)), _ => (), }
150
151 let odds_ratio = {
152 if table[1] > 0 && table[2] > 0 {
153 (table[0] * table[3]) as f64 / (table[1] * table[2]) as f64
154 } else {
155 f64::INFINITY
156 }
157 };
158
159 let p_value = fishers_exact(table, alternative)?;
160 Ok((odds_ratio, p_value))
161}
162
163pub fn fishers_exact(
176 table: &[u64; 4],
177 alternative: Alternative,
178) -> Result<f64, FishersExactTestError> {
179 match table {
181 [0, _, 0, _] | [_, 0, _, 0] => return Ok(1.0), [0, 0, _, _] | [_, _, 0, 0] => return Ok(1.0), _ => (), }
185
186 let n1 = table[0] + table[1];
187 let n2 = table[2] + table[3];
188 let n = table[0] + table[2];
189
190 let p_value = {
191 let population = n1 + n2;
192 let successes = n1;
193
194 match alternative {
195 Alternative::Less => {
196 let draws = n;
197 let dist = Hypergeometric::new(population, successes, draws)?;
198 dist.cdf(table[0])
199 }
200 Alternative::Greater => {
201 let draws = table[1] + table[3];
202 let dist = Hypergeometric::new(population, successes, draws)?;
203 dist.cdf(table[1])
204 }
205 Alternative::TwoSided => {
206 let draws = n;
207 let dist = Hypergeometric::new(population, successes, draws)?;
208
209 let p_exact = dist.pmf(table[0]);
210 let mode = ((n + 1) * (n1 + 1)) / (n1 + n2 + 2);
211 let p_mode = dist.pmf(mode);
212
213 if (p_exact - p_mode).abs() / p_exact.max(p_mode) <= 1.0 - EPSILON {
214 return Ok(1.0);
215 }
216
217 if table[0] < mode {
218 let p_lower = dist.cdf(table[0]);
219 if dist.pmf(n) > p_exact / EPSILON {
220 return Ok(p_lower);
221 }
222 let guess = binary_search(n, n1, n2, mode, p_exact, EPSILON, true);
223 return Ok(p_lower + 1.0 - dist.cdf(guess - 1));
224 }
225
226 let p_upper = 1.0 - dist.cdf(table[0] - 1);
227 if dist.pmf(0) > p_exact / EPSILON {
228 return Ok(p_upper);
229 }
230
231 let guess = binary_search(n, n1, n2, mode, p_exact, EPSILON, false);
232 p_upper + dist.cdf(guess)
233 }
234 }
235 };
236
237 Ok(p_value.min(1.0))
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use crate::prec;
244
245 #[test]
247 fn test_fishers_exact() {
248 let cases = [
249 (
250 [3, 5, 4, 50],
251 0.9963034765672599,
252 0.03970749246529277,
253 0.03970749246529276,
254 ),
255 (
256 [61, 118, 2, 1],
257 0.27535061623455315,
258 0.9598172545684959,
259 0.27535061623455315,
260 ),
261 (
262 [172, 46, 90, 127],
263 1.0,
264 6.662405187351769e-16,
265 9.041009036528785e-16,
266 ),
267 (
268 [127, 38, 112, 43],
269 0.8637599357870167,
270 0.20040942958644145,
271 0.3687862842650179,
272 ),
273 (
274 [186, 177, 111, 154],
275 0.9918518696328176,
276 0.012550663906725129,
277 0.023439141644624434,
278 ),
279 (
280 [137, 49, 135, 183],
281 0.999999999998533,
282 5.6517533666400615e-12,
283 8.870999836202932e-12,
284 ),
285 (
286 [37, 115, 37, 152],
287 0.8834621182590621,
288 0.17638403366123565,
289 0.29400927608021704,
290 ),
291 (
292 [124, 117, 119, 175],
293 0.9956704915461392,
294 0.007134712391455461,
295 0.011588218284387445,
296 ),
297 (
298 [70, 114, 41, 118],
299 0.9945558498544903,
300 0.010384865876586255,
301 0.020438291037108678,
302 ),
303 (
304 [173, 21, 89, 7],
305 0.2303739114068352,
306 0.8808002774812677,
307 0.4027047267306024,
308 ),
309 (
310 [18, 147, 123, 58],
311 4.077820702304103e-29,
312 0.9999999999999817,
313 0.0,
314 ),
315 (
316 [116, 20, 92, 186],
317 0.9999999999998267,
318 6.598118571034892e-25,
319 8.164831402188242e-25,
320 ),
321 (
322 [9, 22, 44, 38],
323 0.01584272038710196,
324 0.9951463496539362,
325 0.021581786662999272,
326 ),
327 (
328 [9, 101, 135, 7],
329 3.3336213533847776e-50,
330 1.0,
331 3.3336213533847776e-50,
332 ),
333 (
334 [153, 27, 191, 144],
335 0.9999999999950817,
336 2.473736787266208e-11,
337 3.185816623300107e-11,
338 ),
339 (
340 [111, 195, 189, 69],
341 6.665245982898848e-19,
342 0.9999999999994574,
343 1.0735744915712542e-18,
344 ),
345 (
346 [125, 21, 31, 131],
347 0.99999999999974,
348 9.720661317939016e-34,
349 1.0352129312860277e-33,
350 ),
351 (
352 [201, 192, 69, 179],
353 0.9999999988714893,
354 3.1477232259550017e-09,
355 4.761075937088169e-09,
356 ),
357 (
358 [124, 138, 159, 160],
359 0.30153826772785475,
360 0.7538974235759873,
361 0.5601766196310243,
362 ),
363 ];
364
365 for (table, less_expected, greater_expected, two_sided_expected) in cases.iter() {
366 for (alternative, expected) in [
367 Alternative::Less,
368 Alternative::Greater,
369 Alternative::TwoSided,
370 ]
371 .iter()
372 .zip(vec![less_expected, greater_expected, two_sided_expected])
373 {
374 let p_value = fishers_exact(table, *alternative).unwrap();
375 assert!(prec::almost_eq(p_value, *expected, 1e-12));
376 }
377 }
378 }
379
380 #[test]
381 fn test_fishers_exact_for_trivial() {
382 let cases = [[0, 0, 1, 2], [1, 2, 0, 0], [1, 0, 2, 0], [0, 1, 0, 2]];
383
384 for table in cases.iter() {
385 assert_eq!(fishers_exact(table, Alternative::Less).unwrap(), 1.0)
386 }
387 }
388
389 #[test]
390 fn test_fishers_exact_with_odds() {
391 let table = [3, 5, 4, 50];
392 let (odds_ratio, p_value) =
393 fishers_exact_with_odds_ratio(&table, Alternative::Less).unwrap();
394 assert!(prec::almost_eq(p_value, 0.9963034765672599, 1e-12));
395 assert!(prec::almost_eq(odds_ratio, 7.5, 1e-1));
396 }
397}