1use std::ops::Add;
2#[cfg(feature = "simd")]
3use std::simd::prelude::*;
4
5use arrow::array::{Array, PrimitiveArray};
6use arrow::bitmap::bitmask::BitMask;
7use arrow::types::NativeType;
8use num_traits::Zero;
9
10macro_rules! wrapping_impl {
11 ($trait_name:ident, $method:ident, $t:ty) => {
12 impl $trait_name for $t {
13 #[inline(always)]
14 fn wrapping_add(&self, v: &Self) -> Self {
15 <$t>::$method(*self, *v)
16 }
17 }
18 };
19}
20
21pub trait WrappingAdd: Sized {
25 fn wrapping_add(&self, v: &Self) -> Self;
28}
29
30wrapping_impl!(WrappingAdd, wrapping_add, u8);
31wrapping_impl!(WrappingAdd, wrapping_add, u16);
32wrapping_impl!(WrappingAdd, wrapping_add, u32);
33wrapping_impl!(WrappingAdd, wrapping_add, u64);
34wrapping_impl!(WrappingAdd, wrapping_add, usize);
35wrapping_impl!(WrappingAdd, wrapping_add, u128);
36
37wrapping_impl!(WrappingAdd, wrapping_add, i8);
38wrapping_impl!(WrappingAdd, wrapping_add, i16);
39wrapping_impl!(WrappingAdd, wrapping_add, i32);
40wrapping_impl!(WrappingAdd, wrapping_add, i64);
41wrapping_impl!(WrappingAdd, wrapping_add, isize);
42wrapping_impl!(WrappingAdd, wrapping_add, i128);
43
44wrapping_impl!(WrappingAdd, add, f32);
45wrapping_impl!(WrappingAdd, add, f64);
46
47#[cfg(feature = "simd")]
48const STRIPE: usize = 16;
49
50fn wrapping_sum_with_mask_scalar<T: Zero + WrappingAdd + Copy>(vals: &[T], mask: &BitMask) -> T {
51 assert!(vals.len() == mask.len());
52 vals.iter()
53 .enumerate()
54 .map(|(i, x)| {
55 if mask.get(i) {
57 *x
58 } else {
59 T::zero()
60 }
61 })
62 .fold(T::zero(), |a, b| a.wrapping_add(&b))
63}
64
65#[cfg(not(feature = "simd"))]
66impl<T> WrappingSum for T
67where
68 T: NativeType + WrappingAdd + Zero,
69{
70 fn wrapping_sum(vals: &[Self]) -> Self {
71 vals.iter()
72 .copied()
73 .fold(T::zero(), |a, b| a.wrapping_add(&b))
74 }
75
76 fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self {
77 wrapping_sum_with_mask_scalar(vals, mask)
78 }
79}
80
81#[cfg(feature = "simd")]
82impl<T> WrappingSum for T
83where
84 T: NativeType + WrappingAdd + Zero + crate::SimdPrimitive,
85{
86 fn wrapping_sum(vals: &[Self]) -> Self {
87 vals.iter()
88 .copied()
89 .fold(T::zero(), |a, b| a.wrapping_add(&b))
90 }
91
92 fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self {
93 assert!(vals.len() == mask.len());
94 let remainder = vals.len() % STRIPE;
95 let (rest, main) = vals.split_at(remainder);
96 let (rest_mask, main_mask) = mask.split_at(remainder);
97 let zero: Simd<T, STRIPE> = Simd::default();
98
99 let vsum = main
100 .chunks_exact(STRIPE)
101 .enumerate()
102 .map(|(i, a)| {
103 let m: Mask<_, STRIPE> = main_mask.get_simd(i * STRIPE);
104 m.select(Simd::from_slice(a), zero)
105 })
106 .fold(zero, |a, b| {
107 let a = a.to_array();
108 let b = b.to_array();
109 Simd::from_array(std::array::from_fn(|i| a[i].wrapping_add(&b[i])))
110 });
111
112 let mainsum = vsum
113 .to_array()
114 .into_iter()
115 .fold(T::zero(), |a, b| a.wrapping_add(&b));
116
117 let restsum = wrapping_sum_with_mask_scalar(rest, &rest_mask);
119 mainsum.wrapping_add(&restsum)
120 }
121}
122
123#[cfg(feature = "simd")]
124impl WrappingSum for u128 {
125 fn wrapping_sum(vals: &[Self]) -> Self {
126 vals.iter().copied().fold(0, |a, b| a.wrapping_add(b))
127 }
128
129 fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self {
130 wrapping_sum_with_mask_scalar(vals, mask)
131 }
132}
133
134#[cfg(feature = "simd")]
135impl WrappingSum for i128 {
136 fn wrapping_sum(vals: &[Self]) -> Self {
137 vals.iter().copied().fold(0, |a, b| a.wrapping_add(b))
138 }
139
140 fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self {
141 wrapping_sum_with_mask_scalar(vals, mask)
142 }
143}
144
145pub trait WrappingSum: Sized {
146 fn wrapping_sum(vals: &[Self]) -> Self;
147 fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self;
148}
149
150pub fn wrapping_sum_arr<T>(arr: &PrimitiveArray<T>) -> T
151where
152 T: NativeType + WrappingSum,
153{
154 let validity = arr.validity().filter(|_| arr.null_count() > 0);
155 if let Some(mask) = validity {
156 WrappingSum::wrapping_sum_with_validity(arr.values(), &BitMask::from_bitmap(mask))
157 } else {
158 WrappingSum::wrapping_sum(arr.values())
159 }
160}