polars_compute/gather/
primitive.rs

1use arrow::array::PrimitiveArray;
2use arrow::bitmap::utils::set_bit_unchecked;
3use arrow::bitmap::{Bitmap, MutableBitmap};
4use arrow::legacy::index::IdxArr;
5use arrow::legacy::utils::CustomIterTools;
6use arrow::types::NativeType;
7use polars_utils::index::NullCount;
8
9pub(super) unsafe fn take_values_and_validity_unchecked<T: NativeType>(
10    values: &[T],
11    validity_values: Option<&Bitmap>,
12    indices: &IdxArr,
13) -> (Vec<T>, Option<Bitmap>) {
14    let index_values = indices.values().as_slice();
15
16    let null_count = validity_values.map(|b| b.unset_bits()).unwrap_or(0);
17
18    // first take the values, these are always needed
19    let values: Vec<T> = if indices.null_count() == 0 {
20        index_values
21            .iter()
22            .map(|idx| *values.get_unchecked(*idx as usize))
23            .collect_trusted()
24    } else {
25        indices
26            .iter()
27            .map(|idx| match idx {
28                Some(idx) => *values.get_unchecked(*idx as usize),
29                None => T::default(),
30            })
31            .collect_trusted()
32    };
33
34    if null_count > 0 {
35        let validity_values = validity_values.unwrap();
36        // the validity buffer we will fill with all valid. And we unset the ones that are null
37        // in later checks
38        // this is in the assumption that most values will be valid.
39        // Maybe we could add another branch based on the null count
40        let mut validity = MutableBitmap::with_capacity(indices.len());
41        validity.extend_constant(indices.len(), true);
42        let validity_slice = validity.as_mut_slice();
43
44        if let Some(validity_indices) = indices.validity().as_ref() {
45            index_values.iter().enumerate().for_each(|(i, idx)| {
46                // i is iteration count
47                // idx is the index that we take from the values array.
48                let idx = *idx as usize;
49                if !validity_indices.get_bit_unchecked(i) || !validity_values.get_bit_unchecked(idx)
50                {
51                    set_bit_unchecked(validity_slice, i, false);
52                }
53            });
54        } else {
55            index_values.iter().enumerate().for_each(|(i, idx)| {
56                let idx = *idx as usize;
57                if !validity_values.get_bit_unchecked(idx) {
58                    set_bit_unchecked(validity_slice, i, false);
59                }
60            });
61        };
62        (values, Some(validity.freeze()))
63    } else {
64        (values, indices.validity().cloned())
65    }
66}
67
68/// Take kernel for single chunk with nulls and arrow array as index that may have nulls.
69/// # Safety
70/// caller must ensure indices are in bounds
71pub unsafe fn take_primitive_unchecked<T: NativeType>(
72    arr: &PrimitiveArray<T>,
73    indices: &IdxArr,
74) -> PrimitiveArray<T> {
75    let (values, validity) =
76        take_values_and_validity_unchecked(arr.values(), arr.validity(), indices);
77    PrimitiveArray::new_unchecked(arr.dtype().clone(), values.into(), validity)
78}