1use std::mem::MaybeUninit;
2
3use arrow::array::{Array, PrimitiveArray};
4use arrow::bitmap::utils::SlicesIterator;
5use arrow::bitmap::{self, Bitmap};
6use arrow::datatypes::ArrowDataType;
7
8use crate::NotSimdPrimitive;
9
10mod array;
11mod boolean;
12mod list;
13mod scalar;
14#[cfg(feature = "simd")]
15mod simd;
16mod view;
17
18pub trait IfThenElseKernel: Sized + Array {
19 type Scalar<'a>;
20
21 fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self;
22 fn if_then_else_broadcast_true(
23 mask: &Bitmap,
24 if_true: Self::Scalar<'_>,
25 if_false: &Self,
26 ) -> Self;
27 fn if_then_else_broadcast_false(
28 mask: &Bitmap,
29 if_true: &Self,
30 if_false: Self::Scalar<'_>,
31 ) -> Self;
32 fn if_then_else_broadcast_both(
33 dtype: ArrowDataType,
34 mask: &Bitmap,
35 if_true: Self::Scalar<'_>,
36 if_false: Self::Scalar<'_>,
37 ) -> Self;
38}
39
40impl<T: NotSimdPrimitive> IfThenElseKernel for PrimitiveArray<T> {
41 type Scalar<'a> = T;
42
43 fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {
44 let values = if_then_else_loop(
45 mask,
46 if_true.values(),
47 if_false.values(),
48 scalar::if_then_else_scalar_rest,
49 scalar::if_then_else_scalar_64,
50 );
51 let validity = if_then_else_validity(mask, if_true.validity(), if_false.validity());
52 PrimitiveArray::from_vec(values).with_validity(validity)
53 }
54
55 fn if_then_else_broadcast_true(
56 mask: &Bitmap,
57 if_true: Self::Scalar<'_>,
58 if_false: &Self,
59 ) -> Self {
60 let values = if_then_else_loop_broadcast_false(
61 true,
62 mask,
63 if_false.values(),
64 if_true,
65 scalar::if_then_else_broadcast_false_scalar_64,
66 );
67 let validity = if_then_else_validity(mask, None, if_false.validity());
68 PrimitiveArray::from_vec(values).with_validity(validity)
69 }
70
71 fn if_then_else_broadcast_false(
72 mask: &Bitmap,
73 if_true: &Self,
74 if_false: Self::Scalar<'_>,
75 ) -> Self {
76 let values = if_then_else_loop_broadcast_false(
77 false,
78 mask,
79 if_true.values(),
80 if_false,
81 scalar::if_then_else_broadcast_false_scalar_64,
82 );
83 let validity = if_then_else_validity(mask, if_true.validity(), None);
84 PrimitiveArray::from_vec(values).with_validity(validity)
85 }
86
87 fn if_then_else_broadcast_both(
88 _dtype: ArrowDataType,
89 mask: &Bitmap,
90 if_true: Self::Scalar<'_>,
91 if_false: Self::Scalar<'_>,
92 ) -> Self {
93 let values = if_then_else_loop_broadcast_both(
94 mask,
95 if_true,
96 if_false,
97 scalar::if_then_else_broadcast_both_scalar_64,
98 );
99 PrimitiveArray::from_vec(values)
100 }
101}
102
103pub fn if_then_else_validity(
104 mask: &Bitmap,
105 if_true: Option<&Bitmap>,
106 if_false: Option<&Bitmap>,
107) -> Option<Bitmap> {
108 match (if_true, if_false) {
109 (None, None) => None,
110 (None, Some(f)) => Some(mask | f),
111 (Some(t), None) => Some(bitmap::binary(mask, t, |m, t| !m | t)),
112 (Some(t), Some(f)) => Some(bitmap::ternary(mask, t, f, |m, t, f| (m & t) | (!m & f))),
113 }
114}
115
116fn if_then_else_extend<G, ET: Fn(&mut G, usize, usize), EF: Fn(&mut G, usize, usize)>(
117 growable: &mut G,
118 mask: &Bitmap,
119 extend_true: ET,
120 extend_false: EF,
121) {
122 let mut last_true_end = 0;
123 for (start, len) in SlicesIterator::new(mask) {
124 if start != last_true_end {
125 extend_false(growable, last_true_end, start - last_true_end);
126 };
127 extend_true(growable, start, len);
128 last_true_end = start + len;
129 }
130 if last_true_end != mask.len() {
131 extend_false(growable, last_true_end, mask.len() - last_true_end)
132 }
133}
134
135fn if_then_else_loop<T, F, F64>(
136 mask: &Bitmap,
137 if_true: &[T],
138 if_false: &[T],
139 process_var: F,
140 process_chunk: F64,
141) -> Vec<T>
142where
143 T: Copy,
144 F: Fn(u64, &[T], &[T], &mut [MaybeUninit<T>]),
145 F64: Fn(u64, &[T; 64], &[T; 64], &mut [MaybeUninit<T>; 64]),
146{
147 assert_eq!(mask.len(), if_true.len());
148 assert_eq!(mask.len(), if_false.len());
149
150 let mut ret = Vec::with_capacity(mask.len());
151 let out = &mut ret.spare_capacity_mut()[..mask.len()];
152
153 let aligned = mask.aligned::<u64>();
155 let (start_true, rest_true) = if_true.split_at(aligned.prefix_bitlen());
156 let (start_false, rest_false) = if_false.split_at(aligned.prefix_bitlen());
157 let (start_out, rest_out) = out.split_at_mut(aligned.prefix_bitlen());
158 if aligned.prefix_bitlen() > 0 {
159 process_var(aligned.prefix(), start_true, start_false, start_out);
160 }
161
162 let mut true_chunks = rest_true.chunks_exact(64);
164 let mut false_chunks = rest_false.chunks_exact(64);
165 let mut out_chunks = rest_out.chunks_exact_mut(64);
166 let combined = true_chunks
167 .by_ref()
168 .zip(false_chunks.by_ref())
169 .zip(out_chunks.by_ref());
170 for (i, ((tc, fc), oc)) in combined.enumerate() {
171 let m = unsafe { *aligned.bulk().get_unchecked(i) };
172 process_chunk(
173 m,
174 tc.try_into().unwrap(),
175 fc.try_into().unwrap(),
176 oc.try_into().unwrap(),
177 );
178 }
179
180 if aligned.suffix_bitlen() > 0 {
182 process_var(
183 aligned.suffix(),
184 true_chunks.remainder(),
185 false_chunks.remainder(),
186 out_chunks.into_remainder(),
187 );
188 }
189
190 unsafe {
191 ret.set_len(mask.len());
192 }
193 ret
194}
195
196fn if_then_else_loop_broadcast_false<T, F64>(
197 invert_mask: bool, mask: &Bitmap,
199 if_true: &[T],
200 if_false: T,
201 process_chunk: F64,
202) -> Vec<T>
203where
204 T: Copy,
205 F64: Fn(u64, &[T; 64], T, &mut [MaybeUninit<T>; 64]),
206{
207 assert_eq!(mask.len(), if_true.len());
208
209 let mut ret = Vec::with_capacity(mask.len());
210 let out = &mut ret.spare_capacity_mut()[..mask.len()];
211
212 let xor_inverter = if invert_mask { u64::MAX } else { 0 };
214
215 let aligned = mask.aligned::<u64>();
217 let (start_true, rest_true) = if_true.split_at(aligned.prefix_bitlen());
218 let (start_out, rest_out) = out.split_at_mut(aligned.prefix_bitlen());
219 if aligned.prefix_bitlen() > 0 {
220 scalar::if_then_else_broadcast_false_scalar_rest(
221 aligned.prefix() ^ xor_inverter,
222 start_true,
223 if_false,
224 start_out,
225 );
226 }
227
228 let mut true_chunks = rest_true.chunks_exact(64);
230 let mut out_chunks = rest_out.chunks_exact_mut(64);
231 let combined = true_chunks.by_ref().zip(out_chunks.by_ref());
232 for (i, (tc, oc)) in combined.enumerate() {
233 let m = unsafe { *aligned.bulk().get_unchecked(i) } ^ xor_inverter;
234 process_chunk(m, tc.try_into().unwrap(), if_false, oc.try_into().unwrap());
235 }
236
237 if aligned.suffix_bitlen() > 0 {
239 scalar::if_then_else_broadcast_false_scalar_rest(
240 aligned.suffix() ^ xor_inverter,
241 true_chunks.remainder(),
242 if_false,
243 out_chunks.into_remainder(),
244 );
245 }
246
247 unsafe {
248 ret.set_len(mask.len());
249 }
250 ret
251}
252
253fn if_then_else_loop_broadcast_both<T, F64>(
254 mask: &Bitmap,
255 if_true: T,
256 if_false: T,
257 generate_chunk: F64,
258) -> Vec<T>
259where
260 T: Copy,
261 F64: Fn(u64, T, T, &mut [MaybeUninit<T>; 64]),
262{
263 let mut ret = Vec::with_capacity(mask.len());
264 let out = &mut ret.spare_capacity_mut()[..mask.len()];
265
266 let aligned = mask.aligned::<u64>();
268 let (start_out, rest_out) = out.split_at_mut(aligned.prefix_bitlen());
269 scalar::if_then_else_broadcast_both_scalar_rest(aligned.prefix(), if_true, if_false, start_out);
270
271 let mut out_chunks = rest_out.chunks_exact_mut(64);
273 for (i, oc) in out_chunks.by_ref().enumerate() {
274 let m = unsafe { *aligned.bulk().get_unchecked(i) };
275 generate_chunk(m, if_true, if_false, oc.try_into().unwrap());
276 }
277
278 if aligned.suffix_bitlen() > 0 {
280 scalar::if_then_else_broadcast_both_scalar_rest(
281 aligned.suffix(),
282 if_true,
283 if_false,
284 out_chunks.into_remainder(),
285 );
286 }
287
288 unsafe {
289 ret.set_len(mask.len());
290 }
291 ret
292}