ethnum/intrinsics/native/
mul.rs1use crate::{int::I256, uint::U256};
11use core::mem::MaybeUninit;
12
13#[inline]
14pub fn umulddi3(a: &u128, b: &u128) -> U256 {
15 const BITS_IN_DWORD_2: u32 = 64;
16 const LOWER_MASK: u128 = u128::MAX >> BITS_IN_DWORD_2;
17
18 let mut low = (a & LOWER_MASK) * (b & LOWER_MASK);
19 let mut t = low >> BITS_IN_DWORD_2;
20 low &= LOWER_MASK;
21 t += (a >> BITS_IN_DWORD_2) * (b & LOWER_MASK);
22 low += (t & LOWER_MASK) << BITS_IN_DWORD_2;
23 let mut high = t >> BITS_IN_DWORD_2;
24 t = low >> BITS_IN_DWORD_2;
25 low &= LOWER_MASK;
26 t += (b >> BITS_IN_DWORD_2) * (a & LOWER_MASK);
27 low += (t & LOWER_MASK) << BITS_IN_DWORD_2;
28 high += t >> BITS_IN_DWORD_2;
29 high += (a >> BITS_IN_DWORD_2) * (b >> BITS_IN_DWORD_2);
30
31 U256::from_words(high, low)
32}
33
34#[inline]
35pub fn mul2(r: &mut U256, a: &U256) {
36 let (a, b) = (*r, a);
37 let res = unsafe { &mut *(r as *mut U256).cast() };
40 mul3(res, &a, b);
41}
42
43#[inline]
44pub fn mul3(res: &mut MaybeUninit<U256>, a: &U256, b: &U256) {
45 let mut r = umulddi3(a.low(), b.low());
46
47 let hi_lo = a.high().wrapping_mul(*b.low());
48 let lo_hi = a.low().wrapping_mul(*b.high());
49 *r.high_mut() = r.high().wrapping_add(hi_lo.wrapping_add(lo_hi));
50
51 res.write(r);
52}
53
54#[inline]
55pub fn umulc(r: &mut MaybeUninit<U256>, a: &U256, b: &U256) -> bool {
56 let mut res = umulddi3(a.low(), b.low());
57
58 let (hi_lo, overflow_hi_lo) = a.high().overflowing_mul(*b.low());
59 let (lo_hi, overflow_lo_hi) = a.low().overflowing_mul(*b.high());
60 let (hi, overflow_hi) = hi_lo.overflowing_add(lo_hi);
61 let (high, overflow_high) = res.high().overflowing_add(hi);
62 *res.high_mut() = high;
63
64 let overflow_hi_hi = (*a.high() != 0) & (*b.high() != 0);
65
66 r.write(res);
67 overflow_hi_lo | overflow_lo_hi | overflow_hi | overflow_high | overflow_hi_hi
68}
69
70#[inline]
71pub fn imulc(res: &mut MaybeUninit<I256>, a: &I256, b: &I256) -> bool {
72 mul3(cast!(uninit: res), cast!(ref: a), cast!(ref: b));
73 if *a == I256::MIN {
74 return *b != 0 && *b != 1;
75 }
76 if *b == I256::MIN {
77 return *a != 0 && *a != 1;
78 }
79 let sa = a >> (I256::BITS - 1);
80 let abs_a = (a ^ sa).wrapping_sub(sa);
81 let sb = b >> (I256::BITS - 1);
82 let abs_b = (b ^ sb).wrapping_sub(sb);
83 if abs_a < 2 || abs_b < 2 {
84 return false;
85 }
86 if sa == sb {
87 abs_a > I256::MAX / abs_b
88 } else {
89 abs_a > I256::MIN / -abs_b
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use crate::AsU256;
97
98 fn umul(a: impl AsU256, b: impl AsU256) -> (U256, bool) {
99 let mut r = MaybeUninit::uninit();
100 let overflow = umulc(&mut r, &a.as_u256(), &b.as_u256());
101 (unsafe { r.assume_init() }, overflow)
102 }
103
104 #[test]
105 fn multiplication() {
106 assert_eq!(umul(6, 7), (42.as_u256(), false));
107
108 assert_eq!(umul(U256::MAX, 1), (U256::MAX, false));
109 assert_eq!(umul(1, U256::MAX), (U256::MAX, false));
110 assert_eq!(umul(U256::MAX, 0), (U256::ZERO, false));
111 assert_eq!(umul(0, U256::MAX), (U256::ZERO, false));
112
113 assert_eq!(umul(U256::MAX, 5), (U256::MAX ^ 4, true));
114 assert_eq!(
115 umul(u128::MAX, u128::MAX),
116 (U256::from_words(!0 << 1, 1), false),
117 );
118 }
119}