polars_compute/if_then_else/
mod.rs

1use std::mem::MaybeUninit;
2
3use arrow::array::{Array, PrimitiveArray};
4use arrow::bitmap::utils::SlicesIterator;
5use arrow::bitmap::{self, Bitmap};
6use arrow::datatypes::ArrowDataType;
7
8use crate::NotSimdPrimitive;
9
10mod array;
11mod boolean;
12mod list;
13mod scalar;
14#[cfg(feature = "simd")]
15mod simd;
16mod view;
17
18pub trait IfThenElseKernel: Sized + Array {
19    type Scalar<'a>;
20
21    fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self;
22    fn if_then_else_broadcast_true(
23        mask: &Bitmap,
24        if_true: Self::Scalar<'_>,
25        if_false: &Self,
26    ) -> Self;
27    fn if_then_else_broadcast_false(
28        mask: &Bitmap,
29        if_true: &Self,
30        if_false: Self::Scalar<'_>,
31    ) -> Self;
32    fn if_then_else_broadcast_both(
33        dtype: ArrowDataType,
34        mask: &Bitmap,
35        if_true: Self::Scalar<'_>,
36        if_false: Self::Scalar<'_>,
37    ) -> Self;
38}
39
40impl<T: NotSimdPrimitive> IfThenElseKernel for PrimitiveArray<T> {
41    type Scalar<'a> = T;
42
43    fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {
44        let values = if_then_else_loop(
45            mask,
46            if_true.values(),
47            if_false.values(),
48            scalar::if_then_else_scalar_rest,
49            scalar::if_then_else_scalar_64,
50        );
51        let validity = if_then_else_validity(mask, if_true.validity(), if_false.validity());
52        PrimitiveArray::from_vec(values).with_validity(validity)
53    }
54
55    fn if_then_else_broadcast_true(
56        mask: &Bitmap,
57        if_true: Self::Scalar<'_>,
58        if_false: &Self,
59    ) -> Self {
60        let values = if_then_else_loop_broadcast_false(
61            true,
62            mask,
63            if_false.values(),
64            if_true,
65            scalar::if_then_else_broadcast_false_scalar_64,
66        );
67        let validity = if_then_else_validity(mask, None, if_false.validity());
68        PrimitiveArray::from_vec(values).with_validity(validity)
69    }
70
71    fn if_then_else_broadcast_false(
72        mask: &Bitmap,
73        if_true: &Self,
74        if_false: Self::Scalar<'_>,
75    ) -> Self {
76        let values = if_then_else_loop_broadcast_false(
77            false,
78            mask,
79            if_true.values(),
80            if_false,
81            scalar::if_then_else_broadcast_false_scalar_64,
82        );
83        let validity = if_then_else_validity(mask, if_true.validity(), None);
84        PrimitiveArray::from_vec(values).with_validity(validity)
85    }
86
87    fn if_then_else_broadcast_both(
88        _dtype: ArrowDataType,
89        mask: &Bitmap,
90        if_true: Self::Scalar<'_>,
91        if_false: Self::Scalar<'_>,
92    ) -> Self {
93        let values = if_then_else_loop_broadcast_both(
94            mask,
95            if_true,
96            if_false,
97            scalar::if_then_else_broadcast_both_scalar_64,
98        );
99        PrimitiveArray::from_vec(values)
100    }
101}
102
103pub fn if_then_else_validity(
104    mask: &Bitmap,
105    if_true: Option<&Bitmap>,
106    if_false: Option<&Bitmap>,
107) -> Option<Bitmap> {
108    match (if_true, if_false) {
109        (None, None) => None,
110        (None, Some(f)) => Some(mask | f),
111        (Some(t), None) => Some(bitmap::binary(mask, t, |m, t| !m | t)),
112        (Some(t), Some(f)) => Some(bitmap::ternary(mask, t, f, |m, t, f| (m & t) | (!m & f))),
113    }
114}
115
116fn if_then_else_extend<G, ET: Fn(&mut G, usize, usize), EF: Fn(&mut G, usize, usize)>(
117    growable: &mut G,
118    mask: &Bitmap,
119    extend_true: ET,
120    extend_false: EF,
121) {
122    let mut last_true_end = 0;
123    for (start, len) in SlicesIterator::new(mask) {
124        if start != last_true_end {
125            extend_false(growable, last_true_end, start - last_true_end);
126        };
127        extend_true(growable, start, len);
128        last_true_end = start + len;
129    }
130    if last_true_end != mask.len() {
131        extend_false(growable, last_true_end, mask.len() - last_true_end)
132    }
133}
134
135fn if_then_else_loop<T, F, F64>(
136    mask: &Bitmap,
137    if_true: &[T],
138    if_false: &[T],
139    process_var: F,
140    process_chunk: F64,
141) -> Vec<T>
142where
143    T: Copy,
144    F: Fn(u64, &[T], &[T], &mut [MaybeUninit<T>]),
145    F64: Fn(u64, &[T; 64], &[T; 64], &mut [MaybeUninit<T>; 64]),
146{
147    assert_eq!(mask.len(), if_true.len());
148    assert_eq!(mask.len(), if_false.len());
149
150    let mut ret = Vec::with_capacity(mask.len());
151    let out = &mut ret.spare_capacity_mut()[..mask.len()];
152
153    // Handle prefix.
154    let aligned = mask.aligned::<u64>();
155    let (start_true, rest_true) = if_true.split_at(aligned.prefix_bitlen());
156    let (start_false, rest_false) = if_false.split_at(aligned.prefix_bitlen());
157    let (start_out, rest_out) = out.split_at_mut(aligned.prefix_bitlen());
158    if aligned.prefix_bitlen() > 0 {
159        process_var(aligned.prefix(), start_true, start_false, start_out);
160    }
161
162    // Handle bulk.
163    let mut true_chunks = rest_true.chunks_exact(64);
164    let mut false_chunks = rest_false.chunks_exact(64);
165    let mut out_chunks = rest_out.chunks_exact_mut(64);
166    let combined = true_chunks
167        .by_ref()
168        .zip(false_chunks.by_ref())
169        .zip(out_chunks.by_ref());
170    for (i, ((tc, fc), oc)) in combined.enumerate() {
171        let m = unsafe { *aligned.bulk().get_unchecked(i) };
172        process_chunk(
173            m,
174            tc.try_into().unwrap(),
175            fc.try_into().unwrap(),
176            oc.try_into().unwrap(),
177        );
178    }
179
180    // Handle suffix.
181    if aligned.suffix_bitlen() > 0 {
182        process_var(
183            aligned.suffix(),
184            true_chunks.remainder(),
185            false_chunks.remainder(),
186            out_chunks.into_remainder(),
187        );
188    }
189
190    unsafe {
191        ret.set_len(mask.len());
192    }
193    ret
194}
195
196fn if_then_else_loop_broadcast_false<T, F64>(
197    invert_mask: bool, // Allows code reuse for both false and true broadcasts.
198    mask: &Bitmap,
199    if_true: &[T],
200    if_false: T,
201    process_chunk: F64,
202) -> Vec<T>
203where
204    T: Copy,
205    F64: Fn(u64, &[T; 64], T, &mut [MaybeUninit<T>; 64]),
206{
207    assert_eq!(mask.len(), if_true.len());
208
209    let mut ret = Vec::with_capacity(mask.len());
210    let out = &mut ret.spare_capacity_mut()[..mask.len()];
211
212    // XOR with all 1's inverts the mask.
213    let xor_inverter = if invert_mask { u64::MAX } else { 0 };
214
215    // Handle prefix.
216    let aligned = mask.aligned::<u64>();
217    let (start_true, rest_true) = if_true.split_at(aligned.prefix_bitlen());
218    let (start_out, rest_out) = out.split_at_mut(aligned.prefix_bitlen());
219    if aligned.prefix_bitlen() > 0 {
220        scalar::if_then_else_broadcast_false_scalar_rest(
221            aligned.prefix() ^ xor_inverter,
222            start_true,
223            if_false,
224            start_out,
225        );
226    }
227
228    // Handle bulk.
229    let mut true_chunks = rest_true.chunks_exact(64);
230    let mut out_chunks = rest_out.chunks_exact_mut(64);
231    let combined = true_chunks.by_ref().zip(out_chunks.by_ref());
232    for (i, (tc, oc)) in combined.enumerate() {
233        let m = unsafe { *aligned.bulk().get_unchecked(i) } ^ xor_inverter;
234        process_chunk(m, tc.try_into().unwrap(), if_false, oc.try_into().unwrap());
235    }
236
237    // Handle suffix.
238    if aligned.suffix_bitlen() > 0 {
239        scalar::if_then_else_broadcast_false_scalar_rest(
240            aligned.suffix() ^ xor_inverter,
241            true_chunks.remainder(),
242            if_false,
243            out_chunks.into_remainder(),
244        );
245    }
246
247    unsafe {
248        ret.set_len(mask.len());
249    }
250    ret
251}
252
253fn if_then_else_loop_broadcast_both<T, F64>(
254    mask: &Bitmap,
255    if_true: T,
256    if_false: T,
257    generate_chunk: F64,
258) -> Vec<T>
259where
260    T: Copy,
261    F64: Fn(u64, T, T, &mut [MaybeUninit<T>; 64]),
262{
263    let mut ret = Vec::with_capacity(mask.len());
264    let out = &mut ret.spare_capacity_mut()[..mask.len()];
265
266    // Handle prefix.
267    let aligned = mask.aligned::<u64>();
268    let (start_out, rest_out) = out.split_at_mut(aligned.prefix_bitlen());
269    scalar::if_then_else_broadcast_both_scalar_rest(aligned.prefix(), if_true, if_false, start_out);
270
271    // Handle bulk.
272    let mut out_chunks = rest_out.chunks_exact_mut(64);
273    for (i, oc) in out_chunks.by_ref().enumerate() {
274        let m = unsafe { *aligned.bulk().get_unchecked(i) };
275        generate_chunk(m, if_true, if_false, oc.try_into().unwrap());
276    }
277
278    // Handle suffix.
279    if aligned.suffix_bitlen() > 0 {
280        scalar::if_then_else_broadcast_both_scalar_rest(
281            aligned.suffix(),
282            if_true,
283            if_false,
284            out_chunks.into_remainder(),
285        );
286    }
287
288    unsafe {
289        ret.set_len(mask.len());
290    }
291    ret
292}