polars_arrow/legacy/kernels/
set.rs1use std::ops::BitOr;
2
3use polars_error::polars_err;
4use polars_utils::IdxSize;
5
6use crate::array::*;
7use crate::datatypes::ArrowDataType;
8use crate::legacy::array::default_arrays::FromData;
9use crate::legacy::error::PolarsResult;
10use crate::legacy::kernels::BinaryMaskedSliceIterator;
11use crate::legacy::trusted_len::TrustedLenPush;
12use crate::types::NativeType;
13
14pub fn set_at_nulls<T>(array: &PrimitiveArray<T>, value: T) -> PrimitiveArray<T>
17where
18 T: NativeType,
19{
20 let values = array.values();
21 if array.null_count() == 0 {
22 return array.clone();
23 }
24
25 let validity = array.validity().unwrap();
26 let validity = BooleanArray::from_data_default(validity.clone(), None);
27
28 let mut av = Vec::with_capacity(array.len());
29 BinaryMaskedSliceIterator::new(&validity).for_each(|(lower, upper, truthy)| {
30 if truthy {
31 av.extend_from_slice(&values[lower..upper])
32 } else {
33 av.extend_trusted_len(std::iter::repeat(value).take(upper - lower))
34 }
35 });
36
37 PrimitiveArray::new(array.dtype().clone(), av.into(), None)
38}
39
40pub fn set_with_mask<T: NativeType>(
42 array: &PrimitiveArray<T>,
43 mask: &BooleanArray,
44 value: T,
45 dtype: ArrowDataType,
46) -> PrimitiveArray<T> {
47 let values = array.values();
48
49 let mut buf = Vec::with_capacity(array.len());
50 BinaryMaskedSliceIterator::new(mask).for_each(|(lower, upper, truthy)| {
51 if truthy {
52 buf.extend_trusted_len(std::iter::repeat(value).take(upper - lower))
53 } else {
54 buf.extend_from_slice(&values[lower..upper])
55 }
56 });
57 let validity = array.validity().as_ref().map(|valid| {
60 let mask_bitmap = mask.values();
61 valid.bitor(mask_bitmap)
62 });
63
64 PrimitiveArray::new(dtype, buf.into(), validity)
65}
66
67pub fn scatter_single_non_null<T, I>(
70 array: &PrimitiveArray<T>,
71 idx: I,
72 set_value: T,
73 dtype: ArrowDataType,
74) -> PolarsResult<PrimitiveArray<T>>
75where
76 T: NativeType,
77 I: IntoIterator<Item = IdxSize>,
78{
79 let mut buf = Vec::with_capacity(array.len());
80 buf.extend_from_slice(array.values().as_slice());
81 let mut_slice = buf.as_mut_slice();
82
83 idx.into_iter().try_for_each::<_, PolarsResult<_>>(|idx| {
84 let val = mut_slice
85 .get_mut(idx as usize)
86 .ok_or_else(|| polars_err!(ComputeError: "index is out of bounds"))?;
87 *val = set_value;
88 Ok(())
89 })?;
90
91 Ok(PrimitiveArray::new(
92 dtype,
93 buf.into(),
94 array.validity().cloned(),
95 ))
96}
97
98#[cfg(test)]
99mod test {
100 use super::*;
101
102 #[test]
103 fn test_set_mask() {
104 let mask = BooleanArray::from_iter((0..86).map(|v| v > 68 && v != 85).map(Some));
105 let val = UInt32Array::from_iter((0..86).map(Some));
106 let a = set_with_mask(&val, &mask, 100, ArrowDataType::UInt32);
107 let slice = a.values();
108
109 assert_eq!(slice[a.len() - 1], 85);
110 assert_eq!(slice[a.len() - 2], 100);
111 assert_eq!(slice[67], 67);
112 assert_eq!(slice[68], 68);
113 assert_eq!(slice[1], 1);
114 assert_eq!(slice[0], 0);
115
116 let mask = BooleanArray::from_slice([
117 false, true, false, true, false, true, false, true, false, false,
118 ]);
119 let val = UInt32Array::from_slice([0; 10]);
120 let out = set_with_mask(&val, &mask, 1, ArrowDataType::UInt32);
121 assert_eq!(out.values().as_slice(), &[0, 1, 0, 1, 0, 1, 0, 1, 0, 0]);
122
123 let val = UInt32Array::from(&[None, None, None]);
124 let mask = BooleanArray::from(&[Some(true), Some(true), None]);
125 let out = set_with_mask(&val, &mask, 1, ArrowDataType::UInt32);
126 let out: Vec<_> = out.iter().map(|v| v.copied()).collect();
127 assert_eq!(out, &[Some(1), Some(1), None])
128 }
129
130 #[test]
131 fn test_scatter_single_non_null() {
132 let val = UInt32Array::from_slice([1, 2, 3]);
133 let out =
134 scatter_single_non_null(&val, std::iter::once(1), 100, ArrowDataType::UInt32).unwrap();
135 assert_eq!(out.values().as_slice(), &[1, 100, 3]);
136 let out = scatter_single_non_null(&val, std::iter::once(100), 100, ArrowDataType::UInt32);
137 assert!(out.is_err())
138 }
139}