ethnum/intrinsics/native/
divmod.rs

1//! This module contains a Rust port of the `__u?divmodti4` compiler builtins
2//! that are typically used for implementing 64-bit signed and unsigned division
3//! on 32-bit platforms.
4//!
5//! This port is adapted to use 128-bit high and low words in order to implement
6//! 256-bit division.
7//!
8//! This source is ported from LLVM project from C:
9//! - signed division: <https://github.com/llvm/llvm-project/blob/main/compiler-rt/lib/builtins/divmodti4.c>
10//! - unsigned division: <https://github.com/llvm/llvm-project/blob/main/compiler-rt/lib/builtins/udivmodti4.c>
11
12use crate::{int::I256, uint::U256};
13use core::{mem::MaybeUninit, num::NonZeroU128};
14
15#[inline(always)]
16fn udiv256_by_128_to_128(u1: u128, u0: u128, mut v: NonZeroU128, r: &mut u128) -> u128 {
17    const N_UDWORD_BITS: u32 = 128;
18
19    #[inline]
20    unsafe fn shl_nz(x: NonZeroU128, n: u32) -> NonZeroU128 {
21        debug_assert!(n < N_UDWORD_BITS);
22        let res: u128 = x.get() << n;
23        debug_assert_ne!(res, 0);
24        NonZeroU128::new_unchecked(res)
25    }
26
27    #[inline]
28    unsafe fn shr_nz(x: NonZeroU128, n: u32) -> NonZeroU128 {
29        debug_assert!(n < N_UDWORD_BITS);
30        let res: u128 = x.get() >> n;
31        debug_assert_ne!(res, 0);
32        NonZeroU128::new_unchecked(res)
33    }
34
35    const B: u128 = 1 << (N_UDWORD_BITS / 2); // Number base (128 bits)
36    let (un1, un0): (u128, u128); // Norm. dividend LSD's
37    let (vn1, vn0): (NonZeroU128, u128); // Norm. divisor digits
38    let (mut q1, mut q0): (u128, u128); // Quotient digits
39    let (un128, un21, un10): (u128, u128, u128); // Dividend digit pairs
40
41    debug_assert!(v.get() > u1);
42
43    let s = v.leading_zeros();
44    debug_assert_ne!(s, N_UDWORD_BITS);
45    if s > 0 {
46        // Normalize the divisor.
47        v = unsafe { shl_nz(v, s) };
48        un128 = (u1 << s) | (u0 >> (N_UDWORD_BITS - s));
49        un10 = u0 << s; // Shift dividend left
50    } else {
51        // Avoid undefined behavior of (u0 >> 128).
52        un128 = u1;
53        un10 = u0;
54    }
55
56    // Break divisor up into two 64-bit digits.
57    vn1 = unsafe { shr_nz(v, N_UDWORD_BITS / 2) };
58    vn0 = v.get() & 0xFFFF_FFFF_FFFF_FFFF;
59
60    // Break right half of dividend into two digits.
61    un1 = un10 >> (N_UDWORD_BITS / 2);
62    un0 = un10 & 0xFFFF_FFFF_FFFF_FFFF;
63
64    // Compute the first quotient digit, q1.
65    q1 = un128 / vn1;
66    let mut rhat = un128 - q1 * vn1.get();
67
68    // q1 has at most error 2. No more than 2 iterations.
69    while q1 >= B || q1 * vn0 > B * rhat + un1 {
70        q1 -= 1;
71        rhat += vn1.get();
72        if rhat >= B {
73            break;
74        }
75    }
76
77    un21 = un128
78        .wrapping_mul(B)
79        .wrapping_add(un1)
80        .wrapping_sub(q1.wrapping_mul(v.get()));
81
82    // Compute the second quotient digit.
83    q0 = un21 / vn1;
84    rhat = un21 - q0 * vn1.get();
85
86    // q0 has at most error 2. No more than 2 iterations.
87    while q0 >= B || q0 * vn0 > B * rhat + un0 {
88        q0 -= 1;
89        rhat += vn1.get();
90        if rhat >= B {
91            break;
92        }
93    }
94
95    *r = (un21
96        .wrapping_mul(B)
97        .wrapping_add(un0)
98        .wrapping_sub(q0.wrapping_mul(v.get())))
99        >> s;
100    q1 * B + q0
101}
102
103#[allow(clippy::many_single_char_names)]
104pub fn udivmod4(
105    res: &mut MaybeUninit<U256>,
106    a: &U256,
107    b: &U256,
108    rem: Option<&mut MaybeUninit<U256>>,
109) {
110    // In the LLVM version on the x86_64 platform, `udiv256_by_128_to_128` would
111    // defer to `divq` instruction, which divides a 128-bit value by a 64-bit
112    // one returning a 64-bit value, making it very performant when dividing
113    // small values:
114    // ```
115    //   du_int result;
116    //   __asm__("divq %[v]"
117    //           : "=a"(result), "=d"(*r)
118    //           : [ v ] "r"(v), "a"(u0), "d"(u1));
119    //   return result;
120    // ```
121    // Unfortunately, there is no 256-bit equivalent on x86_64, but we can still
122    // shortcut if the high and low values of the operands are 0:
123    if a.high() | b.high() == 0 {
124        res.write(U256::from_words(0, a.low() / b.low()));
125        if let Some(rem) = rem {
126            rem.write(U256::from_words(0, a.low() % b.low()));
127        }
128        return;
129    }
130
131    let dividend = *a;
132    let divisor = *b;
133    let quotient: U256;
134    let mut remainder: U256;
135
136    if divisor > dividend {
137        if let Some(rem) = rem {
138            rem.write(dividend);
139        }
140        res.write(U256::ZERO);
141        return;
142    }
143    // When the divisor fits in 128 bits, we can use an optimized path.
144    if *divisor.high() == 0 {
145        remainder = U256::ZERO;
146        if dividend.high() < divisor.low() {
147            // The result fits in 128 bits.
148            quotient = U256::from_words(
149                0,
150                udiv256_by_128_to_128(
151                    *dividend.high(),
152                    *dividend.low(),
153                    // SAFETY: dividend.high() < divisor.low()
154                    unsafe { NonZeroU128::new_unchecked(*divisor.low()) },
155                    remainder.low_mut(),
156                ),
157            );
158        } else {
159            // First, divide with the high part to get the remainder in dividend.s.high.
160            // After that dividend.s.high < divisor.s.low.
161            quotient = U256::from_words(
162                dividend.high() / divisor.low(),
163                udiv256_by_128_to_128(
164                    dividend.high() % divisor.low(),
165                    *dividend.low(),
166                    // SAFETY: dividend.high() / divisor.low()
167                    unsafe { NonZeroU128::new_unchecked(*divisor.low()) },
168                    remainder.low_mut(),
169                ),
170            );
171        }
172        if let Some(rem) = rem {
173            rem.write(remainder);
174        }
175        res.write(quotient);
176        return;
177    }
178
179    // SAFETY: `*divisor.high() != 0`
180    (quotient, remainder) = unsafe { div_mod_knuth(&dividend, &divisor) };
181
182    if let Some(rem) = rem {
183        rem.write(remainder);
184    }
185    res.write(quotient);
186}
187
188// See Knuth, TAOCP, Volume 2, section 4.3.1, Algorithm D.
189// https://skanthak.homepage.t-online.de/division.html
190// SAFETY: The high word of v (the divisor) must be non-zero.
191#[inline]
192unsafe fn div_mod_knuth(u: &U256, v: &U256) -> (U256, U256) {
193    const N_UDWORD_BITS: u32 = 128;
194    debug_assert_ne!(
195        *u.high(),
196        0,
197        "The second operand must be greater than u128::MAX"
198    );
199    if *u.high() == 0 {
200        unsafe { core::hint::unreachable_unchecked() }
201    }
202
203    #[inline]
204    fn full_shl(a: &U256, shift: u32) -> [u128; 3] {
205        debug_assert!(shift < N_UDWORD_BITS);
206        let mut u = [0_u128; 3];
207        let u_lo = a.low() << shift;
208        let u_hi = a >> (N_UDWORD_BITS - shift);
209        u[0] = u_lo;
210        u[1] = *u_hi.low();
211        u[2] = *u_hi.high();
212
213        u
214    }
215
216    #[inline]
217    fn full_shr(u: &[u128; 3], shift: u32) -> U256 {
218        debug_assert!(shift < N_UDWORD_BITS);
219        let mut res = U256::ZERO;
220        *res.low_mut() = u[0] >> shift;
221        *res.high_mut() = u[1] >> shift;
222        // carry
223        if shift > 0 {
224            let sh = N_UDWORD_BITS - shift;
225            *res.low_mut() |= u[1] << sh;
226            *res.high_mut() |= u[2] << sh;
227        }
228
229        res
230    }
231
232    // returns (lo, hi)
233    #[inline]
234    const fn split_u128_to_u128(a: u128) -> (u128, u128) {
235        (a & 0xFFFFFFFFFFFFFFFF, a >> (N_UDWORD_BITS / 2))
236    }
237
238    // returns (lo, hi)
239    #[inline]
240    const fn fullmul_u128(a: u128, b: u128) -> (u128, u128) {
241        let (a0, a1) = split_u128_to_u128(a);
242        let (b0, b1) = split_u128_to_u128(b);
243
244        let mut t = a0 * b0;
245        let mut k: u128;
246        let w3: u128;
247        (w3, k) = split_u128_to_u128(t);
248
249        t = a1 * b0 + k;
250        let (w1, w2) = split_u128_to_u128(t);
251        t = a0 * b1 + w1;
252        k = t >> 64;
253
254        let w_hi = a1 * b1 + w2 + k;
255        let w_lo = (t << 64) + w3;
256
257        (w_lo, w_hi)
258    }
259
260    #[inline]
261    fn fullmul_u256_u128(a: &U256, b: u128) -> [u128; 3] {
262        let mut acc = [0_u128; 3];
263        let mut lo: u128;
264        let mut carry: u128;
265        let c: bool;
266        if b != 0 {
267            (lo, carry) = fullmul_u128(*a.low(), b);
268            acc[0] = lo;
269            acc[1] = carry;
270            (lo, carry) = fullmul_u128(*a.high(), b);
271            (acc[1], c) = acc[1].overflowing_add(lo);
272            acc[2] = carry + c as u128;
273        }
274
275        acc
276    }
277
278    #[inline]
279    const fn add_carry(a: u128, b: u128, c: bool) -> (u128, bool) {
280        let (res1, overflow1) = b.overflowing_add(c as u128);
281        let (res2, overflow2) = u128::overflowing_add(a, res1);
282
283        (res2, overflow1 || overflow2)
284    }
285
286    #[inline]
287    const fn sub_carry(a: u128, b: u128, c: bool) -> (u128, bool) {
288        let (res1, overflow1) = b.overflowing_add(c as u128);
289        let (res2, overflow2) = u128::overflowing_sub(a, res1);
290
291        (res2, overflow1 || overflow2)
292    }
293
294    // D1.
295    // Make sure 128th bit in v's highest word is set.
296    // If we shift both u and v, it won't affect the quotient
297    // and the remainder will only need to be shifted back.
298    let shift = v.high().leading_zeros();
299    debug_assert!(shift < N_UDWORD_BITS);
300    let v = v << shift;
301    // u will store the remainder (shifted)
302    let mut u = full_shl(u, shift);
303
304    // quotient
305    let mut q = U256::ZERO;
306    let v_n_1 = *v.high();
307    let v_n_2 = *v.low();
308
309    if v_n_1 >> (N_UDWORD_BITS - 1) != 1 {
310        debug_assert!(false);
311
312        // SAFETY: `v_n_1` must be normalized because input `v` has
313        // been checked to be non-zero.
314        unsafe { core::hint::unreachable_unchecked() }
315    }
316
317    // D2. D7. - unrolled loop j == 0, n == 2, m == 0 (only one possible iteration)
318    let mut r_hat: u128 = 0;
319    let u_jn = u[2];
320
321    // D3.
322    // q_hat is our guess for the j-th quotient digit
323    // q_hat = min(b - 1, (u_{j+n} * b + u_{j+n-1}) / v_{n-1})
324    // b = 1 << WORD_BITS
325    // Theorem B: q_hat >= q_j >= q_hat - 2
326    let mut q_hat = if u_jn < v_n_1 {
327        //let (mut q_hat, mut r_hat) = _div_mod_u128(u_jn, u[j + n - 1], v_n_1);
328        let mut q_hat = udiv256_by_128_to_128(
329            u_jn,
330            u[1],
331            unsafe { NonZeroU128::new_unchecked(v_n_1) },
332            &mut r_hat,
333        );
334        let mut overflow: bool;
335        // this loop takes at most 2 iterations
336        loop {
337            let another_iteration = {
338                // check if q_hat * v_{n-2} > b * r_hat + u_{j+n-2}
339                let (lo, hi) = fullmul_u128(q_hat, v_n_2);
340                hi > r_hat || (hi == r_hat && lo > u[0])
341            };
342            if !another_iteration {
343                break;
344            }
345            q_hat -= 1;
346            (r_hat, overflow) = r_hat.overflowing_add(v_n_1);
347            // if r_hat overflowed, we're done
348            if overflow {
349                break;
350            }
351        }
352        q_hat
353    } else {
354        // here q_hat >= q_j >= q_hat - 1
355        u128::MAX
356    };
357
358    // ex. 20:
359    // since q_hat * v_{n-2} <= b * r_hat + u_{j+n-2},
360    // either q_hat == q_j, or q_hat == q_j + 1
361
362    // D4.
363    // let's assume optimistically q_hat == q_j
364    // subtract (q_hat * v) from u[j..]
365    let q_hat_v = fullmul_u256_u128(&v, q_hat);
366    // u[j..] -= q_hat_v;
367    let mut c = false;
368    (u[0], c) = sub_carry(u[0], q_hat_v[0], c);
369    (u[1], c) = sub_carry(u[1], q_hat_v[1], c);
370    (u[2], c) = sub_carry(u[2], q_hat_v[2], c);
371
372    // D6.
373    // actually, q_hat == q_j + 1 and u[j..] has overflowed
374    // highly unlikely ~ (1 / 2^127)
375    if c {
376        q_hat -= 1;
377        // add v to u[j..]
378        c = false;
379        (u[0], c) = add_carry(u[0], *v.low(), c);
380        (u[1], c) = add_carry(u[1], *v.high(), c);
381        u[2] = u[2].wrapping_add(c as u128);
382    }
383
384    // D5.
385    *q.low_mut() = q_hat;
386
387    // D8.
388    let remainder = full_shr(&u, shift);
389
390    (q, remainder)
391}
392
393#[inline]
394pub fn udiv2(r: &mut U256, a: &U256) {
395    let (a, b) = (*r, a);
396    // SAFETY: `udivmod4` does not write `MaybeUninit::uninit()` to `res` and
397    // `U256` does not implement `Drop`.
398    let res = unsafe { &mut *(r as *mut U256).cast() };
399    udivmod4(res, &a, b, None);
400}
401
402#[inline]
403pub fn udiv3(r: &mut MaybeUninit<U256>, a: &U256, b: &U256) {
404    udivmod4(r, a, b, None);
405}
406
407#[inline]
408pub fn urem2(r: &mut U256, a: &U256) {
409    let mut res = MaybeUninit::uninit();
410    let (a, b) = (*r, a);
411    // SAFETY: `udivmod4` does not write `MaybeUninit::uninit()` to `rem` and
412    // `U256` does not implement `Drop`.
413    let r = unsafe { &mut *(r as *mut U256).cast() };
414    udivmod4(&mut res, &a, b, Some(r));
415}
416
417#[inline]
418pub fn urem3(r: &mut MaybeUninit<U256>, a: &U256, b: &U256) {
419    let mut res = MaybeUninit::uninit();
420    udivmod4(&mut res, a, b, Some(r));
421}
422
423pub fn idivmod4(
424    res: &mut MaybeUninit<I256>,
425    a: &I256,
426    b: &I256,
427    mut rem: Option<&mut MaybeUninit<I256>>,
428) {
429    const BITS_IN_TWORD_M1: u32 = 255;
430    let s_a = a >> BITS_IN_TWORD_M1; // s_a = a < 0 ? -1 : 0
431    let mut s_b = b >> BITS_IN_TWORD_M1; // s_b = b < 0 ? -1 : 0
432    let a = (a ^ s_a).wrapping_sub(s_a); // negate if s_a == -1
433    let b = (b ^ s_b).wrapping_sub(s_b); // negate if s_b == -1
434    s_b ^= s_a; // sign of quotient
435    udivmod4(
436        cast!(uninit: res),
437        cast!(ref: &a),
438        cast!(ref: &b),
439        cast!(optuninit: rem),
440    );
441    let q = unsafe { res.assume_init_ref() };
442    let q = (q ^ s_b).wrapping_sub(s_b); // negate if s_b == -1
443    res.write(q);
444    if let Some(rem) = rem {
445        let r = unsafe { rem.assume_init_ref() };
446        let r = (r ^ s_a).wrapping_sub(s_a);
447        rem.write(r);
448    }
449}
450
451#[inline]
452pub fn idiv2(r: &mut I256, a: &I256) {
453    let (a, b) = (*r, a);
454    // SAFETY: `udivmod4` does not write `MaybeUninit::uninit()` to `res` and
455    // `U256` does not implement `Drop`.
456    let res = unsafe { &mut *(r as *mut I256).cast() };
457    idivmod4(res, &a, b, None);
458}
459
460#[inline]
461pub fn idiv3(r: &mut MaybeUninit<I256>, a: &I256, b: &I256) {
462    idivmod4(r, a, b, None);
463}
464
465#[inline]
466pub fn irem2(r: &mut I256, a: &I256) {
467    let mut res = MaybeUninit::uninit();
468    let (a, b) = (*r, a);
469    // SAFETY: `udivmod4` does not write `MaybeUninit::uninit()` to `rem` and
470    // `U256` does not implement `Drop`.
471    let r = unsafe { &mut *(r as *mut I256).cast() };
472    idivmod4(&mut res, &a, b, Some(r));
473}
474
475#[inline]
476pub fn irem3(r: &mut MaybeUninit<I256>, a: &I256, b: &I256) {
477    let mut res = MaybeUninit::uninit();
478    idivmod4(&mut res, a, b, Some(r));
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use crate::AsU256;
485
486    fn udiv(a: impl AsU256, b: impl AsU256) -> U256 {
487        let mut r = MaybeUninit::uninit();
488        udiv3(&mut r, &a.as_u256(), &b.as_u256());
489        unsafe { r.assume_init() }
490    }
491
492    fn urem(a: impl AsU256, b: impl AsU256) -> U256 {
493        let mut r = MaybeUninit::uninit();
494        urem3(&mut r, &a.as_u256(), &b.as_u256());
495        unsafe { r.assume_init() }
496    }
497
498    #[test]
499    fn division() {
500        // 0 X
501        // ---
502        // 0 X
503        assert_eq!(udiv(100, 9), 11);
504
505        // 0 X
506        // ---
507        // K X
508        assert_eq!(udiv(!0u128, U256::ONE << 128u32), 0);
509
510        // K 0
511        // ---
512        // K 0
513        assert_eq!(udiv(U256::from_words(100, 0), U256::from_words(10, 0)), 10);
514
515        // K K
516        // ---
517        // K 0
518        assert_eq!(udiv(U256::from_words(100, 1337), U256::ONE << 130u32), 25);
519        assert_eq!(
520            udiv(U256::from_words(1337, !0), U256::from_words(63, 0)),
521            21
522        );
523
524        // K X
525        // ---
526        // 0 K
527        assert_eq!(
528            udiv(U256::from_words(42, 0), U256::ONE),
529            U256::from_words(42, 0),
530        );
531        assert_eq!(
532            udiv(U256::from_words(42, 42), U256::ONE << 42),
533            42u128 << (128 - 42),
534        );
535        assert_eq!(
536            udiv(U256::from_words(1337, !0), 0xc0ffee),
537            35996389033280467545299711090127855,
538        );
539        assert_eq!(
540            udiv(U256::from_words(42, 0), 99),
541            144362216269489045105674075880144089708,
542        );
543
544        // K X
545        // ---
546        // K K
547        assert_eq!(
548            udiv(U256::from_words(100, 100), U256::from_words(1000, 1000)),
549            0,
550        );
551        assert_eq!(
552            udiv(U256::from_words(1337, !0), U256::from_words(43, !0)),
553            30,
554        );
555    }
556
557    #[test]
558    #[should_panic]
559    fn division_by_zero() {
560        udiv(1, 0);
561    }
562
563    #[test]
564    fn remainder() {
565        // 0 X
566        // ---
567        // 0 X
568        assert_eq!(urem(100, 9), 1);
569
570        // 0 X
571        // ---
572        // K X
573        assert_eq!(urem(!0u128, U256::ONE << 128u32), !0u128);
574
575        // K 0
576        // ---
577        // K 0
578        assert_eq!(urem(U256::from_words(100, 0), U256::from_words(10, 0)), 0);
579
580        // K K
581        // ---
582        // K 0
583        assert_eq!(urem(U256::from_words(100, 1337), U256::ONE << 130u32), 1337);
584        assert_eq!(
585            urem(U256::from_words(1337, !0), U256::from_words(63, 0)),
586            U256::from_words(14, !0),
587        );
588
589        // K X
590        // ---
591        // 0 K
592        assert_eq!(urem(U256::from_words(42, 0), U256::ONE), 0);
593        assert_eq!(urem(U256::from_words(42, 42), U256::ONE << 42), 42);
594        assert_eq!(urem(U256::from_words(1337, !0), 0xc0ffee), 1910477);
595        assert_eq!(urem(U256::from_words(42, 0), 99), 60);
596
597        // K X
598        // ---
599        // K K
600        assert_eq!(
601            urem(U256::from_words(100, 100), U256::from_words(1000, 1000)),
602            U256::from_words(100, 100),
603        );
604        assert_eq!(
605            urem(U256::from_words(1337, !0), U256::from_words(43, !0)),
606            U256::from_words(18, 29),
607        );
608    }
609
610    #[test]
611    #[should_panic]
612    fn remainder_by_zero() {
613        urem(1, 0);
614    }
615}