polars_compute/
float_sum.rs1use std::ops::{Add, IndexMut};
2#[cfg(feature = "simd")]
3use std::simd::{prelude::*, *};
4
5use arrow::array::{Array, PrimitiveArray};
6use arrow::bitmap::bitmask::BitMask;
7use arrow::bitmap::Bitmap;
8use arrow::types::NativeType;
9use num_traits::{AsPrimitive, Float};
10
11const STRIPE: usize = 16;
12const PAIRWISE_RECURSION_LIMIT: usize = 128;
13
14#[cfg(feature = "simd")]
16pub trait SimdCastGeneric<const N: usize>
17where
18 LaneCount<N>: SupportedLaneCount,
19{
20 fn cast_generic<U: SimdCast>(self) -> Simd<U, N>;
21}
22
23macro_rules! impl_cast_custom {
24 ($_type:ty) => {
25 #[cfg(feature = "simd")]
26 impl<const N: usize> SimdCastGeneric<N> for Simd<$_type, N>
27 where
28 LaneCount<N>: SupportedLaneCount,
29 {
30 fn cast_generic<U: SimdCast>(self) -> Simd<U, N> {
31 self.cast::<U>()
32 }
33 }
34 };
35}
36
37impl_cast_custom!(u8);
38impl_cast_custom!(u16);
39impl_cast_custom!(u32);
40impl_cast_custom!(u64);
41impl_cast_custom!(i8);
42impl_cast_custom!(i16);
43impl_cast_custom!(i32);
44impl_cast_custom!(i64);
45impl_cast_custom!(f32);
46impl_cast_custom!(f64);
47
48fn vector_horizontal_sum<V, T>(mut v: V) -> T
49where
50 V: IndexMut<usize, Output = T>,
51 T: Add<T, Output = T> + Sized + Copy,
52{
53 let mut width = STRIPE;
59 while width > 4 {
60 for j in 0..width / 2 {
61 v[j] = v[j] + v[width / 2 + j];
62 }
63 width /= 2;
64 }
65
66 (v[0] + v[2]) + (v[1] + v[3])
67}
68
69pub trait SumBlock<F> {
71 fn sum_block_vectorized(&self) -> F;
72 fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F;
73}
74
75#[cfg(feature = "simd")]
76impl<T, F> SumBlock<F> for [T; PAIRWISE_RECURSION_LIMIT]
77where
78 T: SimdElement,
79 F: SimdElement + SimdCast + Add<Output = F> + Default,
80 Simd<T, STRIPE>: SimdCastGeneric<STRIPE>,
81 Simd<F, STRIPE>: std::iter::Sum,
82{
83 fn sum_block_vectorized(&self) -> F {
84 let vsum = self
85 .chunks_exact(STRIPE)
86 .map(|a| Simd::<T, STRIPE>::from_slice(a).cast_generic::<F>())
87 .sum::<Simd<F, STRIPE>>();
88 vector_horizontal_sum(vsum)
89 }
90
91 fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
92 let zero = Simd::default();
93 let vsum = self
94 .chunks_exact(STRIPE)
95 .enumerate()
96 .map(|(i, a)| {
97 let m: Mask<_, STRIPE> = mask.get_simd(i * STRIPE);
98 m.select(Simd::from_slice(a).cast_generic::<F>(), zero)
99 })
100 .sum::<Simd<F, STRIPE>>();
101 vector_horizontal_sum(vsum)
102 }
103}
104
105#[cfg(feature = "simd")]
106impl<F> SumBlock<F> for [i128; PAIRWISE_RECURSION_LIMIT]
107where
108 i128: AsPrimitive<F>,
109 F: Float + std::iter::Sum + 'static,
110{
111 fn sum_block_vectorized(&self) -> F {
112 self.iter().map(|x| x.as_()).sum()
113 }
114
115 fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
116 self.iter()
117 .enumerate()
118 .map(|(idx, x)| if mask.get(idx) { x.as_() } else { F::zero() })
119 .sum()
120 }
121}
122
123#[cfg(not(feature = "simd"))]
124impl<T, F> SumBlock<F> for [T; PAIRWISE_RECURSION_LIMIT]
125where
126 T: AsPrimitive<F> + 'static,
127 F: Default + Add<Output = F> + Copy + 'static,
128{
129 fn sum_block_vectorized(&self) -> F {
130 let mut vsum = [F::default(); STRIPE];
131 for chunk in self.chunks_exact(STRIPE) {
132 for j in 0..STRIPE {
133 vsum[j] = vsum[j] + chunk[j].as_();
134 }
135 }
136 vector_horizontal_sum(vsum)
137 }
138
139 fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
140 let mut vsum = [F::default(); STRIPE];
141 for (i, chunk) in self.chunks_exact(STRIPE).enumerate() {
142 for j in 0..STRIPE {
143 let addend = if mask.get(i * STRIPE + j) {
145 chunk[j].as_()
146 } else {
147 F::default()
148 };
149 vsum[j] = vsum[j] + addend;
150 }
151 }
152 vector_horizontal_sum(vsum)
153 }
154}
155
156unsafe fn pairwise_sum<F, T>(f: &[T]) -> F
158where
159 [T; PAIRWISE_RECURSION_LIMIT]: SumBlock<F>,
160 F: Add<Output = F>,
161{
162 debug_assert!(!f.is_empty() && f.len() % PAIRWISE_RECURSION_LIMIT == 0);
163
164 let block: Option<&[T; PAIRWISE_RECURSION_LIMIT]> = f.try_into().ok();
165 if let Some(block) = block {
166 return block.sum_block_vectorized();
167 }
168
169 unsafe {
174 let blocks = f.len() / PAIRWISE_RECURSION_LIMIT;
175 let left_len = (blocks / 2) * PAIRWISE_RECURSION_LIMIT;
176 let (left, right) = (f.get_unchecked(..left_len), f.get_unchecked(left_len..));
177 pairwise_sum(left) + pairwise_sum(right)
178 }
179}
180
181unsafe fn pairwise_sum_with_mask<F, T>(f: &[T], mask: BitMask<'_>) -> F
184where
185 [T; PAIRWISE_RECURSION_LIMIT]: SumBlock<F>,
186 F: Add<Output = F>,
187{
188 debug_assert!(!f.is_empty() && f.len() % PAIRWISE_RECURSION_LIMIT == 0);
189 debug_assert!(f.len() == mask.len());
190
191 let block: Option<&[T; PAIRWISE_RECURSION_LIMIT]> = f.try_into().ok();
192 if let Some(block) = block {
193 return block.sum_block_vectorized_with_mask(mask);
194 }
195
196 unsafe {
198 let blocks = f.len() / PAIRWISE_RECURSION_LIMIT;
199 let left_len = (blocks / 2) * PAIRWISE_RECURSION_LIMIT;
200 let (left, right) = (f.get_unchecked(..left_len), f.get_unchecked(left_len..));
201 let (left_mask, right_mask) = mask.split_at_unchecked(left_len);
202 pairwise_sum_with_mask(left, left_mask) + pairwise_sum_with_mask(right, right_mask)
203 }
204}
205
206pub trait FloatSum<F>: Sized {
207 fn sum(f: &[Self]) -> F;
208 fn sum_with_validity(f: &[Self], validity: &Bitmap) -> F;
209}
210
211impl<T, F> FloatSum<F> for T
212where
213 F: Float + std::iter::Sum + 'static,
214 T: AsPrimitive<F>,
215 [T; PAIRWISE_RECURSION_LIMIT]: SumBlock<F>,
216{
217 fn sum(f: &[Self]) -> F {
218 let remainder = f.len() % PAIRWISE_RECURSION_LIMIT;
219 let (rest, main) = f.split_at(remainder);
220 let mainsum = if f.len() > remainder {
221 unsafe { pairwise_sum(main) }
222 } else {
223 F::zero()
224 };
225 let restsum: F = rest.iter().map(|x| x.as_()).sum();
227 mainsum + restsum
228 }
229
230 fn sum_with_validity(f: &[Self], validity: &Bitmap) -> F {
231 let mask = BitMask::from_bitmap(validity);
232 assert!(f.len() == mask.len());
233
234 let remainder = f.len() % PAIRWISE_RECURSION_LIMIT;
235 let (rest, main) = f.split_at(remainder);
236 let (rest_mask, main_mask) = mask.split_at(remainder);
237 let mainsum = if f.len() > remainder {
238 unsafe { pairwise_sum_with_mask(main, main_mask) }
239 } else {
240 F::zero()
241 };
242 let restsum: F = rest
244 .iter()
245 .enumerate()
246 .map(|(i, x)| {
247 if rest_mask.get(i) {
249 x.as_()
250 } else {
251 F::zero()
252 }
253 })
254 .sum();
255 mainsum + restsum
256 }
257}
258
259pub fn sum_arr_as_f32<T>(arr: &PrimitiveArray<T>) -> f32
260where
261 T: NativeType + FloatSum<f32>,
262{
263 let validity = arr.validity().filter(|_| arr.null_count() > 0);
264 if let Some(mask) = validity {
265 FloatSum::sum_with_validity(arr.values(), mask)
266 } else {
267 FloatSum::sum(arr.values())
268 }
269}
270
271pub fn sum_arr_as_f64<T>(arr: &PrimitiveArray<T>) -> f64
272where
273 T: NativeType + FloatSum<f64>,
274{
275 let validity = arr.validity().filter(|_| arr.null_count() > 0);
276 if let Some(mask) = validity {
277 FloatSum::sum_with_validity(arr.values(), mask)
278 } else {
279 FloatSum::sum(arr.values())
280 }
281}