1fn mod_mul_(a: u64, b: u64, m: u64) -> u64 {
2 (u128::from(a) * u128::from(b) % u128::from(m)) as u64
3}
4
5fn mod_mul(a: u64, b: u64, m: u64) -> u64 {
6 match a.checked_mul(b) {
7 Some(r) => if r >= m { r % m } else { r },
8 None => mod_mul_(a, b, m),
9 }
10}
11
12fn mod_sqr(a: u64, m: u64) -> u64 {
13 if a < (1 << 32) {
14 let r = a * a;
15 if r >= m {
16 r % m
17 } else {
18 r
19 }
20 } else {
21 mod_mul_(a, a, m)
22 }
23}
24
25fn mod_exp(mut x: u64, mut d: u64, n: u64) -> u64 {
26 let mut ret: u64 = 1;
27 while d != 0 {
28 if d % 2 == 1 {
29 ret = mod_mul(ret, x, n)
30 }
31 d /= 2;
32 x = mod_sqr(x, n);
33 }
34 ret
35}
36
37pub fn miller_rabin(n: u64) -> bool {
58 const HINT: &[u64] = &[2];
59
60 const WITNESSES: &[(u64, &[u64])] = &[
67 (2_046, HINT),
68 (1_373_652, &[2, 3]),
69 (9_080_190, &[31, 73]),
70 (25_326_000, &[2, 3, 5]),
71 (4_759_123_140, &[2, 7, 61]),
72 (1_112_004_669_632, &[2, 13, 23, 1662803]),
73 (2_152_302_898_746, &[2, 3, 5, 7, 11]),
74 (3_474_749_660_382, &[2, 3, 5, 7, 11, 13]),
75 (341_550_071_728_320, &[2, 3, 5, 7, 11, 13, 17]),
76 (3_825_123_056_546_413_050, &[2, 3, 5, 7, 11, 13, 17, 19, 23]),
77 (std::u64::MAX, &[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]),
78 ];
79
80 if n % 2 == 0 { return n == 2 }
81 if n == 1 { return false }
82
83 let mut d = n - 1;
84 let mut s = 0;
85 while d % 2 == 0 { d /= 2; s += 1 }
86
87 let witnesses =
88 WITNESSES.iter().find(|&&(hi, _)| hi >= n)
89 .map(|&(_, wtnss)| wtnss).unwrap();
90 'next_witness: for &a in witnesses.iter() {
91 let mut power = mod_exp(a, d, n);
92 assert!(power < n);
93 if power == 1 || power == n - 1 { continue 'next_witness }
94
95 for _r in 0..s {
96 power = mod_sqr(power, n);
97 assert!(power < n);
98 if power == 1 { return false }
99 if power == n - 1 {
100 continue 'next_witness
101 }
102 }
103 return false
104 }
105
106 true
107}
108
109#[cfg(test)]
110mod tests {
111 use primal::Sieve;
112
113 #[test]
114 fn mod_mul() {
115 assert_eq!(super::mod_mul(1 << 63, 1 << 32, 3), 2);
116 assert_eq!(super::mod_mul(1 << 31, 1 << 31, (1 << 32) - 7), 3221225479);
117 assert_eq!(super::mod_mul(1 << 32, 1 << 32, (1 << 32) - 7), 49);
118 assert_eq!(super::mod_mul(1 << 32, 1 << 32, (1 << 32) + 7), 49);
119 assert_eq!(super::mod_mul(1 << 63, 1 << 32, (1 << 32) + 7), 2_147_483_480);
120 assert_eq!(super::mod_mul(1 << 63, 1 << 32, (1 << 63) + 7), 9_223_372_006_790_004_743);
121 assert_eq!(super::mod_mul(1 << 32, 1 << 32, !0), 1);
122 }
123
124 #[test]
125 fn miller_rabin() {
126 const LIMIT: usize = 1_000_000;
127 let sieve = Sieve::new(LIMIT);
128 for x in 0..LIMIT {
129 let s = sieve.is_prime(x);
130 let mr = super::miller_rabin(x as u64);
131
132 assert!(s == mr, "miller_rabin {} mismatches sieve {} for {}",
133 mr, s, x)
134 }
135 }
136 #[test]
137 fn miller_rabin_large() {
138 let tests = &[
139 (4_294_967_311, true),
140 (4_294_967_291, true),
141 (4_294_967_291 * 4_294_967_291, false),
142 (!0, false),
143 ];
144 for &(n, is_prime) in tests {
145 assert!(super::miller_rabin(n) == is_prime,
146 "mismatch for {} (should be {})", n, is_prime);
147 }
148 }
149
150 #[test]
151 fn oeis_a014233() {
152 const A014233: [u64; 9] = [
154 2047,
155 1373653,
156 25326001,
157 3215031751,
158 2152302898747,
159 3474749660383,
160 341550071728321,
161 341550071728321,
162 3825123056546413051,
163 ];
168 for &n in &A014233 {
169 assert!(!super::miller_rabin(n), "{} is composite!", n);
170 }
171 }
172}