polars_compute/comparisons/
list.rs1use arrow::array::{
2 Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray,
3 ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, Utf8ViewArray,
4};
5use arrow::bitmap::Bitmap;
6use arrow::legacy::utils::CustomIterTools;
7use arrow::types::{days_ms, f16, i256, months_days_ns, Offset};
8
9use super::TotalEqKernel;
10
11macro_rules! compare {
12 (
13 $lhs:expr, $rhs:expr,
14 $op:path, $true_op:expr,
15 $ineq_len_rv:literal, $invalid_rv:literal
16 ) => {{
17 let lhs = $lhs;
18 let rhs = $rhs;
19
20 assert_eq!(lhs.len(), rhs.len());
21 assert_eq!(lhs.dtype(), rhs.dtype());
22
23 macro_rules! call_binary {
24 ($T:ty) => {{
25 let lhs_values: &$T = $lhs.values().as_any().downcast_ref().unwrap();
26 let rhs_values: &$T = $rhs.values().as_any().downcast_ref().unwrap();
27
28 (0..$lhs.len())
29 .map(|i| {
30 let lval = $lhs.validity().is_none_or(|v| v.get(i).unwrap());
31 let rval = $rhs.validity().is_none_or(|v| v.get(i).unwrap());
32
33 if !lval || !rval {
34 return $invalid_rv;
35 }
36
37 let (lstart, lend) = unsafe { $lhs.offsets().start_end_unchecked(i) };
39 let (rstart, rend) = unsafe { $rhs.offsets().start_end_unchecked(i) };
40
41 if lend - lstart != rend - rstart {
42 return $ineq_len_rv;
43 }
44
45 let mut lhs_values = lhs_values.clone();
46 lhs_values.slice(lstart, lend - lstart);
47 let mut rhs_values = rhs_values.clone();
48 rhs_values.slice(rstart, rend - rstart);
49
50 $true_op($op(&lhs_values, &rhs_values))
51 })
52 .collect_trusted()
53 }};
54 }
55
56 use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};
57 match lhs.values().dtype().to_physical_type() {
58 PH::Boolean => call_binary!(BooleanArray),
59 PH::BinaryView => call_binary!(BinaryViewArray),
60 PH::Utf8View => call_binary!(Utf8ViewArray),
61 PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),
62 PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),
63 PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),
64 PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),
65 PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),
66 PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),
67 PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),
68 PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),
69 PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),
70 PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),
71 PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<f16>),
72 PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),
73 PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),
74 PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),
75 PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),
76 PH::Primitive(PR::MonthDayNano) => {
77 call_binary!(PrimitiveArray<months_days_ns>)
78 },
79
80 #[cfg(feature = "dtype-array")]
81 PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),
82 #[cfg(not(feature = "dtype-array"))]
83 PH::FixedSizeList => todo!(
84 "Comparison of FixedSizeListArray is not supported without dtype-array feature"
85 ),
86
87 PH::Null => call_binary!(NullArray),
88 PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),
89 PH::Binary => call_binary!(BinaryArray<i32>),
90 PH::LargeBinary => call_binary!(BinaryArray<i64>),
91 PH::Utf8 => call_binary!(Utf8Array<i32>),
92 PH::LargeUtf8 => call_binary!(Utf8Array<i64>),
93 PH::List => call_binary!(ListArray<i32>),
94 PH::LargeList => call_binary!(ListArray<i64>),
95 PH::Struct => call_binary!(StructArray),
96 PH::Union => todo!("Comparison of UnionArrays is not yet supported"),
97 PH::Map => todo!("Comparison of MapArrays is not yet supported"),
98 PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),
99 PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),
100 PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),
101 PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),
102 PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),
103 PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),
104 PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),
105 PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),
106 PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),
107 }
108 }};
109}
110
111macro_rules! compare_broadcast {
112 (
113 $lhs:expr, $rhs:expr,
114 $offsets:expr, $validity:expr,
115 $op:path, $true_op:expr,
116 $ineq_len_rv:literal, $invalid_rv:literal
117 ) => {{
118 let lhs = $lhs;
119 let rhs = $rhs;
120
121 macro_rules! call_binary {
122 ($T:ty) => {{
123 let values: &$T = $lhs.as_any().downcast_ref().unwrap();
124 let scalar: &$T = $rhs.as_any().downcast_ref().unwrap();
125
126 let length = $offsets.len_proxy();
127
128 (0..length)
129 .map(move |i| {
130 let v = $validity.is_none_or(|v| v.get(i).unwrap());
131
132 if !v {
133 return $invalid_rv;
134 }
135
136 let (start, end) = unsafe { $offsets.start_end_unchecked(i) };
137
138 if end - start != scalar.len() {
139 return $ineq_len_rv;
140 }
141
142 let mut values: $T = values.clone();
144 <$T>::slice(&mut values, start, end - start);
145
146 $true_op($op(&values, scalar))
147 })
148 .collect_trusted()
149 }};
150 }
151
152 assert_eq!(lhs.dtype(), rhs.dtype());
153
154 use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};
155 match lhs.dtype().to_physical_type() {
156 PH::Boolean => call_binary!(BooleanArray),
157 PH::BinaryView => call_binary!(BinaryViewArray),
158 PH::Utf8View => call_binary!(Utf8ViewArray),
159 PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),
160 PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),
161 PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),
162 PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),
163 PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),
164 PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),
165 PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),
166 PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),
167 PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),
168 PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),
169 PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<f16>),
170 PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),
171 PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),
172 PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),
173 PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),
174 PH::Primitive(PR::MonthDayNano) => {
175 call_binary!(PrimitiveArray<months_days_ns>)
176 },
177
178 #[cfg(feature = "dtype-array")]
179 PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),
180 #[cfg(not(feature = "dtype-array"))]
181 PH::FixedSizeList => todo!(
182 "Comparison of FixedSizeListArray is not supported without dtype-array feature"
183 ),
184
185 PH::Null => call_binary!(NullArray),
186 PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),
187 PH::Binary => call_binary!(BinaryArray<i32>),
188 PH::LargeBinary => call_binary!(BinaryArray<i64>),
189 PH::Utf8 => call_binary!(Utf8Array<i32>),
190 PH::LargeUtf8 => call_binary!(Utf8Array<i64>),
191 PH::List => call_binary!(ListArray<i32>),
192 PH::LargeList => call_binary!(ListArray<i64>),
193 PH::Struct => call_binary!(StructArray),
194 PH::Union => todo!("Comparison of UnionArrays is not yet supported"),
195 PH::Map => todo!("Comparison of MapArrays is not yet supported"),
196 PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),
197 PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),
198 PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),
199 PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),
200 PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),
201 PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),
202 PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),
203 PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),
204 PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),
205 }
206 }};
207}
208
209impl<O: Offset> TotalEqKernel for ListArray<O> {
210 type Scalar = Box<dyn Array>;
211
212 fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
213 compare!(
214 self,
215 other,
216 TotalEqKernel::tot_eq_missing_kernel,
217 |bm: Bitmap| bm.unset_bits() == 0,
218 false,
219 true
220 )
221 }
222
223 fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
224 compare!(
225 self,
226 other,
227 TotalEqKernel::tot_ne_missing_kernel,
228 |bm: Bitmap| bm.set_bits() > 0,
229 true,
230 false
231 )
232 }
233
234 fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
235 compare_broadcast!(
236 self.values().as_ref(),
237 other.as_ref(),
238 self.offsets(),
239 self.validity(),
240 TotalEqKernel::tot_eq_missing_kernel,
241 |bm: Bitmap| bm.unset_bits() == 0,
242 false,
243 true
244 )
245 }
246
247 fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
248 compare_broadcast!(
249 self.values().as_ref(),
250 other.as_ref(),
251 self.offsets(),
252 self.validity(),
253 TotalEqKernel::tot_ne_missing_kernel,
254 |bm: Bitmap| bm.set_bits() > 0,
255 true,
256 false
257 )
258 }
259}