1use crate::{int::I256, uint::U256};
13use core::{mem::MaybeUninit, num::NonZeroU128};
14
15#[inline(always)]
16fn udiv256_by_128_to_128(u1: u128, u0: u128, mut v: NonZeroU128, r: &mut u128) -> u128 {
17 const N_UDWORD_BITS: u32 = 128;
18
19 #[inline]
20 unsafe fn shl_nz(x: NonZeroU128, n: u32) -> NonZeroU128 {
21 debug_assert!(n < N_UDWORD_BITS);
22 let res: u128 = x.get() << n;
23 debug_assert_ne!(res, 0);
24 NonZeroU128::new_unchecked(res)
25 }
26
27 #[inline]
28 unsafe fn shr_nz(x: NonZeroU128, n: u32) -> NonZeroU128 {
29 debug_assert!(n < N_UDWORD_BITS);
30 let res: u128 = x.get() >> n;
31 debug_assert_ne!(res, 0);
32 NonZeroU128::new_unchecked(res)
33 }
34
35 const B: u128 = 1 << (N_UDWORD_BITS / 2); let (un1, un0): (u128, u128); let (vn1, vn0): (NonZeroU128, u128); let (mut q1, mut q0): (u128, u128); let (un128, un21, un10): (u128, u128, u128); debug_assert!(v.get() > u1);
42
43 let s = v.leading_zeros();
44 debug_assert_ne!(s, N_UDWORD_BITS);
45 if s > 0 {
46 v = unsafe { shl_nz(v, s) };
48 un128 = (u1 << s) | (u0 >> (N_UDWORD_BITS - s));
49 un10 = u0 << s; } else {
51 un128 = u1;
53 un10 = u0;
54 }
55
56 vn1 = unsafe { shr_nz(v, N_UDWORD_BITS / 2) };
58 vn0 = v.get() & 0xFFFF_FFFF_FFFF_FFFF;
59
60 un1 = un10 >> (N_UDWORD_BITS / 2);
62 un0 = un10 & 0xFFFF_FFFF_FFFF_FFFF;
63
64 q1 = un128 / vn1;
66 let mut rhat = un128 - q1 * vn1.get();
67
68 while q1 >= B || q1 * vn0 > B * rhat + un1 {
70 q1 -= 1;
71 rhat += vn1.get();
72 if rhat >= B {
73 break;
74 }
75 }
76
77 un21 = un128
78 .wrapping_mul(B)
79 .wrapping_add(un1)
80 .wrapping_sub(q1.wrapping_mul(v.get()));
81
82 q0 = un21 / vn1;
84 rhat = un21 - q0 * vn1.get();
85
86 while q0 >= B || q0 * vn0 > B * rhat + un0 {
88 q0 -= 1;
89 rhat += vn1.get();
90 if rhat >= B {
91 break;
92 }
93 }
94
95 *r = (un21
96 .wrapping_mul(B)
97 .wrapping_add(un0)
98 .wrapping_sub(q0.wrapping_mul(v.get())))
99 >> s;
100 q1 * B + q0
101}
102
103#[allow(clippy::many_single_char_names)]
104pub fn udivmod4(
105 res: &mut MaybeUninit<U256>,
106 a: &U256,
107 b: &U256,
108 rem: Option<&mut MaybeUninit<U256>>,
109) {
110 if a.high() | b.high() == 0 {
124 res.write(U256::from_words(0, a.low() / b.low()));
125 if let Some(rem) = rem {
126 rem.write(U256::from_words(0, a.low() % b.low()));
127 }
128 return;
129 }
130
131 let dividend = *a;
132 let divisor = *b;
133 let quotient: U256;
134 let mut remainder: U256;
135
136 if divisor > dividend {
137 if let Some(rem) = rem {
138 rem.write(dividend);
139 }
140 res.write(U256::ZERO);
141 return;
142 }
143 if *divisor.high() == 0 {
145 remainder = U256::ZERO;
146 if dividend.high() < divisor.low() {
147 quotient = U256::from_words(
149 0,
150 udiv256_by_128_to_128(
151 *dividend.high(),
152 *dividend.low(),
153 unsafe { NonZeroU128::new_unchecked(*divisor.low()) },
155 remainder.low_mut(),
156 ),
157 );
158 } else {
159 quotient = U256::from_words(
162 dividend.high() / divisor.low(),
163 udiv256_by_128_to_128(
164 dividend.high() % divisor.low(),
165 *dividend.low(),
166 unsafe { NonZeroU128::new_unchecked(*divisor.low()) },
168 remainder.low_mut(),
169 ),
170 );
171 }
172 if let Some(rem) = rem {
173 rem.write(remainder);
174 }
175 res.write(quotient);
176 return;
177 }
178
179 (quotient, remainder) = unsafe { div_mod_knuth(÷nd, &divisor) };
181
182 if let Some(rem) = rem {
183 rem.write(remainder);
184 }
185 res.write(quotient);
186}
187
188#[inline]
192unsafe fn div_mod_knuth(u: &U256, v: &U256) -> (U256, U256) {
193 const N_UDWORD_BITS: u32 = 128;
194 debug_assert_ne!(
195 *u.high(),
196 0,
197 "The second operand must be greater than u128::MAX"
198 );
199 if *u.high() == 0 {
200 unsafe { core::hint::unreachable_unchecked() }
201 }
202
203 #[inline]
204 fn full_shl(a: &U256, shift: u32) -> [u128; 3] {
205 debug_assert!(shift < N_UDWORD_BITS);
206 let mut u = [0_u128; 3];
207 let u_lo = a.low() << shift;
208 let u_hi = a >> (N_UDWORD_BITS - shift);
209 u[0] = u_lo;
210 u[1] = *u_hi.low();
211 u[2] = *u_hi.high();
212
213 u
214 }
215
216 #[inline]
217 fn full_shr(u: &[u128; 3], shift: u32) -> U256 {
218 debug_assert!(shift < N_UDWORD_BITS);
219 let mut res = U256::ZERO;
220 *res.low_mut() = u[0] >> shift;
221 *res.high_mut() = u[1] >> shift;
222 if shift > 0 {
224 let sh = N_UDWORD_BITS - shift;
225 *res.low_mut() |= u[1] << sh;
226 *res.high_mut() |= u[2] << sh;
227 }
228
229 res
230 }
231
232 #[inline]
234 const fn split_u128_to_u128(a: u128) -> (u128, u128) {
235 (a & 0xFFFFFFFFFFFFFFFF, a >> (N_UDWORD_BITS / 2))
236 }
237
238 #[inline]
240 const fn fullmul_u128(a: u128, b: u128) -> (u128, u128) {
241 let (a0, a1) = split_u128_to_u128(a);
242 let (b0, b1) = split_u128_to_u128(b);
243
244 let mut t = a0 * b0;
245 let mut k: u128;
246 let w3: u128;
247 (w3, k) = split_u128_to_u128(t);
248
249 t = a1 * b0 + k;
250 let (w1, w2) = split_u128_to_u128(t);
251 t = a0 * b1 + w1;
252 k = t >> 64;
253
254 let w_hi = a1 * b1 + w2 + k;
255 let w_lo = (t << 64) + w3;
256
257 (w_lo, w_hi)
258 }
259
260 #[inline]
261 fn fullmul_u256_u128(a: &U256, b: u128) -> [u128; 3] {
262 let mut acc = [0_u128; 3];
263 let mut lo: u128;
264 let mut carry: u128;
265 let c: bool;
266 if b != 0 {
267 (lo, carry) = fullmul_u128(*a.low(), b);
268 acc[0] = lo;
269 acc[1] = carry;
270 (lo, carry) = fullmul_u128(*a.high(), b);
271 (acc[1], c) = acc[1].overflowing_add(lo);
272 acc[2] = carry + c as u128;
273 }
274
275 acc
276 }
277
278 #[inline]
279 const fn add_carry(a: u128, b: u128, c: bool) -> (u128, bool) {
280 let (res1, overflow1) = b.overflowing_add(c as u128);
281 let (res2, overflow2) = u128::overflowing_add(a, res1);
282
283 (res2, overflow1 || overflow2)
284 }
285
286 #[inline]
287 const fn sub_carry(a: u128, b: u128, c: bool) -> (u128, bool) {
288 let (res1, overflow1) = b.overflowing_add(c as u128);
289 let (res2, overflow2) = u128::overflowing_sub(a, res1);
290
291 (res2, overflow1 || overflow2)
292 }
293
294 let shift = v.high().leading_zeros();
299 debug_assert!(shift < N_UDWORD_BITS);
300 let v = v << shift;
301 let mut u = full_shl(u, shift);
303
304 let mut q = U256::ZERO;
306 let v_n_1 = *v.high();
307 let v_n_2 = *v.low();
308
309 if v_n_1 >> (N_UDWORD_BITS - 1) != 1 {
310 debug_assert!(false);
311
312 unsafe { core::hint::unreachable_unchecked() }
315 }
316
317 let mut r_hat: u128 = 0;
319 let u_jn = u[2];
320
321 let mut q_hat = if u_jn < v_n_1 {
327 let mut q_hat = udiv256_by_128_to_128(
329 u_jn,
330 u[1],
331 unsafe { NonZeroU128::new_unchecked(v_n_1) },
332 &mut r_hat,
333 );
334 let mut overflow: bool;
335 loop {
337 let another_iteration = {
338 let (lo, hi) = fullmul_u128(q_hat, v_n_2);
340 hi > r_hat || (hi == r_hat && lo > u[0])
341 };
342 if !another_iteration {
343 break;
344 }
345 q_hat -= 1;
346 (r_hat, overflow) = r_hat.overflowing_add(v_n_1);
347 if overflow {
349 break;
350 }
351 }
352 q_hat
353 } else {
354 u128::MAX
356 };
357
358 let q_hat_v = fullmul_u256_u128(&v, q_hat);
366 let mut c = false;
368 (u[0], c) = sub_carry(u[0], q_hat_v[0], c);
369 (u[1], c) = sub_carry(u[1], q_hat_v[1], c);
370 (u[2], c) = sub_carry(u[2], q_hat_v[2], c);
371
372 if c {
376 q_hat -= 1;
377 c = false;
379 (u[0], c) = add_carry(u[0], *v.low(), c);
380 (u[1], c) = add_carry(u[1], *v.high(), c);
381 u[2] = u[2].wrapping_add(c as u128);
382 }
383
384 *q.low_mut() = q_hat;
386
387 let remainder = full_shr(&u, shift);
389
390 (q, remainder)
391}
392
393#[inline]
394pub fn udiv2(r: &mut U256, a: &U256) {
395 let (a, b) = (*r, a);
396 let res = unsafe { &mut *(r as *mut U256).cast() };
399 udivmod4(res, &a, b, None);
400}
401
402#[inline]
403pub fn udiv3(r: &mut MaybeUninit<U256>, a: &U256, b: &U256) {
404 udivmod4(r, a, b, None);
405}
406
407#[inline]
408pub fn urem2(r: &mut U256, a: &U256) {
409 let mut res = MaybeUninit::uninit();
410 let (a, b) = (*r, a);
411 let r = unsafe { &mut *(r as *mut U256).cast() };
414 udivmod4(&mut res, &a, b, Some(r));
415}
416
417#[inline]
418pub fn urem3(r: &mut MaybeUninit<U256>, a: &U256, b: &U256) {
419 let mut res = MaybeUninit::uninit();
420 udivmod4(&mut res, a, b, Some(r));
421}
422
423pub fn idivmod4(
424 res: &mut MaybeUninit<I256>,
425 a: &I256,
426 b: &I256,
427 mut rem: Option<&mut MaybeUninit<I256>>,
428) {
429 const BITS_IN_TWORD_M1: u32 = 255;
430 let s_a = a >> BITS_IN_TWORD_M1; let mut s_b = b >> BITS_IN_TWORD_M1; let a = (a ^ s_a).wrapping_sub(s_a); let b = (b ^ s_b).wrapping_sub(s_b); s_b ^= s_a; udivmod4(
436 cast!(uninit: res),
437 cast!(ref: &a),
438 cast!(ref: &b),
439 cast!(optuninit: rem),
440 );
441 let q = unsafe { res.assume_init_ref() };
442 let q = (q ^ s_b).wrapping_sub(s_b); res.write(q);
444 if let Some(rem) = rem {
445 let r = unsafe { rem.assume_init_ref() };
446 let r = (r ^ s_a).wrapping_sub(s_a);
447 rem.write(r);
448 }
449}
450
451#[inline]
452pub fn idiv2(r: &mut I256, a: &I256) {
453 let (a, b) = (*r, a);
454 let res = unsafe { &mut *(r as *mut I256).cast() };
457 idivmod4(res, &a, b, None);
458}
459
460#[inline]
461pub fn idiv3(r: &mut MaybeUninit<I256>, a: &I256, b: &I256) {
462 idivmod4(r, a, b, None);
463}
464
465#[inline]
466pub fn irem2(r: &mut I256, a: &I256) {
467 let mut res = MaybeUninit::uninit();
468 let (a, b) = (*r, a);
469 let r = unsafe { &mut *(r as *mut I256).cast() };
472 idivmod4(&mut res, &a, b, Some(r));
473}
474
475#[inline]
476pub fn irem3(r: &mut MaybeUninit<I256>, a: &I256, b: &I256) {
477 let mut res = MaybeUninit::uninit();
478 idivmod4(&mut res, a, b, Some(r));
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484 use crate::AsU256;
485
486 fn udiv(a: impl AsU256, b: impl AsU256) -> U256 {
487 let mut r = MaybeUninit::uninit();
488 udiv3(&mut r, &a.as_u256(), &b.as_u256());
489 unsafe { r.assume_init() }
490 }
491
492 fn urem(a: impl AsU256, b: impl AsU256) -> U256 {
493 let mut r = MaybeUninit::uninit();
494 urem3(&mut r, &a.as_u256(), &b.as_u256());
495 unsafe { r.assume_init() }
496 }
497
498 #[test]
499 fn division() {
500 assert_eq!(udiv(100, 9), 11);
504
505 assert_eq!(udiv(!0u128, U256::ONE << 128u32), 0);
509
510 assert_eq!(udiv(U256::from_words(100, 0), U256::from_words(10, 0)), 10);
514
515 assert_eq!(udiv(U256::from_words(100, 1337), U256::ONE << 130u32), 25);
519 assert_eq!(
520 udiv(U256::from_words(1337, !0), U256::from_words(63, 0)),
521 21
522 );
523
524 assert_eq!(
528 udiv(U256::from_words(42, 0), U256::ONE),
529 U256::from_words(42, 0),
530 );
531 assert_eq!(
532 udiv(U256::from_words(42, 42), U256::ONE << 42),
533 42u128 << (128 - 42),
534 );
535 assert_eq!(
536 udiv(U256::from_words(1337, !0), 0xc0ffee),
537 35996389033280467545299711090127855,
538 );
539 assert_eq!(
540 udiv(U256::from_words(42, 0), 99),
541 144362216269489045105674075880144089708,
542 );
543
544 assert_eq!(
548 udiv(U256::from_words(100, 100), U256::from_words(1000, 1000)),
549 0,
550 );
551 assert_eq!(
552 udiv(U256::from_words(1337, !0), U256::from_words(43, !0)),
553 30,
554 );
555 }
556
557 #[test]
558 #[should_panic]
559 fn division_by_zero() {
560 udiv(1, 0);
561 }
562
563 #[test]
564 fn remainder() {
565 assert_eq!(urem(100, 9), 1);
569
570 assert_eq!(urem(!0u128, U256::ONE << 128u32), !0u128);
574
575 assert_eq!(urem(U256::from_words(100, 0), U256::from_words(10, 0)), 0);
579
580 assert_eq!(urem(U256::from_words(100, 1337), U256::ONE << 130u32), 1337);
584 assert_eq!(
585 urem(U256::from_words(1337, !0), U256::from_words(63, 0)),
586 U256::from_words(14, !0),
587 );
588
589 assert_eq!(urem(U256::from_words(42, 0), U256::ONE), 0);
593 assert_eq!(urem(U256::from_words(42, 42), U256::ONE << 42), 42);
594 assert_eq!(urem(U256::from_words(1337, !0), 0xc0ffee), 1910477);
595 assert_eq!(urem(U256::from_words(42, 0), 99), 60);
596
597 assert_eq!(
601 urem(U256::from_words(100, 100), U256::from_words(1000, 1000)),
602 U256::from_words(100, 100),
603 );
604 assert_eq!(
605 urem(U256::from_words(1337, !0), U256::from_words(43, !0)),
606 U256::from_words(18, 29),
607 );
608 }
609
610 #[test]
611 #[should_panic]
612 fn remainder_by_zero() {
613 urem(1, 0);
614 }
615}