polars_compute/comparisons/
list.rs

1use arrow::array::{
2    Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray,
3    ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, Utf8ViewArray,
4};
5use arrow::bitmap::Bitmap;
6use arrow::legacy::utils::CustomIterTools;
7use arrow::types::{days_ms, f16, i256, months_days_ns, Offset};
8
9use super::TotalEqKernel;
10
11macro_rules! compare {
12    (
13        $lhs:expr, $rhs:expr,
14        $op:path, $true_op:expr,
15        $ineq_len_rv:literal, $invalid_rv:literal
16    ) => {{
17        let lhs = $lhs;
18        let rhs = $rhs;
19
20        assert_eq!(lhs.len(), rhs.len());
21        assert_eq!(lhs.dtype(), rhs.dtype());
22
23        macro_rules! call_binary {
24            ($T:ty) => {{
25                let lhs_values: &$T = $lhs.values().as_any().downcast_ref().unwrap();
26                let rhs_values: &$T = $rhs.values().as_any().downcast_ref().unwrap();
27
28                (0..$lhs.len())
29                    .map(|i| {
30                        let lval = $lhs.validity().is_none_or(|v| v.get(i).unwrap());
31                        let rval = $rhs.validity().is_none_or(|v| v.get(i).unwrap());
32
33                        if !lval || !rval {
34                            return $invalid_rv;
35                        }
36
37                        // SAFETY: ListArray's invariant offsets.len_proxy() == len
38                        let (lstart, lend) = unsafe { $lhs.offsets().start_end_unchecked(i) };
39                        let (rstart, rend) = unsafe { $rhs.offsets().start_end_unchecked(i) };
40
41                        if lend - lstart != rend - rstart {
42                            return $ineq_len_rv;
43                        }
44
45                        let mut lhs_values = lhs_values.clone();
46                        lhs_values.slice(lstart, lend - lstart);
47                        let mut rhs_values = rhs_values.clone();
48                        rhs_values.slice(rstart, rend - rstart);
49
50                        $true_op($op(&lhs_values, &rhs_values))
51                    })
52                    .collect_trusted()
53            }};
54        }
55
56        use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};
57        match lhs.values().dtype().to_physical_type() {
58            PH::Boolean => call_binary!(BooleanArray),
59            PH::BinaryView => call_binary!(BinaryViewArray),
60            PH::Utf8View => call_binary!(Utf8ViewArray),
61            PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),
62            PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),
63            PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),
64            PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),
65            PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),
66            PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),
67            PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),
68            PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),
69            PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),
70            PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),
71            PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<f16>),
72            PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),
73            PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),
74            PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),
75            PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),
76            PH::Primitive(PR::MonthDayNano) => {
77                call_binary!(PrimitiveArray<months_days_ns>)
78            },
79
80            #[cfg(feature = "dtype-array")]
81            PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),
82            #[cfg(not(feature = "dtype-array"))]
83            PH::FixedSizeList => todo!(
84                "Comparison of FixedSizeListArray is not supported without dtype-array feature"
85            ),
86
87            PH::Null => call_binary!(NullArray),
88            PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),
89            PH::Binary => call_binary!(BinaryArray<i32>),
90            PH::LargeBinary => call_binary!(BinaryArray<i64>),
91            PH::Utf8 => call_binary!(Utf8Array<i32>),
92            PH::LargeUtf8 => call_binary!(Utf8Array<i64>),
93            PH::List => call_binary!(ListArray<i32>),
94            PH::LargeList => call_binary!(ListArray<i64>),
95            PH::Struct => call_binary!(StructArray),
96            PH::Union => todo!("Comparison of UnionArrays is not yet supported"),
97            PH::Map => todo!("Comparison of MapArrays is not yet supported"),
98            PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),
99            PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),
100            PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),
101            PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),
102            PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),
103            PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),
104            PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),
105            PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),
106            PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),
107        }
108    }};
109}
110
111macro_rules! compare_broadcast {
112    (
113        $lhs:expr, $rhs:expr,
114        $offsets:expr, $validity:expr,
115        $op:path, $true_op:expr,
116        $ineq_len_rv:literal, $invalid_rv:literal
117    ) => {{
118        let lhs = $lhs;
119        let rhs = $rhs;
120
121        macro_rules! call_binary {
122            ($T:ty) => {{
123                let values: &$T = $lhs.as_any().downcast_ref().unwrap();
124                let scalar: &$T = $rhs.as_any().downcast_ref().unwrap();
125
126                let length = $offsets.len_proxy();
127
128                (0..length)
129                    .map(move |i| {
130                        let v = $validity.is_none_or(|v| v.get(i).unwrap());
131
132                        if !v {
133                            return $invalid_rv;
134                        }
135
136                        let (start, end) = unsafe { $offsets.start_end_unchecked(i) };
137
138                        if end - start != scalar.len() {
139                            return $ineq_len_rv;
140                        }
141
142                        // @TODO: I feel like there is a better way to do this.
143                        let mut values: $T = values.clone();
144                        <$T>::slice(&mut values, start, end - start);
145
146                        $true_op($op(&values, scalar))
147                    })
148                    .collect_trusted()
149            }};
150        }
151
152        assert_eq!(lhs.dtype(), rhs.dtype());
153
154        use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};
155        match lhs.dtype().to_physical_type() {
156            PH::Boolean => call_binary!(BooleanArray),
157            PH::BinaryView => call_binary!(BinaryViewArray),
158            PH::Utf8View => call_binary!(Utf8ViewArray),
159            PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),
160            PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),
161            PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),
162            PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),
163            PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),
164            PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),
165            PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),
166            PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),
167            PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),
168            PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),
169            PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<f16>),
170            PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),
171            PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),
172            PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),
173            PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),
174            PH::Primitive(PR::MonthDayNano) => {
175                call_binary!(PrimitiveArray<months_days_ns>)
176            },
177
178            #[cfg(feature = "dtype-array")]
179            PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),
180            #[cfg(not(feature = "dtype-array"))]
181            PH::FixedSizeList => todo!(
182                "Comparison of FixedSizeListArray is not supported without dtype-array feature"
183            ),
184
185            PH::Null => call_binary!(NullArray),
186            PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),
187            PH::Binary => call_binary!(BinaryArray<i32>),
188            PH::LargeBinary => call_binary!(BinaryArray<i64>),
189            PH::Utf8 => call_binary!(Utf8Array<i32>),
190            PH::LargeUtf8 => call_binary!(Utf8Array<i64>),
191            PH::List => call_binary!(ListArray<i32>),
192            PH::LargeList => call_binary!(ListArray<i64>),
193            PH::Struct => call_binary!(StructArray),
194            PH::Union => todo!("Comparison of UnionArrays is not yet supported"),
195            PH::Map => todo!("Comparison of MapArrays is not yet supported"),
196            PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),
197            PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),
198            PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),
199            PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),
200            PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),
201            PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),
202            PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),
203            PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),
204            PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),
205        }
206    }};
207}
208
209impl<O: Offset> TotalEqKernel for ListArray<O> {
210    type Scalar = Box<dyn Array>;
211
212    fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
213        compare!(
214            self,
215            other,
216            TotalEqKernel::tot_eq_missing_kernel,
217            |bm: Bitmap| bm.unset_bits() == 0,
218            false,
219            true
220        )
221    }
222
223    fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
224        compare!(
225            self,
226            other,
227            TotalEqKernel::tot_ne_missing_kernel,
228            |bm: Bitmap| bm.set_bits() > 0,
229            true,
230            false
231        )
232    }
233
234    fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
235        compare_broadcast!(
236            self.values().as_ref(),
237            other.as_ref(),
238            self.offsets(),
239            self.validity(),
240            TotalEqKernel::tot_eq_missing_kernel,
241            |bm: Bitmap| bm.unset_bits() == 0,
242            false,
243            true
244        )
245    }
246
247    fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
248        compare_broadcast!(
249            self.values().as_ref(),
250            other.as_ref(),
251            self.offsets(),
252            self.validity(),
253            TotalEqKernel::tot_ne_missing_kernel,
254            |bm: Bitmap| bm.set_bits() > 0,
255            true,
256            false
257        )
258    }
259}