polars_compute/
arity.rs

1use arrow::array::PrimitiveArray;
2use arrow::compute::utils::combine_validities_and;
3use arrow::types::NativeType;
4
5/// To reduce codegen we use these helpers where the input and output arrays
6/// may overlap. These are marked to never be inlined, this way only a single
7/// unrolled kernel gets generated, even if we call it in multiple ways.
8///
9/// # Safety
10///  - arr must point to a readable slice of length len.
11///  - out must point to a writable slice of length len.
12#[inline(never)]
13unsafe fn ptr_apply_unary_kernel<I: Copy, O, F: Fn(I) -> O>(
14    arr: *const I,
15    out: *mut O,
16    len: usize,
17    op: F,
18) {
19    for i in 0..len {
20        let ret = op(arr.add(i).read());
21        out.add(i).write(ret);
22    }
23}
24
25/// # Safety
26///  - left must point to a readable slice of length len.
27///  - right must point to a readable slice of length len.
28///  - out must point to a writable slice of length len.
29#[inline(never)]
30unsafe fn ptr_apply_binary_kernel<L: Copy, R: Copy, O, F: Fn(L, R) -> O>(
31    left: *const L,
32    right: *const R,
33    out: *mut O,
34    len: usize,
35    op: F,
36) {
37    for i in 0..len {
38        let ret = op(left.add(i).read(), right.add(i).read());
39        out.add(i).write(ret);
40    }
41}
42
43/// Applies a function to all the values (regardless of nullability).
44///
45/// May reuse the memory of the array if possible.
46pub fn prim_unary_values<I, O, F>(mut arr: PrimitiveArray<I>, op: F) -> PrimitiveArray<O>
47where
48    I: NativeType,
49    O: NativeType,
50    F: Fn(I) -> O,
51{
52    let len = arr.len();
53
54    // Reuse memory if possible.
55    if size_of::<I>() == size_of::<O>() && align_of::<I>() == align_of::<O>() {
56        if let Some(values) = arr.get_mut_values() {
57            let ptr = values.as_mut_ptr();
58            // SAFETY: checked same size & alignment I/O, NativeType is always Pod.
59            unsafe { ptr_apply_unary_kernel(ptr, ptr as *mut O, len, op) }
60            return arr.transmute::<O>();
61        }
62    }
63
64    let mut out = Vec::with_capacity(len);
65    unsafe {
66        // SAFETY: checked pointers point to slices of length len.
67        ptr_apply_unary_kernel(arr.values().as_ptr(), out.as_mut_ptr(), len, op);
68        out.set_len(len);
69    }
70    PrimitiveArray::from_vec(out).with_validity(arr.take_validity())
71}
72
73/// Apply a binary function to all the values (regardless of nullability)
74/// in (lhs, rhs). Combines the validities with a bitand.
75///
76/// May reuse the memory of one of its arguments if possible.
77pub fn prim_binary_values<L, R, O, F>(
78    mut lhs: PrimitiveArray<L>,
79    mut rhs: PrimitiveArray<R>,
80    op: F,
81) -> PrimitiveArray<O>
82where
83    L: NativeType,
84    R: NativeType,
85    O: NativeType,
86    F: Fn(L, R) -> O,
87{
88    assert_eq!(lhs.len(), rhs.len());
89    let len = lhs.len();
90
91    let validity = combine_validities_and(lhs.validity(), rhs.validity());
92
93    // Reuse memory if possible.
94    if size_of::<L>() == size_of::<O>() && align_of::<L>() == align_of::<O>() {
95        if let Some(lv) = lhs.get_mut_values() {
96            let lp = lv.as_mut_ptr();
97            let rp = rhs.values().as_ptr();
98            unsafe {
99                // SAFETY: checked same size & alignment L/O, NativeType is always Pod.
100                ptr_apply_binary_kernel(lp, rp, lp as *mut O, len, op);
101            }
102            return lhs.transmute::<O>().with_validity(validity);
103        }
104    }
105    if size_of::<R>() == size_of::<O>() && align_of::<R>() == align_of::<O>() {
106        if let Some(rv) = rhs.get_mut_values() {
107            let lp = lhs.values().as_ptr();
108            let rp = rv.as_mut_ptr();
109            unsafe {
110                // SAFETY: checked same size & alignment R/O, NativeType is always Pod.
111                ptr_apply_binary_kernel(lp, rp, rp as *mut O, len, op);
112            }
113            return rhs.transmute::<O>().with_validity(validity);
114        }
115    }
116
117    let mut out = Vec::with_capacity(len);
118    unsafe {
119        // SAFETY: checked pointers point to slices of length len.
120        let lp = lhs.values().as_ptr();
121        let rp = rhs.values().as_ptr();
122        ptr_apply_binary_kernel(lp, rp, out.as_mut_ptr(), len, op);
123        out.set_len(len);
124    }
125    PrimitiveArray::from_vec(out).with_validity(validity)
126}