polars_arrow/legacy/kernels/
set.rs

1use std::ops::BitOr;
2
3use polars_error::polars_err;
4use polars_utils::IdxSize;
5
6use crate::array::*;
7use crate::datatypes::ArrowDataType;
8use crate::legacy::array::default_arrays::FromData;
9use crate::legacy::error::PolarsResult;
10use crate::legacy::kernels::BinaryMaskedSliceIterator;
11use crate::legacy::trusted_len::TrustedLenPush;
12use crate::types::NativeType;
13
14/// Set values in a primitive array where the primitive array has null values.
15/// this is faster because we don't have to invert and combine bitmaps
16pub fn set_at_nulls<T>(array: &PrimitiveArray<T>, value: T) -> PrimitiveArray<T>
17where
18    T: NativeType,
19{
20    let values = array.values();
21    if array.null_count() == 0 {
22        return array.clone();
23    }
24
25    let validity = array.validity().unwrap();
26    let validity = BooleanArray::from_data_default(validity.clone(), None);
27
28    let mut av = Vec::with_capacity(array.len());
29    BinaryMaskedSliceIterator::new(&validity).for_each(|(lower, upper, truthy)| {
30        if truthy {
31            av.extend_from_slice(&values[lower..upper])
32        } else {
33            av.extend_trusted_len(std::iter::repeat(value).take(upper - lower))
34        }
35    });
36
37    PrimitiveArray::new(array.dtype().clone(), av.into(), None)
38}
39
40/// Set values in a primitive array based on a mask array. This is fast when large chunks of bits are set or unset.
41pub fn set_with_mask<T: NativeType>(
42    array: &PrimitiveArray<T>,
43    mask: &BooleanArray,
44    value: T,
45    dtype: ArrowDataType,
46) -> PrimitiveArray<T> {
47    let values = array.values();
48
49    let mut buf = Vec::with_capacity(array.len());
50    BinaryMaskedSliceIterator::new(mask).for_each(|(lower, upper, truthy)| {
51        if truthy {
52            buf.extend_trusted_len(std::iter::repeat(value).take(upper - lower))
53        } else {
54            buf.extend_from_slice(&values[lower..upper])
55        }
56    });
57    // make sure that where the mask is set to true, the validity buffer is also set to valid
58    // after we have applied the or operation we have new buffer with no offsets
59    let validity = array.validity().as_ref().map(|valid| {
60        let mask_bitmap = mask.values();
61        valid.bitor(mask_bitmap)
62    });
63
64    PrimitiveArray::new(dtype, buf.into(), validity)
65}
66
67/// Efficiently sets value at the indices from the iterator to `set_value`.
68/// The new array is initialized with a `memcpy` from the old values.
69pub fn scatter_single_non_null<T, I>(
70    array: &PrimitiveArray<T>,
71    idx: I,
72    set_value: T,
73    dtype: ArrowDataType,
74) -> PolarsResult<PrimitiveArray<T>>
75where
76    T: NativeType,
77    I: IntoIterator<Item = IdxSize>,
78{
79    let mut buf = Vec::with_capacity(array.len());
80    buf.extend_from_slice(array.values().as_slice());
81    let mut_slice = buf.as_mut_slice();
82
83    idx.into_iter().try_for_each::<_, PolarsResult<_>>(|idx| {
84        let val = mut_slice
85            .get_mut(idx as usize)
86            .ok_or_else(|| polars_err!(ComputeError: "index is out of bounds"))?;
87        *val = set_value;
88        Ok(())
89    })?;
90
91    Ok(PrimitiveArray::new(
92        dtype,
93        buf.into(),
94        array.validity().cloned(),
95    ))
96}
97
98#[cfg(test)]
99mod test {
100    use super::*;
101
102    #[test]
103    fn test_set_mask() {
104        let mask = BooleanArray::from_iter((0..86).map(|v| v > 68 && v != 85).map(Some));
105        let val = UInt32Array::from_iter((0..86).map(Some));
106        let a = set_with_mask(&val, &mask, 100, ArrowDataType::UInt32);
107        let slice = a.values();
108
109        assert_eq!(slice[a.len() - 1], 85);
110        assert_eq!(slice[a.len() - 2], 100);
111        assert_eq!(slice[67], 67);
112        assert_eq!(slice[68], 68);
113        assert_eq!(slice[1], 1);
114        assert_eq!(slice[0], 0);
115
116        let mask = BooleanArray::from_slice([
117            false, true, false, true, false, true, false, true, false, false,
118        ]);
119        let val = UInt32Array::from_slice([0; 10]);
120        let out = set_with_mask(&val, &mask, 1, ArrowDataType::UInt32);
121        assert_eq!(out.values().as_slice(), &[0, 1, 0, 1, 0, 1, 0, 1, 0, 0]);
122
123        let val = UInt32Array::from(&[None, None, None]);
124        let mask = BooleanArray::from(&[Some(true), Some(true), None]);
125        let out = set_with_mask(&val, &mask, 1, ArrowDataType::UInt32);
126        let out: Vec<_> = out.iter().map(|v| v.copied()).collect();
127        assert_eq!(out, &[Some(1), Some(1), None])
128    }
129
130    #[test]
131    fn test_scatter_single_non_null() {
132        let val = UInt32Array::from_slice([1, 2, 3]);
133        let out =
134            scatter_single_non_null(&val, std::iter::once(1), 100, ArrowDataType::UInt32).unwrap();
135        assert_eq!(out.values().as_slice(), &[1, 100, 3]);
136        let out = scatter_single_non_null(&val, std::iter::once(100), 100, ArrowDataType::UInt32);
137        assert!(out.is_err())
138    }
139}