polars_arrow/compute/
arity.rs

1//! Defines kernels suitable to perform operations to primitive arrays.
2
3use super::utils::{check_same_len, combine_validities_and};
4use crate::array::PrimitiveArray;
5use crate::datatypes::ArrowDataType;
6use crate::types::NativeType;
7
8/// Applies an unary and infallible function to a [`PrimitiveArray`].
9///
10/// This is the /// fastest way to perform an operation on a [`PrimitiveArray`] when the benefits
11/// of a vectorized operation outweighs the cost of branching nulls and non-nulls.
12///
13/// # Implementation
14/// This will apply the function for all values, including those on null slots.
15/// This implies that the operation must be infallible for any value of the
16/// corresponding type or this function may panic.
17#[inline]
18pub fn unary<I, F, O>(array: &PrimitiveArray<I>, op: F, dtype: ArrowDataType) -> PrimitiveArray<O>
19where
20    I: NativeType,
21    O: NativeType,
22    F: Fn(I) -> O,
23{
24    let values = array.values().iter().map(|v| op(*v)).collect::<Vec<_>>();
25
26    PrimitiveArray::<O>::new(dtype, values.into(), array.validity().cloned())
27}
28
29/// Applies a binary operations to two primitive arrays.
30///
31/// This is the fastest way to perform an operation on two primitive array when the benefits of a
32/// vectorized operation outweighs the cost of branching nulls and non-nulls.
33///
34/// # Errors
35/// This function errors iff the arrays have a different length.
36///
37/// # Implementation
38/// This will apply the function for all values, including those on null slots.
39/// This implies that the operation must be infallible for any value of the
40/// corresponding type.
41/// The types of the arrays are not checked with this operation. The closure
42/// "op" needs to handle the different types in the arrays. The datatype for the
43/// resulting array has to be selected by the implementer of the function as
44/// an argument for the function.
45#[inline]
46pub fn binary<T, D, F>(
47    lhs: &PrimitiveArray<T>,
48    rhs: &PrimitiveArray<D>,
49    dtype: ArrowDataType,
50    op: F,
51) -> PrimitiveArray<T>
52where
53    T: NativeType,
54    D: NativeType,
55    F: Fn(T, D) -> T,
56{
57    check_same_len(lhs, rhs).unwrap();
58
59    let validity = combine_validities_and(lhs.validity(), rhs.validity());
60
61    let values = lhs
62        .values()
63        .iter()
64        .zip(rhs.values().iter())
65        .map(|(l, r)| op(*l, *r))
66        .collect::<Vec<_>>()
67        .into();
68
69    PrimitiveArray::<T>::new(dtype, values, validity)
70}