polars_compute/gather/
fixed_size_list.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::mem::ManuallyDrop;
19
20use arrow::array::{Array, ArrayRef, FixedSizeListArray, PrimitiveArray, StaticArray};
21use arrow::bitmap::MutableBitmap;
22use arrow::compute::utils::combine_validities_and;
23use arrow::datatypes::reshape::{Dimension, ReshapeDimension};
24use arrow::datatypes::{ArrowDataType, IdxArr, PhysicalType};
25use arrow::legacy::prelude::FromData;
26use arrow::with_match_primitive_type;
27use polars_utils::itertools::Itertools;
28
29use super::Index;
30use crate::gather::bitmap::{take_bitmap_nulls_unchecked, take_bitmap_unchecked};
31
32fn get_stride_and_leaf_type(dtype: &ArrowDataType, size: usize) -> (usize, &ArrowDataType) {
33    if let ArrowDataType::FixedSizeList(inner, size_inner) = dtype {
34        get_stride_and_leaf_type(inner.dtype(), *size_inner * size)
35    } else {
36        (size, dtype)
37    }
38}
39
40fn get_leaves(array: &FixedSizeListArray) -> &dyn Array {
41    if let Some(array) = array.values().as_any().downcast_ref::<FixedSizeListArray>() {
42        get_leaves(array)
43    } else {
44        &**array.values()
45    }
46}
47
48fn get_buffer_and_size(array: &dyn Array) -> (&[u8], usize) {
49    match array.dtype().to_physical_type() {
50        PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
51
52            let arr = array.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
53            let values = arr.values();
54            (bytemuck::cast_slice(values), size_of::<$T>())
55
56        }),
57        _ => {
58            unimplemented!()
59        },
60    }
61}
62
63unsafe fn from_buffer(mut buf: ManuallyDrop<Vec<u8>>, dtype: &ArrowDataType) -> ArrayRef {
64    match dtype.to_physical_type() {
65        PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
66            let ptr = buf.as_mut_ptr();
67            let len_units = buf.len();
68            let cap_units = buf.capacity();
69
70            let buf = Vec::from_raw_parts(
71                ptr as *mut $T,
72                len_units / size_of::<$T>(),
73                cap_units / size_of::<$T>(),
74            );
75
76            PrimitiveArray::from_data_default(buf.into(), None).boxed()
77
78        }),
79        _ => {
80            unimplemented!()
81        },
82    }
83}
84
85unsafe fn aligned_vec(dt: &ArrowDataType, n_bytes: usize) -> Vec<u8> {
86    match dt.to_physical_type() {
87        PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
88
89        let n_units = (n_bytes / size_of::<$T>()) + 1;
90
91        let mut aligned: Vec<$T> = Vec::with_capacity(n_units);
92
93        let ptr = aligned.as_mut_ptr();
94        let len_units = aligned.len();
95        let cap_units = aligned.capacity();
96
97        std::mem::forget(aligned);
98
99        Vec::from_raw_parts(
100            ptr as *mut u8,
101            len_units * size_of::<$T>(),
102            cap_units * size_of::<$T>(),
103        )
104
105            }),
106        _ => {
107            unimplemented!()
108        },
109    }
110}
111
112fn arr_no_validities_recursive(arr: &dyn Array) -> bool {
113    arr.validity().is_none()
114        && arr
115            .as_any()
116            .downcast_ref::<FixedSizeListArray>()
117            .is_none_or(|x| arr_no_validities_recursive(x.values().as_ref()))
118}
119
120/// `take` implementation for FixedSizeListArrays
121pub(super) unsafe fn take_unchecked(values: &FixedSizeListArray, indices: &IdxArr) -> ArrayRef {
122    let (stride, leaf_type) = get_stride_and_leaf_type(values.dtype(), 1);
123    if leaf_type.to_physical_type().is_primitive()
124        && arr_no_validities_recursive(values.values().as_ref())
125    {
126        let leaves = get_leaves(values);
127
128        let (leaves_buf, leave_size) = get_buffer_and_size(leaves);
129        let bytes_per_element = leave_size * stride;
130
131        let n_idx = indices.len();
132        let total_bytes = bytes_per_element * n_idx;
133
134        let mut buf = ManuallyDrop::new(aligned_vec(leaves.dtype(), total_bytes));
135        let dst = buf.spare_capacity_mut();
136
137        let mut count = 0;
138        let outer_validity = if indices.null_count() == 0 {
139            for i in indices.values().iter() {
140                let i = i.to_usize();
141
142                std::ptr::copy_nonoverlapping(
143                    leaves_buf.as_ptr().add(i * bytes_per_element),
144                    dst.as_mut_ptr().add(count * bytes_per_element) as *mut _,
145                    bytes_per_element,
146                );
147                count += 1;
148            }
149            None
150        } else {
151            let mut new_validity = MutableBitmap::with_capacity(indices.len());
152            new_validity.extend_constant(indices.len(), true);
153            for i in indices.iter() {
154                if let Some(i) = i {
155                    let i = i.to_usize();
156                    std::ptr::copy_nonoverlapping(
157                        leaves_buf.as_ptr().add(i * bytes_per_element),
158                        dst.as_mut_ptr().add(count * bytes_per_element) as *mut _,
159                        bytes_per_element,
160                    );
161                } else {
162                    new_validity.set_unchecked(count, false);
163                    std::ptr::write_bytes(
164                        dst.as_mut_ptr().add(count * bytes_per_element) as *mut _,
165                        0,
166                        bytes_per_element,
167                    );
168                }
169
170                count += 1;
171            }
172            Some(new_validity.freeze())
173        };
174
175        assert_eq!(count * bytes_per_element, total_bytes);
176        buf.set_len(total_bytes);
177
178        let outer_validity = combine_validities_and(
179            outer_validity.as_ref(),
180            values
181                .validity()
182                .map(|x| {
183                    if indices.has_nulls() {
184                        take_bitmap_nulls_unchecked(x, indices)
185                    } else {
186                        take_bitmap_unchecked(x, indices.as_slice().unwrap())
187                    }
188                })
189                .as_ref(),
190        );
191
192        let leaves = from_buffer(buf, leaves.dtype());
193        let mut shape = values.get_dims();
194        shape[0] = Dimension::new(indices.len() as _);
195        let shape = shape
196            .into_iter()
197            .map(ReshapeDimension::Specified)
198            .collect_vec();
199
200        FixedSizeListArray::from_shape(leaves.clone(), &shape)
201            .unwrap()
202            .with_validity(outer_validity)
203    } else {
204        super::take_unchecked_impl_generic(values, indices, &FixedSizeListArray::new_null).boxed()
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use arrow::array::StaticArray;
211    use arrow::datatypes::ArrowDataType;
212
213    /// Test gather for FixedSizeListArray with outer validity but no inner validities.
214    #[test]
215    fn test_arr_gather_nulls_outer_validity_19482() {
216        use arrow::array::{FixedSizeListArray, Int64Array, PrimitiveArray};
217        use arrow::bitmap::Bitmap;
218        use arrow::datatypes::reshape::{Dimension, ReshapeDimension};
219        use polars_utils::IdxSize;
220
221        use super::take_unchecked;
222
223        unsafe {
224            let dyn_arr = FixedSizeListArray::from_shape(
225                Box::new(Int64Array::from_slice([1, 2, 3, 4])),
226                &[
227                    ReshapeDimension::Specified(Dimension::new(2)),
228                    ReshapeDimension::Specified(Dimension::new(2)),
229                ],
230            )
231            .unwrap()
232            .with_validity(Some(Bitmap::from_iter([true, false]))); // FixedSizeListArray[[1, 2], None]
233
234            let arr = dyn_arr
235                .as_any()
236                .downcast_ref::<FixedSizeListArray>()
237                .unwrap();
238
239            assert_eq!(
240                [arr.validity().is_some(), arr.values().validity().is_some()],
241                [true, false]
242            );
243
244            assert_eq!(
245                take_unchecked(arr, &PrimitiveArray::<IdxSize>::from_slice([0, 1])),
246                dyn_arr
247            )
248        }
249    }
250
251    #[test]
252    fn test_arr_gather_nulls_inner_validity() {
253        use arrow::array::{FixedSizeListArray, Int64Array, PrimitiveArray};
254        use arrow::datatypes::reshape::{Dimension, ReshapeDimension};
255        use polars_utils::IdxSize;
256
257        use super::take_unchecked;
258
259        unsafe {
260            let dyn_arr = FixedSizeListArray::from_shape(
261                Box::new(Int64Array::full_null(4, ArrowDataType::Int64)),
262                &[
263                    ReshapeDimension::Specified(Dimension::new(2)),
264                    ReshapeDimension::Specified(Dimension::new(2)),
265                ],
266            )
267            .unwrap(); // FixedSizeListArray[[None, None], [None, None]]
268
269            let arr = dyn_arr
270                .as_any()
271                .downcast_ref::<FixedSizeListArray>()
272                .unwrap();
273
274            assert_eq!(
275                [arr.validity().is_some(), arr.values().validity().is_some()],
276                [false, true]
277            );
278
279            assert_eq!(
280                take_unchecked(arr, &PrimitiveArray::<IdxSize>::from_slice([0, 1])),
281                dyn_arr
282            )
283        }
284    }
285}