polars_compute/gather/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines take kernel for [`Array`]
19
20use arrow::array::{
21    self, new_empty_array, Array, ArrayCollectIterExt, ArrayFromIterDtype, NullArray, StaticArray,
22    Utf8ViewArray,
23};
24use arrow::datatypes::{ArrowDataType, IdxArr};
25use arrow::types::Index;
26
27use crate::gather::binview::take_binview_unchecked;
28
29pub mod binary;
30pub mod binview;
31pub mod bitmap;
32pub mod boolean;
33pub mod fixed_size_list;
34pub mod generic_binary;
35pub mod list;
36pub mod primitive;
37pub mod structure;
38pub mod sublist;
39
40use arrow::with_match_primitive_type_full;
41
42/// Returns a new [`Array`] with only indices at `indices`. Null indices are taken as nulls.
43/// The returned array has a length equal to `indices.len()`.
44/// # Safety
45/// Doesn't do bound checks
46pub unsafe fn take_unchecked(values: &dyn Array, indices: &IdxArr) -> Box<dyn Array> {
47    if indices.len() == 0 {
48        return new_empty_array(values.dtype().clone());
49    }
50
51    use arrow::datatypes::PhysicalType::*;
52    match values.dtype().to_physical_type() {
53        Null => Box::new(NullArray::new(values.dtype().clone(), indices.len())),
54        Boolean => {
55            let values = values.as_any().downcast_ref().unwrap();
56            Box::new(boolean::take_unchecked(values, indices))
57        },
58        Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {
59            let values = values.as_any().downcast_ref().unwrap();
60            Box::new(primitive::take_primitive_unchecked::<$T>(&values, indices))
61        }),
62        LargeBinary => {
63            let values = values.as_any().downcast_ref().unwrap();
64            Box::new(binary::take_unchecked::<i64, _>(values, indices))
65        },
66        Struct => {
67            let array = values.as_any().downcast_ref().unwrap();
68            structure::take_unchecked(array, indices).boxed()
69        },
70        LargeList => {
71            let array = values.as_any().downcast_ref().unwrap();
72            Box::new(list::take_unchecked::<i64>(array, indices))
73        },
74        FixedSizeList => {
75            let array = values.as_any().downcast_ref().unwrap();
76            fixed_size_list::take_unchecked(array, indices)
77        },
78        BinaryView => {
79            take_binview_unchecked(values.as_any().downcast_ref().unwrap(), indices).boxed()
80        },
81        Utf8View => {
82            let arr: &Utf8ViewArray = values.as_any().downcast_ref().unwrap();
83            take_binview_unchecked(&arr.to_binview(), indices)
84                .to_utf8view_unchecked()
85                .boxed()
86        },
87        t => unimplemented!("Take not supported for data type {:?}", t),
88    }
89}
90
91/// Naive default implementation
92unsafe fn take_unchecked_impl_generic<T>(
93    values: &T,
94    indices: &IdxArr,
95    new_null_func: &dyn Fn(ArrowDataType, usize) -> T,
96) -> T
97where
98    T: StaticArray + ArrayFromIterDtype<std::option::Option<Box<dyn array::Array>>>,
99{
100    if values.null_count() == values.len() || indices.null_count() == indices.len() {
101        return new_null_func(values.dtype().clone(), indices.len());
102    }
103
104    match (indices.has_nulls(), values.has_nulls()) {
105        (true, true) => {
106            let values_validity = values.validity().unwrap();
107
108            indices
109                .iter()
110                .map(|i| {
111                    if let Some(i) = i {
112                        let i = *i as usize;
113                        if values_validity.get_bit_unchecked(i) {
114                            return Some(values.value_unchecked(i));
115                        }
116                    }
117                    None
118                })
119                .collect_arr_trusted_with_dtype(values.dtype().clone())
120        },
121        (true, false) => indices
122            .iter()
123            .map(|i| {
124                if let Some(i) = i {
125                    let i = *i as usize;
126                    return Some(values.value_unchecked(i));
127                }
128                None
129            })
130            .collect_arr_trusted_with_dtype(values.dtype().clone()),
131        (false, true) => {
132            let values_validity = values.validity().unwrap();
133
134            indices
135                .values_iter()
136                .map(|i| {
137                    let i = *i as usize;
138                    if values_validity.get_bit_unchecked(i) {
139                        return Some(values.value_unchecked(i));
140                    }
141                    None
142                })
143                .collect_arr_trusted_with_dtype(values.dtype().clone())
144        },
145        (false, false) => indices
146            .values_iter()
147            .map(|i| {
148                let i = *i as usize;
149                Some(values.value_unchecked(i))
150            })
151            .collect_arr_trusted_with_dtype(values.dtype().clone()),
152    }
153}