polars_compute/if_then_else/
view.rs1use std::mem::MaybeUninit;
2use std::ops::Deref;
3use std::sync::Arc;
4
5use arrow::array::{Array, BinaryViewArray, MutablePlBinary, Utf8ViewArray, View};
6use arrow::bitmap::Bitmap;
7use arrow::buffer::Buffer;
8use arrow::datatypes::ArrowDataType;
9use polars_utils::aliases::{InitHashMaps, PlHashSet};
10
11use super::IfThenElseKernel;
12use crate::if_then_else::scalar::if_then_else_broadcast_both_scalar_64;
13
14fn make_buffer_and_views<const N: usize>(
17 strings: [&[u8]; N],
18 buffer_idx: u32,
19) -> ([View; N], Option<Buffer<u8>>) {
20 let mut buf_data = Vec::new();
21 let views = strings.map(|s| {
22 let offset = buf_data.len().try_into().unwrap();
23 if s.len() > 12 {
24 buf_data.extend(s);
25 }
26 View::new_from_bytes(s, buffer_idx, offset)
27 });
28 let buf = (!buf_data.is_empty()).then(|| buf_data.into());
29 (views, buf)
30}
31
32fn has_duplicate_buffers(bufs: &[Buffer<u8>]) -> bool {
33 let mut has_duplicate_buffers = false;
34 let mut bufset = PlHashSet::new();
35 for buf in bufs {
36 if !bufset.insert(buf.as_ptr()) {
37 has_duplicate_buffers = true;
38 break;
39 }
40 }
41 has_duplicate_buffers
42}
43
44impl IfThenElseKernel for BinaryViewArray {
45 type Scalar<'a> = &'a [u8];
46
47 fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {
48 let combined_buffers: Arc<_>;
49 let false_buffer_idx_offset: u32;
50 let mut has_duplicate_bufs = false;
51 if Arc::ptr_eq(if_true.data_buffers(), if_false.data_buffers()) {
52 combined_buffers = if_true.data_buffers().clone();
54 false_buffer_idx_offset = 0;
55 } else {
56 let true_buffers = if_true.data_buffers().iter().cloned();
58 let false_buffers = if_false.data_buffers().iter().cloned();
59
60 combined_buffers = true_buffers.chain(false_buffers).collect();
61 has_duplicate_bufs = has_duplicate_buffers(&combined_buffers);
62 false_buffer_idx_offset = if_true.data_buffers().len() as u32;
63 }
64
65 let views = super::if_then_else_loop(
66 mask,
67 if_true.views(),
68 if_false.views(),
69 |m, t, f, o| if_then_else_view_rest(m, t, f, o, false_buffer_idx_offset),
70 |m, t, f, o| if_then_else_view_64(m, t, f, o, false_buffer_idx_offset),
71 );
72
73 let validity = super::if_then_else_validity(mask, if_true.validity(), if_false.validity());
74
75 let mut builder = MutablePlBinary::with_capacity(views.len());
76
77 if has_duplicate_bufs {
78 unsafe {
79 builder.extend_non_null_views_unchecked_dedupe(
80 views.into_iter(),
81 combined_buffers.deref(),
82 )
83 };
84 } else {
85 unsafe {
86 builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref())
87 };
88 }
89 builder
90 .freeze_with_dtype(if_true.dtype().clone())
91 .with_validity(validity)
92 }
93
94 fn if_then_else_broadcast_true(
95 mask: &Bitmap,
96 if_true: Self::Scalar<'_>,
97 if_false: &Self,
98 ) -> Self {
99 let false_buffers = if_false.data_buffers().iter().cloned();
101 let true_buffer_idx_offset: u32 = if_false.data_buffers().len() as u32;
102 let ([true_view], true_buffer) = make_buffer_and_views([if_true], true_buffer_idx_offset);
103 let combined_buffers: Arc<_> = false_buffers.chain(true_buffer).collect();
104
105 let views = super::if_then_else_loop_broadcast_false(
106 true, mask,
108 if_false.views(),
109 true_view,
110 if_then_else_broadcast_false_view_64,
111 );
112
113 let validity = super::if_then_else_validity(mask, None, if_false.validity());
114
115 let mut builder = MutablePlBinary::with_capacity(views.len());
116
117 unsafe {
118 if has_duplicate_buffers(&combined_buffers) {
119 builder.extend_non_null_views_unchecked_dedupe(
120 views.into_iter(),
121 combined_buffers.deref(),
122 )
123 } else {
124 builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref())
125 }
126 }
127 builder
128 .freeze_with_dtype(if_false.dtype().clone())
129 .with_validity(validity)
130 }
131
132 fn if_then_else_broadcast_false(
133 mask: &Bitmap,
134 if_true: &Self,
135 if_false: Self::Scalar<'_>,
136 ) -> Self {
137 let true_buffers = if_true.data_buffers().iter().cloned();
139 let false_buffer_idx_offset: u32 = if_true.data_buffers().len() as u32;
140 let ([false_view], false_buffer) =
141 make_buffer_and_views([if_false], false_buffer_idx_offset);
142 let combined_buffers: Arc<_> = true_buffers.chain(false_buffer).collect();
143
144 let views = super::if_then_else_loop_broadcast_false(
145 false,
146 mask,
147 if_true.views(),
148 false_view,
149 if_then_else_broadcast_false_view_64,
150 );
151
152 let validity = super::if_then_else_validity(mask, if_true.validity(), None);
153
154 let mut builder = MutablePlBinary::with_capacity(views.len());
155 unsafe {
156 if has_duplicate_buffers(&combined_buffers) {
157 builder.extend_non_null_views_unchecked_dedupe(
158 views.into_iter(),
159 combined_buffers.deref(),
160 )
161 } else {
162 builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref())
163 }
164 };
165 builder
166 .freeze_with_dtype(if_true.dtype().clone())
167 .with_validity(validity)
168 }
169
170 fn if_then_else_broadcast_both(
171 dtype: ArrowDataType,
172 mask: &Bitmap,
173 if_true: Self::Scalar<'_>,
174 if_false: Self::Scalar<'_>,
175 ) -> Self {
176 let ([true_view, false_view], buffer) = make_buffer_and_views([if_true, if_false], 0);
177 let buffers: Arc<_> = buffer.into_iter().collect();
178 let views = super::if_then_else_loop_broadcast_both(
179 mask,
180 true_view,
181 false_view,
182 if_then_else_broadcast_both_scalar_64,
183 );
184
185 let mut builder = MutablePlBinary::with_capacity(views.len());
186 unsafe {
187 if has_duplicate_buffers(&buffers) {
188 builder.extend_non_null_views_unchecked_dedupe(views.into_iter(), buffers.deref())
189 } else {
190 builder.extend_non_null_views_unchecked(views.into_iter(), buffers.deref())
191 }
192 };
193 builder.freeze_with_dtype(dtype)
194 }
195}
196
197impl IfThenElseKernel for Utf8ViewArray {
198 type Scalar<'a> = &'a str;
199
200 fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {
201 let ret =
202 IfThenElseKernel::if_then_else(mask, &if_true.to_binview(), &if_false.to_binview());
203 unsafe { ret.to_utf8view_unchecked() }
204 }
205
206 fn if_then_else_broadcast_true(
207 mask: &Bitmap,
208 if_true: Self::Scalar<'_>,
209 if_false: &Self,
210 ) -> Self {
211 let ret = IfThenElseKernel::if_then_else_broadcast_true(
212 mask,
213 if_true.as_bytes(),
214 &if_false.to_binview(),
215 );
216 unsafe { ret.to_utf8view_unchecked() }
217 }
218
219 fn if_then_else_broadcast_false(
220 mask: &Bitmap,
221 if_true: &Self,
222 if_false: Self::Scalar<'_>,
223 ) -> Self {
224 let ret = IfThenElseKernel::if_then_else_broadcast_false(
225 mask,
226 &if_true.to_binview(),
227 if_false.as_bytes(),
228 );
229 unsafe { ret.to_utf8view_unchecked() }
230 }
231
232 fn if_then_else_broadcast_both(
233 dtype: ArrowDataType,
234 mask: &Bitmap,
235 if_true: Self::Scalar<'_>,
236 if_false: Self::Scalar<'_>,
237 ) -> Self {
238 let ret: BinaryViewArray = IfThenElseKernel::if_then_else_broadcast_both(
239 dtype,
240 mask,
241 if_true.as_bytes(),
242 if_false.as_bytes(),
243 );
244 unsafe { ret.to_utf8view_unchecked() }
245 }
246}
247
248pub fn if_then_else_view_rest(
249 mask: u64,
250 if_true: &[View],
251 if_false: &[View],
252 out: &mut [MaybeUninit<View>],
253 false_buffer_idx_offset: u32,
254) {
255 assert!(if_true.len() <= out.len()); let true_it = if_true.iter();
257 let false_it = if_false.iter();
258 for (i, (t, f)) in true_it.zip(false_it).enumerate() {
259 let m = (mask >> i) & 1 != 0;
262 let src = if m { t } else { f };
263 let mut v = *src;
264 let offset = if m | (v.length <= 12) {
265 0
267 } else {
268 false_buffer_idx_offset
269 };
270 v.buffer_idx += offset;
271 out[i] = MaybeUninit::new(v);
272 }
273}
274
275pub fn if_then_else_view_64(
276 mask: u64,
277 if_true: &[View; 64],
278 if_false: &[View; 64],
279 out: &mut [MaybeUninit<View>; 64],
280 false_buffer_idx_offset: u32,
281) {
282 if_then_else_view_rest(mask, if_true, if_false, out, false_buffer_idx_offset)
283}
284
285pub fn if_then_else_broadcast_false_view_64(
288 mask: u64,
289 if_true: &[View; 64],
290 if_false: View,
291 out: &mut [MaybeUninit<View>; 64],
292) {
293 assert!(if_true.len() == out.len()); for (i, t) in if_true.iter().enumerate() {
295 let src = if (mask >> i) & 1 != 0 { t } else { &if_false };
296 out[i] = MaybeUninit::new(*src);
297 }
298}