polars_arrow/compute/arity.rs
1//! Defines kernels suitable to perform operations to primitive arrays.
2
3use super::utils::{check_same_len, combine_validities_and};
4use crate::array::PrimitiveArray;
5use crate::datatypes::ArrowDataType;
6use crate::types::NativeType;
7
8/// Applies an unary and infallible function to a [`PrimitiveArray`].
9///
10/// This is the /// fastest way to perform an operation on a [`PrimitiveArray`] when the benefits
11/// of a vectorized operation outweighs the cost of branching nulls and non-nulls.
12///
13/// # Implementation
14/// This will apply the function for all values, including those on null slots.
15/// This implies that the operation must be infallible for any value of the
16/// corresponding type or this function may panic.
17#[inline]
18pub fn unary<I, F, O>(array: &PrimitiveArray<I>, op: F, dtype: ArrowDataType) -> PrimitiveArray<O>
19where
20 I: NativeType,
21 O: NativeType,
22 F: Fn(I) -> O,
23{
24 let values = array.values().iter().map(|v| op(*v)).collect::<Vec<_>>();
25
26 PrimitiveArray::<O>::new(dtype, values.into(), array.validity().cloned())
27}
28
29/// Applies a binary operations to two primitive arrays.
30///
31/// This is the fastest way to perform an operation on two primitive array when the benefits of a
32/// vectorized operation outweighs the cost of branching nulls and non-nulls.
33///
34/// # Errors
35/// This function errors iff the arrays have a different length.
36///
37/// # Implementation
38/// This will apply the function for all values, including those on null slots.
39/// This implies that the operation must be infallible for any value of the
40/// corresponding type.
41/// The types of the arrays are not checked with this operation. The closure
42/// "op" needs to handle the different types in the arrays. The datatype for the
43/// resulting array has to be selected by the implementer of the function as
44/// an argument for the function.
45#[inline]
46pub fn binary<T, D, F>(
47 lhs: &PrimitiveArray<T>,
48 rhs: &PrimitiveArray<D>,
49 dtype: ArrowDataType,
50 op: F,
51) -> PrimitiveArray<T>
52where
53 T: NativeType,
54 D: NativeType,
55 F: Fn(T, D) -> T,
56{
57 check_same_len(lhs, rhs).unwrap();
58
59 let validity = combine_validities_and(lhs.validity(), rhs.validity());
60
61 let values = lhs
62 .values()
63 .iter()
64 .zip(rhs.values().iter())
65 .map(|(l, r)| op(*l, *r))
66 .collect::<Vec<_>>()
67 .into();
68
69 PrimitiveArray::<T>::new(dtype, values, validity)
70}