polars_core/chunked_array/ops/
gather.rs1use arrow::bitmap::bitmask::BitMask;
2use arrow::bitmap::Bitmap;
3use polars_compute::gather::take_unchecked;
4use polars_error::polars_ensure;
5use polars_utils::index::check_bounds;
6
7use crate::prelude::*;
8use crate::series::IsSorted;
9
10const BINARY_SEARCH_LIMIT: usize = 8;
11
12pub fn check_bounds_nulls(idx: &PrimitiveArray<IdxSize>, len: IdxSize) -> PolarsResult<()> {
13 let mask = BitMask::from_bitmap(idx.validity().unwrap());
14
15 for (block_idx, block) in idx.values().chunks(32).enumerate() {
17 let mut in_bounds = 0;
18 for (i, x) in block.iter().enumerate() {
19 in_bounds |= ((*x < len) as u32) << i;
20 }
21 let m = mask.get_u32(32 * block_idx);
22 polars_ensure!(m == m & in_bounds, ComputeError: "gather indices are out of bounds");
23 }
24 Ok(())
25}
26
27pub fn check_bounds_ca(indices: &IdxCa, len: IdxSize) -> PolarsResult<()> {
28 let all_valid = indices.downcast_iter().all(|a| {
29 if a.null_count() == 0 {
30 check_bounds(a.values(), len).is_ok()
31 } else {
32 check_bounds_nulls(a, len).is_ok()
33 }
34 });
35 polars_ensure!(all_valid, OutOfBounds: "gather indices are out of bounds");
36 Ok(())
37}
38
39impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTake<I> for ChunkedArray<T>
40where
41 ChunkedArray<T>: ChunkTakeUnchecked<I>,
42{
43 fn take(&self, indices: &I) -> PolarsResult<Self> {
45 check_bounds(indices.as_ref(), self.len() as IdxSize)?;
46
47 Ok(unsafe { self.take_unchecked(indices) })
49 }
50}
51
52impl<T: PolarsDataType> ChunkTake<IdxCa> for ChunkedArray<T>
53where
54 ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>,
55{
56 fn take(&self, indices: &IdxCa) -> PolarsResult<Self> {
58 check_bounds_ca(indices, self.len() as IdxSize)?;
59
60 Ok(unsafe { self.take_unchecked(indices) })
62 }
63}
64
65fn cumulative_lengths<A: StaticArray>(arrs: &[&A]) -> [IdxSize; BINARY_SEARCH_LIMIT] {
70 assert!(arrs.len() <= BINARY_SEARCH_LIMIT);
71 let mut ret = [IdxSize::MAX; BINARY_SEARCH_LIMIT];
72 ret[0] = 0;
73 for i in 1..arrs.len() {
74 ret[i] = ret[i - 1] + arrs[i - 1].len() as IdxSize;
75 }
76 ret
77}
78
79#[rustfmt::skip]
80#[inline]
81fn resolve_chunked_idx(idx: IdxSize, cumlens: &[IdxSize; BINARY_SEARCH_LIMIT]) -> (usize, usize) {
82 let mut chunk_idx = 0;
84 chunk_idx += if idx >= cumlens[chunk_idx + 0b100] { 0b0100 } else { 0 };
85 chunk_idx += if idx >= cumlens[chunk_idx + 0b010] { 0b0010 } else { 0 };
86 chunk_idx += if idx >= cumlens[chunk_idx + 0b001] { 0b0001 } else { 0 };
87 (chunk_idx, (idx - cumlens[chunk_idx]) as usize)
88}
89
90#[inline]
91unsafe fn target_value_unchecked<'a, A: StaticArray>(
92 targets: &[&'a A],
93 cumlens: &[IdxSize; BINARY_SEARCH_LIMIT],
94 idx: IdxSize,
95) -> A::ValueT<'a> {
96 let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens);
97 let arr = targets.get_unchecked(chunk_idx);
98 arr.value_unchecked(arr_idx)
99}
100
101#[inline]
102unsafe fn target_get_unchecked<'a, A: StaticArray>(
103 targets: &[&'a A],
104 cumlens: &[IdxSize; BINARY_SEARCH_LIMIT],
105 idx: IdxSize,
106) -> Option<A::ValueT<'a>> {
107 let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens);
108 let arr = targets.get_unchecked(chunk_idx);
109 arr.get_unchecked(arr_idx)
110}
111
112unsafe fn gather_idx_array_unchecked<A: StaticArray>(
113 dtype: ArrowDataType,
114 targets: &[&A],
115 has_nulls: bool,
116 indices: &[IdxSize],
117) -> A {
118 let it = indices.iter().copied();
119 if targets.len() == 1 {
120 let target = targets.first().unwrap();
121 if has_nulls {
122 it.map(|i| target.get_unchecked(i as usize))
123 .collect_arr_trusted_with_dtype(dtype)
124 } else if let Some(sl) = target.as_slice() {
125 it.map(|i| sl.get_unchecked(i as usize).clone())
127 .collect_arr_trusted_with_dtype(dtype)
128 } else {
129 it.map(|i| target.value_unchecked(i as usize))
130 .collect_arr_trusted_with_dtype(dtype)
131 }
132 } else {
133 let cumlens = cumulative_lengths(targets);
134 if has_nulls {
135 it.map(|i| target_get_unchecked(targets, &cumlens, i))
136 .collect_arr_trusted_with_dtype(dtype)
137 } else {
138 it.map(|i| target_value_unchecked(targets, &cumlens, i))
139 .collect_arr_trusted_with_dtype(dtype)
140 }
141 }
142}
143
144impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ChunkedArray<T>
145where
146 T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
147{
148 unsafe fn take_unchecked(&self, indices: &I) -> Self {
150 let rechunked;
151 let mut ca = self;
152 if self.chunks().len() > BINARY_SEARCH_LIMIT {
153 rechunked = self.rechunk();
154 ca = &rechunked;
155 }
156 let targets: Vec<_> = ca.downcast_iter().collect();
157 let arr = gather_idx_array_unchecked(
158 ca.dtype().to_arrow(CompatLevel::newest()),
159 &targets,
160 ca.null_count() > 0,
161 indices.as_ref(),
162 );
163 ChunkedArray::from_chunk_iter_like(ca, [arr])
164 }
165}
166
167pub fn _update_gather_sorted_flag(sorted_arr: IsSorted, sorted_idx: IsSorted) -> IsSorted {
168 use crate::series::IsSorted::*;
169 match (sorted_arr, sorted_idx) {
170 (_, Not) => Not,
171 (Not, _) => Not,
172 (Ascending, Ascending) => Ascending,
173 (Ascending, Descending) => Descending,
174 (Descending, Ascending) => Descending,
175 (Descending, Descending) => Ascending,
176 }
177}
178
179impl<T: PolarsDataType> ChunkTakeUnchecked<IdxCa> for ChunkedArray<T>
180where
181 T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
182{
183 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
185 let rechunked;
186 let mut ca = self;
187 if self.chunks().len() > BINARY_SEARCH_LIMIT {
188 rechunked = self.rechunk();
189 ca = &rechunked;
190 }
191 let targets_have_nulls = ca.null_count() > 0;
192 let targets: Vec<_> = ca.downcast_iter().collect();
193
194 let chunks = indices.downcast_iter().map(|idx_arr| {
195 let dtype = ca.dtype().to_arrow(CompatLevel::newest());
196 if idx_arr.null_count() == 0 {
197 gather_idx_array_unchecked(dtype, &targets, targets_have_nulls, idx_arr.values())
198 } else if targets.len() == 1 {
199 let target = targets.first().unwrap();
200 if targets_have_nulls {
201 idx_arr
202 .iter()
203 .map(|i| target.get_unchecked(*i? as usize))
204 .collect_arr_trusted_with_dtype(dtype)
205 } else {
206 idx_arr
207 .iter()
208 .map(|i| Some(target.value_unchecked(*i? as usize)))
209 .collect_arr_trusted_with_dtype(dtype)
210 }
211 } else {
212 let cumlens = cumulative_lengths(&targets);
213 if targets_have_nulls {
214 idx_arr
215 .iter()
216 .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
217 .collect_arr_trusted_with_dtype(dtype)
218 } else {
219 idx_arr
220 .iter()
221 .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
222 .collect_arr_trusted_with_dtype(dtype)
223 }
224 }
225 });
226
227 let mut out = ChunkedArray::from_chunk_iter_like(ca, chunks);
228 let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
229
230 out.set_sorted_flag(sorted_flag);
231 out
232 }
233}
234
235impl ChunkTakeUnchecked<IdxCa> for BinaryChunked {
236 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
238 let rechunked = self.rechunk();
239 let indices = indices.rechunk();
240 let indices_arr = indices.downcast_iter().next().unwrap();
241 let chunks = rechunked
242 .chunks()
243 .iter()
244 .map(|arr| take_unchecked(arr.as_ref(), indices_arr))
245 .collect::<Vec<_>>();
246
247 let mut out = ChunkedArray::from_chunks(self.name().clone(), chunks);
248
249 let sorted_flag =
250 _update_gather_sorted_flag(self.is_sorted_flag(), indices.is_sorted_flag());
251 out.set_sorted_flag(sorted_flag);
252 out
253 }
254}
255
256impl ChunkTakeUnchecked<IdxCa> for StringChunked {
257 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
258 self.as_binary()
259 .take_unchecked(indices)
260 .to_string_unchecked()
261 }
262}
263
264impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for BinaryChunked {
265 unsafe fn take_unchecked(&self, indices: &I) -> Self {
267 let indices = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
268 self.take_unchecked(&indices)
269 }
270}
271
272impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for StringChunked {
273 unsafe fn take_unchecked(&self, indices: &I) -> Self {
275 self.as_binary()
276 .take_unchecked(indices)
277 .to_string_unchecked()
278 }
279}
280
281#[cfg(feature = "dtype-struct")]
282impl ChunkTakeUnchecked<IdxCa> for StructChunked {
283 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
284 let a = self.rechunk();
285 let index = indices.rechunk();
286
287 let chunks = a
288 .downcast_iter()
289 .zip(index.downcast_iter())
290 .map(|(arr, idx)| take_unchecked(arr, idx))
291 .collect::<Vec<_>>();
292 self.copy_with_chunks(chunks)
293 }
294}
295
296#[cfg(feature = "dtype-struct")]
297impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for StructChunked {
298 unsafe fn take_unchecked(&self, indices: &I) -> Self {
299 let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
300 self.take_unchecked(&idx)
301 }
302}
303
304impl IdxCa {
305 pub fn with_nullable_idx<T, F: FnOnce(&IdxCa) -> T>(idx: &[NullableIdxSize], f: F) -> T {
306 let validity: Bitmap = idx.iter().map(|idx| !idx.is_null_idx()).collect_trusted();
307 let idx = bytemuck::cast_slice::<_, IdxSize>(idx);
308 let arr = unsafe { arrow::ffi::mmap::slice(idx) };
309 let arr = arr.with_validity_typed(Some(validity));
310 let ca = IdxCa::with_chunk(PlSmallStr::EMPTY, arr);
311
312 f(&ca)
313 }
314}
315
316#[cfg(feature = "dtype-array")]
317impl ChunkTakeUnchecked<IdxCa> for ArrayChunked {
318 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
319 let chunks = vec![take_unchecked(
320 &self.rechunk().downcast_into_array(),
321 &indices.rechunk().downcast_into_array(),
322 )];
323 self.copy_with_chunks(chunks)
324 }
325}
326
327#[cfg(feature = "dtype-array")]
328impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ArrayChunked {
329 unsafe fn take_unchecked(&self, indices: &I) -> Self {
330 let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
331 self.take_unchecked(&idx)
332 }
333}
334
335impl ChunkTakeUnchecked<IdxCa> for ListChunked {
336 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
337 let chunks = vec![take_unchecked(
338 &self.rechunk().downcast_into_array(),
339 &indices.rechunk().downcast_into_array(),
340 )];
341 self.copy_with_chunks(chunks)
342 }
343}
344
345impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ListChunked {
346 unsafe fn take_unchecked(&self, indices: &I) -> Self {
347 let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
348 self.take_unchecked(&idx)
349 }
350}