polars_core/chunked_array/ops/
gather.rs

1use arrow::bitmap::bitmask::BitMask;
2use arrow::bitmap::Bitmap;
3use polars_compute::gather::take_unchecked;
4use polars_error::polars_ensure;
5use polars_utils::index::check_bounds;
6
7use crate::prelude::*;
8use crate::series::IsSorted;
9
10const BINARY_SEARCH_LIMIT: usize = 8;
11
12pub fn check_bounds_nulls(idx: &PrimitiveArray<IdxSize>, len: IdxSize) -> PolarsResult<()> {
13    let mask = BitMask::from_bitmap(idx.validity().unwrap());
14
15    // We iterate in chunks to make the inner loop branch-free.
16    for (block_idx, block) in idx.values().chunks(32).enumerate() {
17        let mut in_bounds = 0;
18        for (i, x) in block.iter().enumerate() {
19            in_bounds |= ((*x < len) as u32) << i;
20        }
21        let m = mask.get_u32(32 * block_idx);
22        polars_ensure!(m == m & in_bounds, ComputeError: "gather indices are out of bounds");
23    }
24    Ok(())
25}
26
27pub fn check_bounds_ca(indices: &IdxCa, len: IdxSize) -> PolarsResult<()> {
28    let all_valid = indices.downcast_iter().all(|a| {
29        if a.null_count() == 0 {
30            check_bounds(a.values(), len).is_ok()
31        } else {
32            check_bounds_nulls(a, len).is_ok()
33        }
34    });
35    polars_ensure!(all_valid, OutOfBounds: "gather indices are out of bounds");
36    Ok(())
37}
38
39impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTake<I> for ChunkedArray<T>
40where
41    ChunkedArray<T>: ChunkTakeUnchecked<I>,
42{
43    /// Gather values from ChunkedArray by index.
44    fn take(&self, indices: &I) -> PolarsResult<Self> {
45        check_bounds(indices.as_ref(), self.len() as IdxSize)?;
46
47        // SAFETY: we just checked the indices are valid.
48        Ok(unsafe { self.take_unchecked(indices) })
49    }
50}
51
52impl<T: PolarsDataType> ChunkTake<IdxCa> for ChunkedArray<T>
53where
54    ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>,
55{
56    /// Gather values from ChunkedArray by index.
57    fn take(&self, indices: &IdxCa) -> PolarsResult<Self> {
58        check_bounds_ca(indices, self.len() as IdxSize)?;
59
60        // SAFETY: we just checked the indices are valid.
61        Ok(unsafe { self.take_unchecked(indices) })
62    }
63}
64
65/// Computes cumulative lengths for efficient branchless binary search
66/// lookup. The first element is always 0, and the last length of arrs
67/// is always ignored (as we already checked that all indices are
68/// in-bounds we don't need to check against the last length).
69fn cumulative_lengths<A: StaticArray>(arrs: &[&A]) -> [IdxSize; BINARY_SEARCH_LIMIT] {
70    assert!(arrs.len() <= BINARY_SEARCH_LIMIT);
71    let mut ret = [IdxSize::MAX; BINARY_SEARCH_LIMIT];
72    ret[0] = 0;
73    for i in 1..arrs.len() {
74        ret[i] = ret[i - 1] + arrs[i - 1].len() as IdxSize;
75    }
76    ret
77}
78
79#[rustfmt::skip]
80#[inline]
81fn resolve_chunked_idx(idx: IdxSize, cumlens: &[IdxSize; BINARY_SEARCH_LIMIT]) -> (usize, usize) {
82    // Branchless bitwise binary search.
83    let mut chunk_idx = 0;
84    chunk_idx += if idx >= cumlens[chunk_idx + 0b100] { 0b0100 } else { 0 };
85    chunk_idx += if idx >= cumlens[chunk_idx + 0b010] { 0b0010 } else { 0 };
86    chunk_idx += if idx >= cumlens[chunk_idx + 0b001] { 0b0001 } else { 0 };
87    (chunk_idx, (idx - cumlens[chunk_idx]) as usize)
88}
89
90#[inline]
91unsafe fn target_value_unchecked<'a, A: StaticArray>(
92    targets: &[&'a A],
93    cumlens: &[IdxSize; BINARY_SEARCH_LIMIT],
94    idx: IdxSize,
95) -> A::ValueT<'a> {
96    let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens);
97    let arr = targets.get_unchecked(chunk_idx);
98    arr.value_unchecked(arr_idx)
99}
100
101#[inline]
102unsafe fn target_get_unchecked<'a, A: StaticArray>(
103    targets: &[&'a A],
104    cumlens: &[IdxSize; BINARY_SEARCH_LIMIT],
105    idx: IdxSize,
106) -> Option<A::ValueT<'a>> {
107    let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens);
108    let arr = targets.get_unchecked(chunk_idx);
109    arr.get_unchecked(arr_idx)
110}
111
112unsafe fn gather_idx_array_unchecked<A: StaticArray>(
113    dtype: ArrowDataType,
114    targets: &[&A],
115    has_nulls: bool,
116    indices: &[IdxSize],
117) -> A {
118    let it = indices.iter().copied();
119    if targets.len() == 1 {
120        let target = targets.first().unwrap();
121        if has_nulls {
122            it.map(|i| target.get_unchecked(i as usize))
123                .collect_arr_trusted_with_dtype(dtype)
124        } else if let Some(sl) = target.as_slice() {
125            // Avoid the Arc overhead from value_unchecked.
126            it.map(|i| sl.get_unchecked(i as usize).clone())
127                .collect_arr_trusted_with_dtype(dtype)
128        } else {
129            it.map(|i| target.value_unchecked(i as usize))
130                .collect_arr_trusted_with_dtype(dtype)
131        }
132    } else {
133        let cumlens = cumulative_lengths(targets);
134        if has_nulls {
135            it.map(|i| target_get_unchecked(targets, &cumlens, i))
136                .collect_arr_trusted_with_dtype(dtype)
137        } else {
138            it.map(|i| target_value_unchecked(targets, &cumlens, i))
139                .collect_arr_trusted_with_dtype(dtype)
140        }
141    }
142}
143
144impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ChunkedArray<T>
145where
146    T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
147{
148    /// Gather values from ChunkedArray by index.
149    unsafe fn take_unchecked(&self, indices: &I) -> Self {
150        let rechunked;
151        let mut ca = self;
152        if self.chunks().len() > BINARY_SEARCH_LIMIT {
153            rechunked = self.rechunk();
154            ca = &rechunked;
155        }
156        let targets: Vec<_> = ca.downcast_iter().collect();
157        let arr = gather_idx_array_unchecked(
158            ca.dtype().to_arrow(CompatLevel::newest()),
159            &targets,
160            ca.null_count() > 0,
161            indices.as_ref(),
162        );
163        ChunkedArray::from_chunk_iter_like(ca, [arr])
164    }
165}
166
167pub fn _update_gather_sorted_flag(sorted_arr: IsSorted, sorted_idx: IsSorted) -> IsSorted {
168    use crate::series::IsSorted::*;
169    match (sorted_arr, sorted_idx) {
170        (_, Not) => Not,
171        (Not, _) => Not,
172        (Ascending, Ascending) => Ascending,
173        (Ascending, Descending) => Descending,
174        (Descending, Ascending) => Descending,
175        (Descending, Descending) => Ascending,
176    }
177}
178
179impl<T: PolarsDataType> ChunkTakeUnchecked<IdxCa> for ChunkedArray<T>
180where
181    T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
182{
183    /// Gather values from ChunkedArray by index.
184    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
185        let rechunked;
186        let mut ca = self;
187        if self.chunks().len() > BINARY_SEARCH_LIMIT {
188            rechunked = self.rechunk();
189            ca = &rechunked;
190        }
191        let targets_have_nulls = ca.null_count() > 0;
192        let targets: Vec<_> = ca.downcast_iter().collect();
193
194        let chunks = indices.downcast_iter().map(|idx_arr| {
195            let dtype = ca.dtype().to_arrow(CompatLevel::newest());
196            if idx_arr.null_count() == 0 {
197                gather_idx_array_unchecked(dtype, &targets, targets_have_nulls, idx_arr.values())
198            } else if targets.len() == 1 {
199                let target = targets.first().unwrap();
200                if targets_have_nulls {
201                    idx_arr
202                        .iter()
203                        .map(|i| target.get_unchecked(*i? as usize))
204                        .collect_arr_trusted_with_dtype(dtype)
205                } else {
206                    idx_arr
207                        .iter()
208                        .map(|i| Some(target.value_unchecked(*i? as usize)))
209                        .collect_arr_trusted_with_dtype(dtype)
210                }
211            } else {
212                let cumlens = cumulative_lengths(&targets);
213                if targets_have_nulls {
214                    idx_arr
215                        .iter()
216                        .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
217                        .collect_arr_trusted_with_dtype(dtype)
218                } else {
219                    idx_arr
220                        .iter()
221                        .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
222                        .collect_arr_trusted_with_dtype(dtype)
223                }
224            }
225        });
226
227        let mut out = ChunkedArray::from_chunk_iter_like(ca, chunks);
228        let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
229
230        out.set_sorted_flag(sorted_flag);
231        out
232    }
233}
234
235impl ChunkTakeUnchecked<IdxCa> for BinaryChunked {
236    /// Gather values from ChunkedArray by index.
237    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
238        let rechunked = self.rechunk();
239        let indices = indices.rechunk();
240        let indices_arr = indices.downcast_iter().next().unwrap();
241        let chunks = rechunked
242            .chunks()
243            .iter()
244            .map(|arr| take_unchecked(arr.as_ref(), indices_arr))
245            .collect::<Vec<_>>();
246
247        let mut out = ChunkedArray::from_chunks(self.name().clone(), chunks);
248
249        let sorted_flag =
250            _update_gather_sorted_flag(self.is_sorted_flag(), indices.is_sorted_flag());
251        out.set_sorted_flag(sorted_flag);
252        out
253    }
254}
255
256impl ChunkTakeUnchecked<IdxCa> for StringChunked {
257    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
258        self.as_binary()
259            .take_unchecked(indices)
260            .to_string_unchecked()
261    }
262}
263
264impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for BinaryChunked {
265    /// Gather values from ChunkedArray by index.
266    unsafe fn take_unchecked(&self, indices: &I) -> Self {
267        let indices = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
268        self.take_unchecked(&indices)
269    }
270}
271
272impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for StringChunked {
273    /// Gather values from ChunkedArray by index.
274    unsafe fn take_unchecked(&self, indices: &I) -> Self {
275        self.as_binary()
276            .take_unchecked(indices)
277            .to_string_unchecked()
278    }
279}
280
281#[cfg(feature = "dtype-struct")]
282impl ChunkTakeUnchecked<IdxCa> for StructChunked {
283    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
284        let a = self.rechunk();
285        let index = indices.rechunk();
286
287        let chunks = a
288            .downcast_iter()
289            .zip(index.downcast_iter())
290            .map(|(arr, idx)| take_unchecked(arr, idx))
291            .collect::<Vec<_>>();
292        self.copy_with_chunks(chunks)
293    }
294}
295
296#[cfg(feature = "dtype-struct")]
297impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for StructChunked {
298    unsafe fn take_unchecked(&self, indices: &I) -> Self {
299        let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
300        self.take_unchecked(&idx)
301    }
302}
303
304impl IdxCa {
305    pub fn with_nullable_idx<T, F: FnOnce(&IdxCa) -> T>(idx: &[NullableIdxSize], f: F) -> T {
306        let validity: Bitmap = idx.iter().map(|idx| !idx.is_null_idx()).collect_trusted();
307        let idx = bytemuck::cast_slice::<_, IdxSize>(idx);
308        let arr = unsafe { arrow::ffi::mmap::slice(idx) };
309        let arr = arr.with_validity_typed(Some(validity));
310        let ca = IdxCa::with_chunk(PlSmallStr::EMPTY, arr);
311
312        f(&ca)
313    }
314}
315
316#[cfg(feature = "dtype-array")]
317impl ChunkTakeUnchecked<IdxCa> for ArrayChunked {
318    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
319        let chunks = vec![take_unchecked(
320            &self.rechunk().downcast_into_array(),
321            &indices.rechunk().downcast_into_array(),
322        )];
323        self.copy_with_chunks(chunks)
324    }
325}
326
327#[cfg(feature = "dtype-array")]
328impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ArrayChunked {
329    unsafe fn take_unchecked(&self, indices: &I) -> Self {
330        let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
331        self.take_unchecked(&idx)
332    }
333}
334
335impl ChunkTakeUnchecked<IdxCa> for ListChunked {
336    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
337        let chunks = vec![take_unchecked(
338            &self.rechunk().downcast_into_array(),
339            &indices.rechunk().downcast_into_array(),
340        )];
341        self.copy_with_chunks(chunks)
342    }
343}
344
345impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ListChunked {
346    unsafe fn take_unchecked(&self, indices: &I) -> Self {
347        let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
348        self.take_unchecked(&idx)
349    }
350}