polars_compute/filter/
boolean.rs

1use arrow::bitmap::Bitmap;
2use polars_utils::clmul::prefix_xorsum;
3
4const U56_MAX: u64 = (1 << 56) - 1;
5
6fn pext64_polyfill(mut v: u64, mut m: u64, m_popcnt: u32) -> u64 {
7    // Fast path: popcount is low.
8    if m_popcnt <= 4 {
9        // Not a "while m != 0" but a for loop instead so the compiler fully
10        // unrolls the loop, this makes bit << i much faster.
11        let mut out = 0;
12        for i in 0..4 {
13            if m == 0 {
14                break;
15            };
16
17            let bit = (v >> m.trailing_zeros()) & 1;
18            out |= bit << i;
19            m &= m.wrapping_sub(1); // Clear least significant bit.
20        }
21        return out;
22    }
23
24    // Fast path: all the masked bits in v are 0 or 1.
25    // Despite this fast path being simpler than the above popcount-based one,
26    // we do it afterwards because if m has a low popcount these branches become
27    // very unpredictable.
28    v &= m;
29    if v == 0 {
30        return 0;
31    } else if v == m {
32        return (1 << m_popcnt) - 1;
33    }
34
35    // This algorithm is too involved to explain here, see https://github.com/zwegner/zp7.
36    // That is an optimized version of Hacker's Delight Chapter 7-4, parallel suffix method for compress().
37    let mut invm = !m;
38
39    for i in 0..6 {
40        let shift = 1 << i;
41        let prefix_count_bit = if i < 5 {
42            prefix_xorsum(invm)
43        } else {
44            invm.wrapping_neg() << 1
45        };
46        let keep_in_place = v & !prefix_count_bit;
47        let shift_down = v & prefix_count_bit;
48        v = keep_in_place | (shift_down >> shift);
49        invm &= prefix_count_bit;
50    }
51    v
52}
53
54pub fn filter_boolean_kernel(values: &Bitmap, mask: &Bitmap) -> Bitmap {
55    assert_eq!(values.len(), mask.len());
56    let mask_bits_set = mask.set_bits();
57
58    // Fast path: values is all-0s or all-1s.
59    if let Some(num_values_bits) = values.lazy_set_bits() {
60        if num_values_bits == 0 || num_values_bits == values.len() {
61            return Bitmap::new_with_value(num_values_bits == values.len(), mask_bits_set);
62        }
63    }
64
65    // Fast path: mask is all-0s or all-1s.
66    if mask_bits_set == 0 {
67        return Bitmap::new();
68    } else if mask_bits_set == mask.len() {
69        return values.clone();
70    }
71
72    // Overallocate by 1 u64 so we can always do a full u64 write.
73    let num_words = mask_bits_set.div_ceil(64);
74    let num_bytes = 8 * (num_words + 1);
75    let mut out_vec: Vec<u8> = Vec::with_capacity(num_bytes);
76
77    unsafe {
78        if mask_bits_set <= mask.len() / (64 * 4) {
79            // Less than one in 1 in 4 words has a bit set on average, use sparse kernel.
80            filter_boolean_kernel_sparse(values, mask, out_vec.as_mut_ptr());
81        } else if polars_utils::cpuid::has_fast_bmi2() {
82            #[cfg(target_arch = "x86_64")]
83            filter_boolean_kernel_pext::<true, _>(values, mask, out_vec.as_mut_ptr(), |v, m, _| {
84                // SAFETY: has_fast_bmi2 ensures this is a legal instruction.
85                core::arch::x86_64::_pext_u64(v, m)
86            });
87        } else {
88            filter_boolean_kernel_pext::<false, _>(
89                values,
90                mask,
91                out_vec.as_mut_ptr(),
92                pext64_polyfill,
93            )
94        }
95
96        // SAFETY: the above filters must have initialized these bytes.
97        out_vec.set_len(mask_bits_set.div_ceil(8));
98    }
99
100    Bitmap::from_u8_vec(out_vec, mask_bits_set)
101}
102
103/// # Safety
104/// out_ptr must point to a buffer of length >= 8 + 8 * ceil(mask.set_bits() / 64).
105/// This function will initialize at least the first ceil(mask.set_bits() / 8) bytes.
106unsafe fn filter_boolean_kernel_sparse(values: &Bitmap, mask: &Bitmap, mut out_ptr: *mut u8) {
107    assert_eq!(values.len(), mask.len());
108
109    let mut value_idx = 0;
110    let mut bits_in_word = 0usize;
111    let mut word = 0u64;
112
113    macro_rules! loop_body {
114        ($m: expr) => {{
115            let mut m = $m;
116            while m > 0 {
117                let idx_in_m = m.trailing_zeros() as usize;
118                let bit = unsafe { values.get_bit_unchecked(value_idx + idx_in_m) };
119                word |= (bit as u64) << bits_in_word;
120                bits_in_word += 1;
121
122                if bits_in_word == 64 {
123                    unsafe {
124                        out_ptr.cast::<u64>().write_unaligned(word.to_le());
125                        out_ptr = out_ptr.add(8);
126                        bits_in_word = 0;
127                        word = 0;
128                    }
129                }
130
131                m &= m.wrapping_sub(1); // Clear least significant bit.
132            }
133        }};
134    }
135
136    let mask_aligned = mask.aligned::<u64>();
137    if mask_aligned.prefix_bitlen() > 0 {
138        loop_body!(mask_aligned.prefix());
139        value_idx += mask_aligned.prefix_bitlen();
140    }
141
142    for m in mask_aligned.bulk_iter() {
143        loop_body!(m);
144        value_idx += 64;
145    }
146
147    if mask_aligned.suffix_bitlen() > 0 {
148        loop_body!(mask_aligned.suffix());
149    }
150
151    if bits_in_word > 0 {
152        unsafe {
153            out_ptr.cast::<u64>().write_unaligned(word.to_le());
154        }
155    }
156}
157
158/// # Safety
159/// See filter_boolean_kernel_sparse.
160unsafe fn filter_boolean_kernel_pext<const HAS_NATIVE_PEXT: bool, F: Fn(u64, u64, u32) -> u64>(
161    values: &Bitmap,
162    mask: &Bitmap,
163    mut out_ptr: *mut u8,
164    pext: F,
165) {
166    assert_eq!(values.len(), mask.len());
167
168    let mut bits_in_word = 0usize;
169    let mut word = 0u64;
170
171    macro_rules! loop_body {
172        ($v: expr, $m: expr) => {{
173            let (v, m) = ($v, $m);
174
175            // Fast-path, all-0 mask.
176            if m == 0 {
177                continue;
178            }
179
180            // Fast path, all-1 mask.
181            // This is only worth it if we don't have a native pext.
182            if !HAS_NATIVE_PEXT && m == U56_MAX {
183                word |= v << bits_in_word;
184                unsafe {
185                    out_ptr.cast::<u64>().write_unaligned(word.to_le());
186                    out_ptr = out_ptr.add(7);
187                }
188                word >>= 56;
189                continue;
190            }
191
192            let mask_popcnt = m.count_ones();
193            let bits = pext(v, m, mask_popcnt);
194
195            // Because we keep bits_in_word < 8 and we iterate over u56s,
196            // this never loses output bits.
197            word |= bits << bits_in_word;
198            bits_in_word += mask_popcnt as usize;
199            unsafe {
200                out_ptr.cast::<u64>().write_unaligned(word.to_le());
201
202                let full_bytes_written = bits_in_word / 8;
203                out_ptr = out_ptr.add(full_bytes_written);
204                word >>= full_bytes_written * 8;
205                bits_in_word %= 8;
206            }
207        }};
208    }
209
210    let mut v_iter = values.fast_iter_u56();
211    let mut m_iter = mask.fast_iter_u56();
212    for v in &mut v_iter {
213        // SAFETY: we checked values and mask have same length.
214        let m = unsafe { m_iter.next().unwrap_unchecked() };
215        loop_body!(v, m);
216    }
217    let mut v_rem = v_iter.remainder().0;
218    let mut m_rem = m_iter.remainder().0;
219    while m_rem != 0 {
220        let v = v_rem & U56_MAX;
221        let m = m_rem & U56_MAX;
222        v_rem >>= 56;
223        m_rem >>= 56;
224        loop_body!(v, m); // Careful, contains 'continue', increment loop variables first.
225    }
226}
227
228pub fn filter_bitmap_and_validity(
229    values: &Bitmap,
230    validity: Option<&Bitmap>,
231    mask: &Bitmap,
232) -> (Bitmap, Option<Bitmap>) {
233    let filtered_values = filter_boolean_kernel(values, mask);
234    if let Some(validity) = validity {
235        // TODO: we could theoretically be faster by computing these two filters
236        // at once. Unsure if worth duplicating all the code above.
237        let filtered_validity = filter_boolean_kernel(validity, mask);
238        (filtered_values, Some(filtered_validity))
239    } else {
240        (filtered_values, None)
241    }
242}
243
244#[cfg(test)]
245mod test {
246    use rand::prelude::*;
247
248    use super::*;
249
250    fn naive_pext64(word: u64, mask: u64) -> u64 {
251        let mut out = 0;
252        let mut out_idx = 0;
253
254        for i in 0..64 {
255            let ith_mask_bit = (mask >> i) & 1;
256            let ith_word_bit = (word >> i) & 1;
257            if ith_mask_bit == 1 {
258                out |= ith_word_bit << out_idx;
259                out_idx += 1;
260            }
261        }
262
263        out
264    }
265
266    #[test]
267    fn test_pext64() {
268        // Verify polyfill against naive implementation.
269        let mut rng = StdRng::seed_from_u64(0xdeadbeef);
270        for _ in 0..100 {
271            let x = rng.gen();
272            let y = rng.gen();
273            assert_eq!(naive_pext64(x, y), pext64_polyfill(x, y, y.count_ones()));
274
275            // Test all-zeros and all-ones.
276            assert_eq!(naive_pext64(0, y), pext64_polyfill(0, y, y.count_ones()));
277            assert_eq!(
278                naive_pext64(u64::MAX, y),
279                pext64_polyfill(u64::MAX, y, y.count_ones())
280            );
281            assert_eq!(naive_pext64(x, 0), pext64_polyfill(x, 0, 0));
282            assert_eq!(naive_pext64(x, u64::MAX), pext64_polyfill(x, u64::MAX, 64));
283
284            // Test low popcount mask.
285            let popcnt = rng.gen_range(0..=8);
286            // Not perfect (can generate same bit twice) but it'll do.
287            let mask = (0..popcnt).map(|_| 1 << rng.gen_range(0..64)).sum();
288            assert_eq!(
289                naive_pext64(x, mask),
290                pext64_polyfill(x, mask, mask.count_ones())
291            );
292        }
293    }
294}