polars_compute/bitwise/
mod.rs1use std::convert::identity;
2
3use arrow::array::{Array, BooleanArray, PrimitiveArray};
4use arrow::datatypes::ArrowDataType;
5use arrow::legacy::utils::CustomIterTools;
6
7pub trait BitwiseKernel {
8 type Scalar;
9
10 fn count_ones(&self) -> PrimitiveArray<u32>;
11 fn count_zeros(&self) -> PrimitiveArray<u32>;
12
13 fn leading_ones(&self) -> PrimitiveArray<u32>;
14 fn leading_zeros(&self) -> PrimitiveArray<u32>;
15
16 fn trailing_ones(&self) -> PrimitiveArray<u32>;
17 fn trailing_zeros(&self) -> PrimitiveArray<u32>;
18
19 fn reduce_and(&self) -> Option<Self::Scalar>;
20 fn reduce_or(&self) -> Option<Self::Scalar>;
21 fn reduce_xor(&self) -> Option<Self::Scalar>;
22
23 fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
24 fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
25 fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
26}
27
28macro_rules! impl_bitwise_kernel {
29 ($(($T:ty, $to_bits:expr, $from_bits:expr)),+ $(,)?) => {
30 $(
31 impl BitwiseKernel for PrimitiveArray<$T> {
32 type Scalar = $T;
33
34 #[inline(never)]
35 fn count_ones(&self) -> PrimitiveArray<u32> {
36 PrimitiveArray::new(
37 ArrowDataType::UInt32,
38 self.values_iter()
39 .map(|&v| $to_bits(v).count_ones())
40 .collect_trusted::<Vec<_>>()
41 .into(),
42 self.validity().cloned(),
43 )
44 }
45
46 #[inline(never)]
47 fn count_zeros(&self) -> PrimitiveArray<u32> {
48 PrimitiveArray::new(
49 ArrowDataType::UInt32,
50 self.values_iter()
51 .map(|&v| $to_bits(v).count_zeros())
52 .collect_trusted::<Vec<_>>()
53 .into(),
54 self.validity().cloned(),
55 )
56 }
57
58 #[inline(never)]
59 fn leading_ones(&self) -> PrimitiveArray<u32> {
60 PrimitiveArray::new(
61 ArrowDataType::UInt32,
62 self.values_iter()
63 .map(|&v| $to_bits(v).leading_ones())
64 .collect_trusted::<Vec<_>>()
65 .into(),
66 self.validity().cloned(),
67 )
68 }
69
70 #[inline(never)]
71 fn leading_zeros(&self) -> PrimitiveArray<u32> {
72 PrimitiveArray::new(
73 ArrowDataType::UInt32,
74 self.values_iter()
75 .map(|&v| $to_bits(v).leading_zeros())
76 .collect_trusted::<Vec<_>>()
77 .into(),
78 self.validity().cloned(),
79 )
80 }
81
82 #[inline(never)]
83 fn trailing_ones(&self) -> PrimitiveArray<u32> {
84 PrimitiveArray::new(
85 ArrowDataType::UInt32,
86 self.values_iter()
87 .map(|&v| $to_bits(v).trailing_ones())
88 .collect_trusted::<Vec<_>>()
89 .into(),
90 self.validity().cloned(),
91 )
92 }
93
94 #[inline(never)]
95 fn trailing_zeros(&self) -> PrimitiveArray<u32> {
96 PrimitiveArray::new(
97 ArrowDataType::UInt32,
98 self.values().iter()
99 .map(|&v| $to_bits(v).trailing_zeros())
100 .collect_trusted::<Vec<_>>()
101 .into(),
102 self.validity().cloned(),
103 )
104 }
105
106 #[inline(never)]
107 fn reduce_and(&self) -> Option<Self::Scalar> {
108 if !self.has_nulls() {
109 self.values_iter().copied().map($to_bits).reduce(|a, b| a & b).map($from_bits)
110 } else {
111 self.non_null_values_iter().map($to_bits).reduce(|a, b| a & b).map($from_bits)
112 }
113 }
114
115 #[inline(never)]
116 fn reduce_or(&self) -> Option<Self::Scalar> {
117 if !self.has_nulls() {
118 self.values_iter().copied().map($to_bits).reduce(|a, b| a | b).map($from_bits)
119 } else {
120 self.non_null_values_iter().map($to_bits).reduce(|a, b| a | b).map($from_bits)
121 }
122 }
123
124 #[inline(never)]
125 fn reduce_xor(&self) -> Option<Self::Scalar> {
126 if !self.has_nulls() {
127 self.values_iter().copied().map($to_bits).reduce(|a, b| a ^ b).map($from_bits)
128 } else {
129 self.non_null_values_iter().map($to_bits).reduce(|a, b| a ^ b).map($from_bits)
130 }
131 }
132
133 fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
134 $from_bits($to_bits(lhs) & $to_bits(rhs))
135 }
136 fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
137 $from_bits($to_bits(lhs) | $to_bits(rhs))
138 }
139 fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
140 $from_bits($to_bits(lhs) ^ $to_bits(rhs))
141 }
142 }
143 )+
144 };
145}
146
147impl_bitwise_kernel! {
148 (i8, identity, identity),
149 (i16, identity, identity),
150 (i32, identity, identity),
151 (i64, identity, identity),
152 (u8, identity, identity),
153 (u16, identity, identity),
154 (u32, identity, identity),
155 (u64, identity, identity),
156 (f32, f32::to_bits, f32::from_bits),
157 (f64, f64::to_bits, f64::from_bits),
158}
159
160#[cfg(feature = "dtype-i128")]
161impl_bitwise_kernel! {
162 (i128, identity, identity),
163}
164
165impl BitwiseKernel for BooleanArray {
166 type Scalar = bool;
167
168 #[inline(never)]
169 fn count_ones(&self) -> PrimitiveArray<u32> {
170 PrimitiveArray::new(
171 ArrowDataType::UInt32,
172 self.values_iter()
173 .map(u32::from)
174 .collect_trusted::<Vec<_>>()
175 .into(),
176 self.validity().cloned(),
177 )
178 }
179
180 #[inline(never)]
181 fn count_zeros(&self) -> PrimitiveArray<u32> {
182 PrimitiveArray::new(
183 ArrowDataType::UInt32,
184 self.values_iter()
185 .map(|v| u32::from(!v))
186 .collect_trusted::<Vec<_>>()
187 .into(),
188 self.validity().cloned(),
189 )
190 }
191
192 #[inline(always)]
193 fn leading_ones(&self) -> PrimitiveArray<u32> {
194 self.count_ones()
195 }
196
197 #[inline(always)]
198 fn leading_zeros(&self) -> PrimitiveArray<u32> {
199 self.count_zeros()
200 }
201
202 #[inline(always)]
203 fn trailing_ones(&self) -> PrimitiveArray<u32> {
204 self.count_ones()
205 }
206
207 #[inline(always)]
208 fn trailing_zeros(&self) -> PrimitiveArray<u32> {
209 self.count_zeros()
210 }
211
212 fn reduce_and(&self) -> Option<Self::Scalar> {
213 if self.len() == self.null_count() {
214 None
215 } else if !self.has_nulls() {
216 Some(self.values().unset_bits() == 0)
217 } else {
218 Some((self.values() & self.validity().unwrap()).unset_bits() == 0)
219 }
220 }
221
222 fn reduce_or(&self) -> Option<Self::Scalar> {
223 if self.len() == self.null_count() {
224 None
225 } else if !self.has_nulls() {
226 Some(self.values().set_bits() > 0)
227 } else {
228 Some((self.values() & self.validity().unwrap()).set_bits() > 0)
229 }
230 }
231
232 fn reduce_xor(&self) -> Option<Self::Scalar> {
233 if self.len() == self.null_count() {
234 None
235 } else if !self.has_nulls() {
236 Some(self.values().set_bits() % 2 == 1)
237 } else {
238 Some((self.values() & self.validity().unwrap()).set_bits() % 2 == 1)
239 }
240 }
241
242 fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
243 lhs & rhs
244 }
245 fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
246 lhs | rhs
247 }
248 fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
249 lhs ^ rhs
250 }
251}