polars_compute/if_then_else/
view.rs

1use std::mem::MaybeUninit;
2use std::ops::Deref;
3use std::sync::Arc;
4
5use arrow::array::{Array, BinaryViewArray, MutablePlBinary, Utf8ViewArray, View};
6use arrow::bitmap::Bitmap;
7use arrow::buffer::Buffer;
8use arrow::datatypes::ArrowDataType;
9use polars_utils::aliases::{InitHashMaps, PlHashSet};
10
11use super::IfThenElseKernel;
12use crate::if_then_else::scalar::if_then_else_broadcast_both_scalar_64;
13
14// Makes a buffer and a set of views into that buffer from a set of strings.
15// Does not allocate a buffer if not necessary.
16fn make_buffer_and_views<const N: usize>(
17    strings: [&[u8]; N],
18    buffer_idx: u32,
19) -> ([View; N], Option<Buffer<u8>>) {
20    let mut buf_data = Vec::new();
21    let views = strings.map(|s| {
22        let offset = buf_data.len().try_into().unwrap();
23        if s.len() > 12 {
24            buf_data.extend(s);
25        }
26        View::new_from_bytes(s, buffer_idx, offset)
27    });
28    let buf = (!buf_data.is_empty()).then(|| buf_data.into());
29    (views, buf)
30}
31
32fn has_duplicate_buffers(bufs: &[Buffer<u8>]) -> bool {
33    let mut has_duplicate_buffers = false;
34    let mut bufset = PlHashSet::new();
35    for buf in bufs {
36        if !bufset.insert(buf.as_ptr()) {
37            has_duplicate_buffers = true;
38            break;
39        }
40    }
41    has_duplicate_buffers
42}
43
44impl IfThenElseKernel for BinaryViewArray {
45    type Scalar<'a> = &'a [u8];
46
47    fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {
48        let combined_buffers: Arc<_>;
49        let false_buffer_idx_offset: u32;
50        let mut has_duplicate_bufs = false;
51        if Arc::ptr_eq(if_true.data_buffers(), if_false.data_buffers()) {
52            // Share exact same buffers, no need to combine.
53            combined_buffers = if_true.data_buffers().clone();
54            false_buffer_idx_offset = 0;
55        } else {
56            // Put false buffers after true buffers.
57            let true_buffers = if_true.data_buffers().iter().cloned();
58            let false_buffers = if_false.data_buffers().iter().cloned();
59
60            combined_buffers = true_buffers.chain(false_buffers).collect();
61            has_duplicate_bufs = has_duplicate_buffers(&combined_buffers);
62            false_buffer_idx_offset = if_true.data_buffers().len() as u32;
63        }
64
65        let views = super::if_then_else_loop(
66            mask,
67            if_true.views(),
68            if_false.views(),
69            |m, t, f, o| if_then_else_view_rest(m, t, f, o, false_buffer_idx_offset),
70            |m, t, f, o| if_then_else_view_64(m, t, f, o, false_buffer_idx_offset),
71        );
72
73        let validity = super::if_then_else_validity(mask, if_true.validity(), if_false.validity());
74
75        let mut builder = MutablePlBinary::with_capacity(views.len());
76
77        if has_duplicate_bufs {
78            unsafe {
79                builder.extend_non_null_views_unchecked_dedupe(
80                    views.into_iter(),
81                    combined_buffers.deref(),
82                )
83            };
84        } else {
85            unsafe {
86                builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref())
87            };
88        }
89        builder
90            .freeze_with_dtype(if_true.dtype().clone())
91            .with_validity(validity)
92    }
93
94    fn if_then_else_broadcast_true(
95        mask: &Bitmap,
96        if_true: Self::Scalar<'_>,
97        if_false: &Self,
98    ) -> Self {
99        // It's cheaper if we put the false buffers first, that way we don't need to modify any views in the loop.
100        let false_buffers = if_false.data_buffers().iter().cloned();
101        let true_buffer_idx_offset: u32 = if_false.data_buffers().len() as u32;
102        let ([true_view], true_buffer) = make_buffer_and_views([if_true], true_buffer_idx_offset);
103        let combined_buffers: Arc<_> = false_buffers.chain(true_buffer).collect();
104
105        let views = super::if_then_else_loop_broadcast_false(
106            true, // Invert the mask so we effectively broadcast true.
107            mask,
108            if_false.views(),
109            true_view,
110            if_then_else_broadcast_false_view_64,
111        );
112
113        let validity = super::if_then_else_validity(mask, None, if_false.validity());
114
115        let mut builder = MutablePlBinary::with_capacity(views.len());
116
117        unsafe {
118            if has_duplicate_buffers(&combined_buffers) {
119                builder.extend_non_null_views_unchecked_dedupe(
120                    views.into_iter(),
121                    combined_buffers.deref(),
122                )
123            } else {
124                builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref())
125            }
126        }
127        builder
128            .freeze_with_dtype(if_false.dtype().clone())
129            .with_validity(validity)
130    }
131
132    fn if_then_else_broadcast_false(
133        mask: &Bitmap,
134        if_true: &Self,
135        if_false: Self::Scalar<'_>,
136    ) -> Self {
137        // It's cheaper if we put the true buffers first, that way we don't need to modify any views in the loop.
138        let true_buffers = if_true.data_buffers().iter().cloned();
139        let false_buffer_idx_offset: u32 = if_true.data_buffers().len() as u32;
140        let ([false_view], false_buffer) =
141            make_buffer_and_views([if_false], false_buffer_idx_offset);
142        let combined_buffers: Arc<_> = true_buffers.chain(false_buffer).collect();
143
144        let views = super::if_then_else_loop_broadcast_false(
145            false,
146            mask,
147            if_true.views(),
148            false_view,
149            if_then_else_broadcast_false_view_64,
150        );
151
152        let validity = super::if_then_else_validity(mask, if_true.validity(), None);
153
154        let mut builder = MutablePlBinary::with_capacity(views.len());
155        unsafe {
156            if has_duplicate_buffers(&combined_buffers) {
157                builder.extend_non_null_views_unchecked_dedupe(
158                    views.into_iter(),
159                    combined_buffers.deref(),
160                )
161            } else {
162                builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref())
163            }
164        };
165        builder
166            .freeze_with_dtype(if_true.dtype().clone())
167            .with_validity(validity)
168    }
169
170    fn if_then_else_broadcast_both(
171        dtype: ArrowDataType,
172        mask: &Bitmap,
173        if_true: Self::Scalar<'_>,
174        if_false: Self::Scalar<'_>,
175    ) -> Self {
176        let ([true_view, false_view], buffer) = make_buffer_and_views([if_true, if_false], 0);
177        let buffers: Arc<_> = buffer.into_iter().collect();
178        let views = super::if_then_else_loop_broadcast_both(
179            mask,
180            true_view,
181            false_view,
182            if_then_else_broadcast_both_scalar_64,
183        );
184
185        let mut builder = MutablePlBinary::with_capacity(views.len());
186        unsafe {
187            if has_duplicate_buffers(&buffers) {
188                builder.extend_non_null_views_unchecked_dedupe(views.into_iter(), buffers.deref())
189            } else {
190                builder.extend_non_null_views_unchecked(views.into_iter(), buffers.deref())
191            }
192        };
193        builder.freeze_with_dtype(dtype)
194    }
195}
196
197impl IfThenElseKernel for Utf8ViewArray {
198    type Scalar<'a> = &'a str;
199
200    fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {
201        let ret =
202            IfThenElseKernel::if_then_else(mask, &if_true.to_binview(), &if_false.to_binview());
203        unsafe { ret.to_utf8view_unchecked() }
204    }
205
206    fn if_then_else_broadcast_true(
207        mask: &Bitmap,
208        if_true: Self::Scalar<'_>,
209        if_false: &Self,
210    ) -> Self {
211        let ret = IfThenElseKernel::if_then_else_broadcast_true(
212            mask,
213            if_true.as_bytes(),
214            &if_false.to_binview(),
215        );
216        unsafe { ret.to_utf8view_unchecked() }
217    }
218
219    fn if_then_else_broadcast_false(
220        mask: &Bitmap,
221        if_true: &Self,
222        if_false: Self::Scalar<'_>,
223    ) -> Self {
224        let ret = IfThenElseKernel::if_then_else_broadcast_false(
225            mask,
226            &if_true.to_binview(),
227            if_false.as_bytes(),
228        );
229        unsafe { ret.to_utf8view_unchecked() }
230    }
231
232    fn if_then_else_broadcast_both(
233        dtype: ArrowDataType,
234        mask: &Bitmap,
235        if_true: Self::Scalar<'_>,
236        if_false: Self::Scalar<'_>,
237    ) -> Self {
238        let ret: BinaryViewArray = IfThenElseKernel::if_then_else_broadcast_both(
239            dtype,
240            mask,
241            if_true.as_bytes(),
242            if_false.as_bytes(),
243        );
244        unsafe { ret.to_utf8view_unchecked() }
245    }
246}
247
248pub fn if_then_else_view_rest(
249    mask: u64,
250    if_true: &[View],
251    if_false: &[View],
252    out: &mut [MaybeUninit<View>],
253    false_buffer_idx_offset: u32,
254) {
255    assert!(if_true.len() <= out.len()); // Removes bounds checks in inner loop.
256    let true_it = if_true.iter();
257    let false_it = if_false.iter();
258    for (i, (t, f)) in true_it.zip(false_it).enumerate() {
259        // Written like this, this loop *should* be branchless.
260        // Unfortunately we're still dependent on the compiler.
261        let m = (mask >> i) & 1 != 0;
262        let src = if m { t } else { f };
263        let mut v = *src;
264        let offset = if m | (v.length <= 12) {
265            // Yes, | instead of || is intentional.
266            0
267        } else {
268            false_buffer_idx_offset
269        };
270        v.buffer_idx += offset;
271        out[i] = MaybeUninit::new(v);
272    }
273}
274
275pub fn if_then_else_view_64(
276    mask: u64,
277    if_true: &[View; 64],
278    if_false: &[View; 64],
279    out: &mut [MaybeUninit<View>; 64],
280    false_buffer_idx_offset: u32,
281) {
282    if_then_else_view_rest(mask, if_true, if_false, out, false_buffer_idx_offset)
283}
284
285// Using the scalar variant of this works, but was slower, we want to select a source pointer and
286// then copy it. Using this version for the integers results in branches.
287pub fn if_then_else_broadcast_false_view_64(
288    mask: u64,
289    if_true: &[View; 64],
290    if_false: View,
291    out: &mut [MaybeUninit<View>; 64],
292) {
293    assert!(if_true.len() == out.len()); // Removes bounds checks in inner loop.
294    for (i, t) in if_true.iter().enumerate() {
295        let src = if (mask >> i) & 1 != 0 { t } else { &if_false };
296        out[i] = MaybeUninit::new(*src);
297    }
298}