1use std::mem::ManuallyDrop;
19
20use arrow::array::{Array, ArrayRef, FixedSizeListArray, PrimitiveArray, StaticArray};
21use arrow::bitmap::MutableBitmap;
22use arrow::compute::utils::combine_validities_and;
23use arrow::datatypes::reshape::{Dimension, ReshapeDimension};
24use arrow::datatypes::{ArrowDataType, IdxArr, PhysicalType};
25use arrow::legacy::prelude::FromData;
26use arrow::with_match_primitive_type;
27use polars_utils::itertools::Itertools;
28
29use super::Index;
30use crate::gather::bitmap::{take_bitmap_nulls_unchecked, take_bitmap_unchecked};
31
32fn get_stride_and_leaf_type(dtype: &ArrowDataType, size: usize) -> (usize, &ArrowDataType) {
33 if let ArrowDataType::FixedSizeList(inner, size_inner) = dtype {
34 get_stride_and_leaf_type(inner.dtype(), *size_inner * size)
35 } else {
36 (size, dtype)
37 }
38}
39
40fn get_leaves(array: &FixedSizeListArray) -> &dyn Array {
41 if let Some(array) = array.values().as_any().downcast_ref::<FixedSizeListArray>() {
42 get_leaves(array)
43 } else {
44 &**array.values()
45 }
46}
47
48fn get_buffer_and_size(array: &dyn Array) -> (&[u8], usize) {
49 match array.dtype().to_physical_type() {
50 PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
51
52 let arr = array.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
53 let values = arr.values();
54 (bytemuck::cast_slice(values), size_of::<$T>())
55
56 }),
57 _ => {
58 unimplemented!()
59 },
60 }
61}
62
63unsafe fn from_buffer(mut buf: ManuallyDrop<Vec<u8>>, dtype: &ArrowDataType) -> ArrayRef {
64 match dtype.to_physical_type() {
65 PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
66 let ptr = buf.as_mut_ptr();
67 let len_units = buf.len();
68 let cap_units = buf.capacity();
69
70 let buf = Vec::from_raw_parts(
71 ptr as *mut $T,
72 len_units / size_of::<$T>(),
73 cap_units / size_of::<$T>(),
74 );
75
76 PrimitiveArray::from_data_default(buf.into(), None).boxed()
77
78 }),
79 _ => {
80 unimplemented!()
81 },
82 }
83}
84
85unsafe fn aligned_vec(dt: &ArrowDataType, n_bytes: usize) -> Vec<u8> {
86 match dt.to_physical_type() {
87 PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
88
89 let n_units = (n_bytes / size_of::<$T>()) + 1;
90
91 let mut aligned: Vec<$T> = Vec::with_capacity(n_units);
92
93 let ptr = aligned.as_mut_ptr();
94 let len_units = aligned.len();
95 let cap_units = aligned.capacity();
96
97 std::mem::forget(aligned);
98
99 Vec::from_raw_parts(
100 ptr as *mut u8,
101 len_units * size_of::<$T>(),
102 cap_units * size_of::<$T>(),
103 )
104
105 }),
106 _ => {
107 unimplemented!()
108 },
109 }
110}
111
112fn arr_no_validities_recursive(arr: &dyn Array) -> bool {
113 arr.validity().is_none()
114 && arr
115 .as_any()
116 .downcast_ref::<FixedSizeListArray>()
117 .is_none_or(|x| arr_no_validities_recursive(x.values().as_ref()))
118}
119
120pub(super) unsafe fn take_unchecked(values: &FixedSizeListArray, indices: &IdxArr) -> ArrayRef {
122 let (stride, leaf_type) = get_stride_and_leaf_type(values.dtype(), 1);
123 if leaf_type.to_physical_type().is_primitive()
124 && arr_no_validities_recursive(values.values().as_ref())
125 {
126 let leaves = get_leaves(values);
127
128 let (leaves_buf, leave_size) = get_buffer_and_size(leaves);
129 let bytes_per_element = leave_size * stride;
130
131 let n_idx = indices.len();
132 let total_bytes = bytes_per_element * n_idx;
133
134 let mut buf = ManuallyDrop::new(aligned_vec(leaves.dtype(), total_bytes));
135 let dst = buf.spare_capacity_mut();
136
137 let mut count = 0;
138 let outer_validity = if indices.null_count() == 0 {
139 for i in indices.values().iter() {
140 let i = i.to_usize();
141
142 std::ptr::copy_nonoverlapping(
143 leaves_buf.as_ptr().add(i * bytes_per_element),
144 dst.as_mut_ptr().add(count * bytes_per_element) as *mut _,
145 bytes_per_element,
146 );
147 count += 1;
148 }
149 None
150 } else {
151 let mut new_validity = MutableBitmap::with_capacity(indices.len());
152 new_validity.extend_constant(indices.len(), true);
153 for i in indices.iter() {
154 if let Some(i) = i {
155 let i = i.to_usize();
156 std::ptr::copy_nonoverlapping(
157 leaves_buf.as_ptr().add(i * bytes_per_element),
158 dst.as_mut_ptr().add(count * bytes_per_element) as *mut _,
159 bytes_per_element,
160 );
161 } else {
162 new_validity.set_unchecked(count, false);
163 std::ptr::write_bytes(
164 dst.as_mut_ptr().add(count * bytes_per_element) as *mut _,
165 0,
166 bytes_per_element,
167 );
168 }
169
170 count += 1;
171 }
172 Some(new_validity.freeze())
173 };
174
175 assert_eq!(count * bytes_per_element, total_bytes);
176 buf.set_len(total_bytes);
177
178 let outer_validity = combine_validities_and(
179 outer_validity.as_ref(),
180 values
181 .validity()
182 .map(|x| {
183 if indices.has_nulls() {
184 take_bitmap_nulls_unchecked(x, indices)
185 } else {
186 take_bitmap_unchecked(x, indices.as_slice().unwrap())
187 }
188 })
189 .as_ref(),
190 );
191
192 let leaves = from_buffer(buf, leaves.dtype());
193 let mut shape = values.get_dims();
194 shape[0] = Dimension::new(indices.len() as _);
195 let shape = shape
196 .into_iter()
197 .map(ReshapeDimension::Specified)
198 .collect_vec();
199
200 FixedSizeListArray::from_shape(leaves.clone(), &shape)
201 .unwrap()
202 .with_validity(outer_validity)
203 } else {
204 super::take_unchecked_impl_generic(values, indices, &FixedSizeListArray::new_null).boxed()
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use arrow::array::StaticArray;
211 use arrow::datatypes::ArrowDataType;
212
213 #[test]
215 fn test_arr_gather_nulls_outer_validity_19482() {
216 use arrow::array::{FixedSizeListArray, Int64Array, PrimitiveArray};
217 use arrow::bitmap::Bitmap;
218 use arrow::datatypes::reshape::{Dimension, ReshapeDimension};
219 use polars_utils::IdxSize;
220
221 use super::take_unchecked;
222
223 unsafe {
224 let dyn_arr = FixedSizeListArray::from_shape(
225 Box::new(Int64Array::from_slice([1, 2, 3, 4])),
226 &[
227 ReshapeDimension::Specified(Dimension::new(2)),
228 ReshapeDimension::Specified(Dimension::new(2)),
229 ],
230 )
231 .unwrap()
232 .with_validity(Some(Bitmap::from_iter([true, false]))); let arr = dyn_arr
235 .as_any()
236 .downcast_ref::<FixedSizeListArray>()
237 .unwrap();
238
239 assert_eq!(
240 [arr.validity().is_some(), arr.values().validity().is_some()],
241 [true, false]
242 );
243
244 assert_eq!(
245 take_unchecked(arr, &PrimitiveArray::<IdxSize>::from_slice([0, 1])),
246 dyn_arr
247 )
248 }
249 }
250
251 #[test]
252 fn test_arr_gather_nulls_inner_validity() {
253 use arrow::array::{FixedSizeListArray, Int64Array, PrimitiveArray};
254 use arrow::datatypes::reshape::{Dimension, ReshapeDimension};
255 use polars_utils::IdxSize;
256
257 use super::take_unchecked;
258
259 unsafe {
260 let dyn_arr = FixedSizeListArray::from_shape(
261 Box::new(Int64Array::full_null(4, ArrowDataType::Int64)),
262 &[
263 ReshapeDimension::Specified(Dimension::new(2)),
264 ReshapeDimension::Specified(Dimension::new(2)),
265 ],
266 )
267 .unwrap(); let arr = dyn_arr
270 .as_any()
271 .downcast_ref::<FixedSizeListArray>()
272 .unwrap();
273
274 assert_eq!(
275 [arr.validity().is_some(), arr.values().validity().is_some()],
276 [false, true]
277 );
278
279 assert_eq!(
280 take_unchecked(arr, &PrimitiveArray::<IdxSize>::from_slice([0, 1])),
281 dyn_arr
282 )
283 }
284 }
285}