statrs/stats_tests/
fisher.rs

1use super::Alternative;
2use crate::distribution::{Discrete, DiscreteCDF, Hypergeometric, HypergeometricError};
3
4const EPSILON: f64 = 1.0 - 1e-4;
5
6/// Binary search in two-sided test with starting bound as argument
7fn binary_search(
8    n: u64,
9    n1: u64,
10    n2: u64,
11    mode: u64,
12    p_exact: f64,
13    epsilon: f64,
14    upper: bool,
15) -> u64 {
16    let (mut min_val, mut max_val) = {
17        if upper {
18            (mode, n)
19        } else {
20            (0, mode)
21        }
22    };
23
24    let population = n1 + n2;
25    let successes = n1;
26    let draws = n;
27    let dist = Hypergeometric::new(population, successes, draws).unwrap();
28
29    let mut guess = 0;
30    loop {
31        if max_val - min_val <= 1 {
32            break;
33        }
34        guess = {
35            if max_val == min_val + 1 && guess == min_val {
36                max_val
37            } else {
38                (max_val + min_val) / 2
39            }
40        };
41
42        let ng = {
43            if upper {
44                guess - 1
45            } else {
46                guess + 1
47            }
48        };
49
50        let pmf_comp = dist.pmf(ng);
51        let p_guess = dist.pmf(guess);
52        if p_guess <= p_exact && p_exact < pmf_comp {
53            break;
54        }
55        if p_guess < p_exact {
56            max_val = guess
57        } else {
58            min_val = guess
59        }
60    }
61
62    if guess == 0 {
63        guess = min_val
64    }
65    if upper {
66        loop {
67            if guess > 0 && dist.pmf(guess) < p_exact * epsilon {
68                guess -= 1;
69            } else {
70                break;
71            }
72        }
73        loop {
74            if dist.pmf(guess) > p_exact / epsilon {
75                guess += 1;
76            } else {
77                break;
78            }
79        }
80    } else {
81        loop {
82            if dist.pmf(guess) < p_exact * epsilon {
83                guess += 1;
84            } else {
85                break;
86            }
87        }
88        loop {
89            if guess > 0 && dist.pmf(guess) > p_exact / epsilon {
90                guess -= 1;
91            } else {
92                break;
93            }
94        }
95    }
96    guess
97}
98
99#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
100#[non_exhaustive]
101pub enum FishersExactTestError {
102    /// The table does not describe a valid [`Hypergeometric`] distribution.
103    /// Make sure that the contingency table stores the data in row-major order.
104    TableInvalidForHypergeometric(HypergeometricError),
105}
106
107impl std::fmt::Display for FishersExactTestError {
108    #[cfg_attr(coverage_nightly, coverage(off))]
109    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
110        match self {
111            FishersExactTestError::TableInvalidForHypergeometric(hg_err) => {
112                writeln!(f, "Cannot create a Hypergeometric distribution from the data in the contingency table.")?;
113                writeln!(f, "Is it in row-major order?")?;
114                write!(f, "Inner error: '{hg_err}'")
115            }
116        }
117    }
118}
119
120impl std::error::Error for FishersExactTestError {}
121
122impl From<HypergeometricError> for FishersExactTestError {
123    fn from(value: HypergeometricError) -> Self {
124        Self::TableInvalidForHypergeometric(value)
125    }
126}
127
128/// Perform a Fisher exact test on a 2x2 contingency table.
129/// Based on scipy's fisher test: <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.fisher_exact.html#scipy-stats-fisher-exact>
130/// Expects a table in row-major order
131/// Returns the [odds ratio](https://en.wikipedia.org/wiki/Odds_ratio) and p_value
132/// # Examples
133///
134/// ```
135/// use statrs::stats_tests::fishers_exact_with_odds_ratio;
136/// use statrs::stats_tests::Alternative;
137/// let table = [3, 5, 4, 50];
138/// let (odds_ratio, p_value) = fishers_exact_with_odds_ratio(&table, Alternative::Less).unwrap();
139/// ```
140pub fn fishers_exact_with_odds_ratio(
141    table: &[u64; 4],
142    alternative: Alternative,
143) -> Result<(f64, f64), FishersExactTestError> {
144    // If both values in a row or column are zero, p-value is 1 and odds ratio is NaN.
145    match table {
146        [0, _, 0, _] | [_, 0, _, 0] => return Ok((f64::NAN, 1.0)), // both 0 in a row
147        [0, 0, _, _] | [_, _, 0, 0] => return Ok((f64::NAN, 1.0)), // both 0 in a column
148        _ => (),                                                   // continue
149    }
150
151    let odds_ratio = {
152        if table[1] > 0 && table[2] > 0 {
153            (table[0] * table[3]) as f64 / (table[1] * table[2]) as f64
154        } else {
155            f64::INFINITY
156        }
157    };
158
159    let p_value = fishers_exact(table, alternative)?;
160    Ok((odds_ratio, p_value))
161}
162
163/// Perform a Fisher exact test on a 2x2 contingency table.
164/// Based on scipy's fisher test: <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.fisher_exact.html#scipy-stats-fisher-exact>
165/// Expects a table in row-major order
166/// Returns only the p_value
167/// # Examples
168///
169/// ```
170/// use statrs::stats_tests::fishers_exact;
171/// use statrs::stats_tests::Alternative;
172/// let table = [3, 5, 4, 50];
173/// let p_value = fishers_exact(&table, Alternative::Less).unwrap();
174/// ```
175pub fn fishers_exact(
176    table: &[u64; 4],
177    alternative: Alternative,
178) -> Result<f64, FishersExactTestError> {
179    // If both values in a row or column are zero, the p-value is 1 and the odds ratio is NaN.
180    match table {
181        [0, _, 0, _] | [_, 0, _, 0] => return Ok(1.0), // both 0 in a row
182        [0, 0, _, _] | [_, _, 0, 0] => return Ok(1.0), // both 0 in a column
183        _ => (),                                       // continue
184    }
185
186    let n1 = table[0] + table[1];
187    let n2 = table[2] + table[3];
188    let n = table[0] + table[2];
189
190    let p_value = {
191        let population = n1 + n2;
192        let successes = n1;
193
194        match alternative {
195            Alternative::Less => {
196                let draws = n;
197                let dist = Hypergeometric::new(population, successes, draws)?;
198                dist.cdf(table[0])
199            }
200            Alternative::Greater => {
201                let draws = table[1] + table[3];
202                let dist = Hypergeometric::new(population, successes, draws)?;
203                dist.cdf(table[1])
204            }
205            Alternative::TwoSided => {
206                let draws = n;
207                let dist = Hypergeometric::new(population, successes, draws)?;
208
209                let p_exact = dist.pmf(table[0]);
210                let mode = ((n + 1) * (n1 + 1)) / (n1 + n2 + 2);
211                let p_mode = dist.pmf(mode);
212
213                if (p_exact - p_mode).abs() / p_exact.max(p_mode) <= 1.0 - EPSILON {
214                    return Ok(1.0);
215                }
216
217                if table[0] < mode {
218                    let p_lower = dist.cdf(table[0]);
219                    if dist.pmf(n) > p_exact / EPSILON {
220                        return Ok(p_lower);
221                    }
222                    let guess = binary_search(n, n1, n2, mode, p_exact, EPSILON, true);
223                    return Ok(p_lower + 1.0 - dist.cdf(guess - 1));
224                }
225
226                let p_upper = 1.0 - dist.cdf(table[0] - 1);
227                if dist.pmf(0) > p_exact / EPSILON {
228                    return Ok(p_upper);
229                }
230
231                let guess = binary_search(n, n1, n2, mode, p_exact, EPSILON, false);
232                p_upper + dist.cdf(guess)
233            }
234        }
235    };
236
237    Ok(p_value.min(1.0))
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use crate::prec;
244
245    /// Test fishers_exact by comparing against values from scipy.
246    #[test]
247    fn test_fishers_exact() {
248        let cases = [
249            (
250                [3, 5, 4, 50],
251                0.9963034765672599,
252                0.03970749246529277,
253                0.03970749246529276,
254            ),
255            (
256                [61, 118, 2, 1],
257                0.27535061623455315,
258                0.9598172545684959,
259                0.27535061623455315,
260            ),
261            (
262                [172, 46, 90, 127],
263                1.0,
264                6.662405187351769e-16,
265                9.041009036528785e-16,
266            ),
267            (
268                [127, 38, 112, 43],
269                0.8637599357870167,
270                0.20040942958644145,
271                0.3687862842650179,
272            ),
273            (
274                [186, 177, 111, 154],
275                0.9918518696328176,
276                0.012550663906725129,
277                0.023439141644624434,
278            ),
279            (
280                [137, 49, 135, 183],
281                0.999999999998533,
282                5.6517533666400615e-12,
283                8.870999836202932e-12,
284            ),
285            (
286                [37, 115, 37, 152],
287                0.8834621182590621,
288                0.17638403366123565,
289                0.29400927608021704,
290            ),
291            (
292                [124, 117, 119, 175],
293                0.9956704915461392,
294                0.007134712391455461,
295                0.011588218284387445,
296            ),
297            (
298                [70, 114, 41, 118],
299                0.9945558498544903,
300                0.010384865876586255,
301                0.020438291037108678,
302            ),
303            (
304                [173, 21, 89, 7],
305                0.2303739114068352,
306                0.8808002774812677,
307                0.4027047267306024,
308            ),
309            (
310                [18, 147, 123, 58],
311                4.077820702304103e-29,
312                0.9999999999999817,
313                0.0,
314            ),
315            (
316                [116, 20, 92, 186],
317                0.9999999999998267,
318                6.598118571034892e-25,
319                8.164831402188242e-25,
320            ),
321            (
322                [9, 22, 44, 38],
323                0.01584272038710196,
324                0.9951463496539362,
325                0.021581786662999272,
326            ),
327            (
328                [9, 101, 135, 7],
329                3.3336213533847776e-50,
330                1.0,
331                3.3336213533847776e-50,
332            ),
333            (
334                [153, 27, 191, 144],
335                0.9999999999950817,
336                2.473736787266208e-11,
337                3.185816623300107e-11,
338            ),
339            (
340                [111, 195, 189, 69],
341                6.665245982898848e-19,
342                0.9999999999994574,
343                1.0735744915712542e-18,
344            ),
345            (
346                [125, 21, 31, 131],
347                0.99999999999974,
348                9.720661317939016e-34,
349                1.0352129312860277e-33,
350            ),
351            (
352                [201, 192, 69, 179],
353                0.9999999988714893,
354                3.1477232259550017e-09,
355                4.761075937088169e-09,
356            ),
357            (
358                [124, 138, 159, 160],
359                0.30153826772785475,
360                0.7538974235759873,
361                0.5601766196310243,
362            ),
363        ];
364
365        for (table, less_expected, greater_expected, two_sided_expected) in cases.iter() {
366            for (alternative, expected) in [
367                Alternative::Less,
368                Alternative::Greater,
369                Alternative::TwoSided,
370            ]
371            .iter()
372            .zip(vec![less_expected, greater_expected, two_sided_expected])
373            {
374                let p_value = fishers_exact(table, *alternative).unwrap();
375                assert!(prec::almost_eq(p_value, *expected, 1e-12));
376            }
377        }
378    }
379
380    #[test]
381    fn test_fishers_exact_for_trivial() {
382        let cases = [[0, 0, 1, 2], [1, 2, 0, 0], [1, 0, 2, 0], [0, 1, 0, 2]];
383
384        for table in cases.iter() {
385            assert_eq!(fishers_exact(table, Alternative::Less).unwrap(), 1.0)
386        }
387    }
388
389    #[test]
390    fn test_fishers_exact_with_odds() {
391        let table = [3, 5, 4, 50];
392        let (odds_ratio, p_value) =
393            fishers_exact_with_odds_ratio(&table, Alternative::Less).unwrap();
394        assert!(prec::almost_eq(p_value, 0.9963034765672599, 1e-12));
395        assert!(prec::almost_eq(odds_ratio, 7.5, 1e-1));
396    }
397}