polars_arrow/bitmap/
bitmask.rs

1#[cfg(feature = "simd")]
2use std::simd::{LaneCount, Mask, MaskElement, SupportedLaneCount};
3
4use polars_utils::slice::load_padded_le_u64;
5
6use super::iterator::FastU56BitmapIter;
7use super::utils::{count_zeros, fmt, BitmapIter};
8use crate::bitmap::Bitmap;
9
10/// Returns the nth set bit in w, if n+1 bits are set. The indexing is
11/// zero-based, nth_set_bit_u32(w, 0) returns the least significant set bit in w.
12fn nth_set_bit_u32(w: u32, n: u32) -> Option<u32> {
13    // If we have BMI2's PDEP available, we use it. It takes the lower order
14    // bits of the first argument and spreads it along its second argument
15    // where those bits are 1. So PDEP(abcdefgh, 11001001) becomes ef00g00h.
16    // We use this by setting the first argument to 1 << n, which means the
17    // first n-1 zero bits of it will spread to the first n-1 one bits of w,
18    // after which the one bit will exactly get copied to the nth one bit of w.
19    #[cfg(all(not(miri), target_feature = "bmi2"))]
20    {
21        if n >= 32 {
22            return None;
23        }
24
25        let nth_set_bit = unsafe { core::arch::x86_64::_pdep_u32(1 << n, w) };
26        if nth_set_bit == 0 {
27            return None;
28        }
29
30        Some(nth_set_bit.trailing_zeros())
31    }
32
33    #[cfg(any(miri, not(target_feature = "bmi2")))]
34    {
35        // Each block of 2/4/8/16 bits contains how many set bits there are in that block.
36        let set_per_2 = w - ((w >> 1) & 0x55555555);
37        let set_per_4 = (set_per_2 & 0x33333333) + ((set_per_2 >> 2) & 0x33333333);
38        let set_per_8 = (set_per_4 + (set_per_4 >> 4)) & 0x0f0f0f0f;
39        let set_per_16 = (set_per_8 + (set_per_8 >> 8)) & 0x00ff00ff;
40        let set_per_32 = (set_per_16 + (set_per_16 >> 16)) & 0xff;
41        if n >= set_per_32 {
42            return None;
43        }
44
45        let mut idx = 0;
46        let mut n = n;
47        let next16 = set_per_16 & 0xff;
48        if n >= next16 {
49            n -= next16;
50            idx += 16;
51        }
52        let next8 = (set_per_8 >> idx) & 0xff;
53        if n >= next8 {
54            n -= next8;
55            idx += 8;
56        }
57        let next4 = (set_per_4 >> idx) & 0b1111;
58        if n >= next4 {
59            n -= next4;
60            idx += 4;
61        }
62        let next2 = (set_per_2 >> idx) & 0b11;
63        if n >= next2 {
64            n -= next2;
65            idx += 2;
66        }
67        let next1 = (w >> idx) & 0b1;
68        if n >= next1 {
69            idx += 1;
70        }
71        Some(idx)
72    }
73}
74
75#[derive(Default, Clone)]
76pub struct BitMask<'a> {
77    bytes: &'a [u8],
78    offset: usize,
79    len: usize,
80}
81
82impl std::fmt::Debug for BitMask<'_> {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        let Self { bytes, offset, len } = self;
85        let offset_num_bytes = offset / 8;
86        let offset_in_byte = offset % 8;
87        fmt(&bytes[offset_num_bytes..], offset_in_byte, *len, f)
88    }
89}
90
91impl<'a> BitMask<'a> {
92    pub fn from_bitmap(bitmap: &'a Bitmap) -> Self {
93        let (bytes, offset, len) = bitmap.as_slice();
94        Self::new(bytes, offset, len)
95    }
96
97    pub fn inner(&self) -> (&[u8], usize, usize) {
98        (self.bytes, self.offset, self.len)
99    }
100
101    pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {
102        // Check length so we can use unsafe access in our get.
103        assert!(bytes.len() * 8 >= len + offset);
104        Self { bytes, offset, len }
105    }
106
107    #[inline(always)]
108    pub fn len(&self) -> usize {
109        self.len
110    }
111
112    #[inline]
113    pub fn advance_by(&mut self, idx: usize) {
114        assert!(idx <= self.len);
115        self.offset += idx;
116        self.len -= idx;
117    }
118
119    #[inline]
120    pub fn split_at(&self, idx: usize) -> (Self, Self) {
121        assert!(idx <= self.len);
122        unsafe { self.split_at_unchecked(idx) }
123    }
124
125    /// # Safety
126    /// The index must be in-bounds.
127    #[inline]
128    pub unsafe fn split_at_unchecked(&self, idx: usize) -> (Self, Self) {
129        debug_assert!(idx <= self.len);
130        let left = Self { len: idx, ..*self };
131        let right = Self {
132            len: self.len - idx,
133            offset: self.offset + idx,
134            ..*self
135        };
136        (left, right)
137    }
138
139    #[inline]
140    pub fn sliced(&self, offset: usize, length: usize) -> Self {
141        assert!(offset.checked_add(length).unwrap() <= self.len);
142        unsafe { self.sliced_unchecked(offset, length) }
143    }
144
145    /// # Safety
146    /// The index must be in-bounds.
147    #[inline]
148    pub unsafe fn sliced_unchecked(&self, offset: usize, length: usize) -> Self {
149        if cfg!(debug_assertions) {
150            assert!(offset.checked_add(length).unwrap() <= self.len);
151        }
152
153        Self {
154            bytes: self.bytes,
155            offset: self.offset + offset,
156            len: length,
157        }
158    }
159
160    pub fn unset_bits(&self) -> usize {
161        count_zeros(self.bytes, self.offset, self.len)
162    }
163
164    pub fn set_bits(&self) -> usize {
165        self.len - self.unset_bits()
166    }
167
168    pub fn fast_iter_u56(&self) -> FastU56BitmapIter {
169        FastU56BitmapIter::new(self.bytes, self.offset, self.len)
170    }
171
172    #[cfg(feature = "simd")]
173    #[inline]
174    pub fn get_simd<T, const N: usize>(&self, idx: usize) -> Mask<T, N>
175    where
176        T: MaskElement,
177        LaneCount<N>: SupportedLaneCount,
178    {
179        // We don't support 64-lane masks because then we couldn't load our
180        // bitwise mask as a u64 and then do the byteshift on it.
181
182        let lanes = LaneCount::<N>::BITMASK_LEN;
183        assert!(lanes < 64);
184
185        let start_byte_idx = (self.offset + idx) / 8;
186        let byte_shift = (self.offset + idx) % 8;
187        if idx + lanes <= self.len {
188            // SAFETY: fast path, we know this is completely in-bounds.
189            let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
190            Mask::from_bitmask(mask >> byte_shift)
191        } else if idx < self.len {
192            // SAFETY: we know that at least the first byte is in-bounds.
193            // This is partially out of bounds, we have to do extra masking.
194            let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
195            let num_out_of_bounds = idx + lanes - self.len;
196            let shifted = (mask << num_out_of_bounds) >> (num_out_of_bounds + byte_shift);
197            Mask::from_bitmask(shifted)
198        } else {
199            Mask::from_bitmask(0u64)
200        }
201    }
202
203    #[inline]
204    pub fn get_u32(&self, idx: usize) -> u32 {
205        let start_byte_idx = (self.offset + idx) / 8;
206        let byte_shift = (self.offset + idx) % 8;
207        if idx + 32 <= self.len {
208            // SAFETY: fast path, we know this is completely in-bounds.
209            let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
210            (mask >> byte_shift) as u32
211        } else if idx < self.len {
212            // SAFETY: we know that at least the first byte is in-bounds.
213            // This is partially out of bounds, we have to do extra masking.
214            let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
215            let out_of_bounds_mask = (1u32 << (self.len - idx)) - 1;
216            ((mask >> byte_shift) as u32) & out_of_bounds_mask
217        } else {
218            0
219        }
220    }
221
222    /// Computes the index of the nth set bit after start.
223    ///
224    /// Both are zero-indexed, so `nth_set_bit_idx(0, 0)` finds the index of the
225    /// first bit set (which can be 0 as well). The returned index is absolute,
226    /// not relative to start.
227    pub fn nth_set_bit_idx(&self, mut n: usize, mut start: usize) -> Option<usize> {
228        while start < self.len {
229            let next_u32_mask = self.get_u32(start);
230            if next_u32_mask == u32::MAX {
231                // Happy fast path for dense non-null section.
232                if n < 32 {
233                    return Some(start + n);
234                }
235                n -= 32;
236            } else {
237                let ones = next_u32_mask.count_ones() as usize;
238                if n < ones {
239                    let idx = unsafe {
240                        // SAFETY: we know the nth bit is in the mask.
241                        nth_set_bit_u32(next_u32_mask, n as u32).unwrap_unchecked() as usize
242                    };
243                    return Some(start + idx);
244                }
245                n -= ones;
246            }
247
248            start += 32;
249        }
250
251        None
252    }
253
254    /// Computes the index of the nth set bit before end, counting backwards.
255    ///
256    /// Both are zero-indexed, so nth_set_bit_idx_rev(0, len) finds the index of
257    /// the last bit set (which can be 0 as well). The returned index is
258    /// absolute (and starts at the beginning), not relative to end.
259    pub fn nth_set_bit_idx_rev(&self, mut n: usize, mut end: usize) -> Option<usize> {
260        while end > 0 {
261            // We want to find bits *before* end, so if end < 32 we must mask
262            // out the bits after the endth.
263            let (u32_mask_start, u32_mask_mask) = if end >= 32 {
264                (end - 32, u32::MAX)
265            } else {
266                (0, (1 << end) - 1)
267            };
268            let next_u32_mask = self.get_u32(u32_mask_start) & u32_mask_mask;
269            if next_u32_mask == u32::MAX {
270                // Happy fast path for dense non-null section.
271                if n < 32 {
272                    return Some(end - 1 - n);
273                }
274                n -= 32;
275            } else {
276                let ones = next_u32_mask.count_ones() as usize;
277                if n < ones {
278                    let rev_n = ones - 1 - n;
279                    let idx = unsafe {
280                        // SAFETY: we know the rev_nth bit is in the mask.
281                        nth_set_bit_u32(next_u32_mask, rev_n as u32).unwrap_unchecked() as usize
282                    };
283                    return Some(u32_mask_start + idx);
284                }
285                n -= ones;
286            }
287
288            end = u32_mask_start;
289        }
290
291        None
292    }
293
294    #[inline]
295    pub fn get(&self, idx: usize) -> bool {
296        let byte_idx = (self.offset + idx) / 8;
297        let byte_shift = (self.offset + idx) % 8;
298
299        if idx < self.len {
300            // SAFETY: we know this is in-bounds.
301            let byte = unsafe { *self.bytes.get_unchecked(byte_idx) };
302            (byte >> byte_shift) & 1 == 1
303        } else {
304            false
305        }
306    }
307
308    pub fn iter(&self) -> BitmapIter {
309        BitmapIter::new(self.bytes, self.offset, self.len)
310    }
311}
312
313#[cfg(test)]
314mod test {
315    use super::*;
316
317    fn naive_nth_bit_set(mut w: u32, mut n: u32) -> Option<u32> {
318        for i in 0..32 {
319            if w & (1 << i) != 0 {
320                if n == 0 {
321                    return Some(i);
322                }
323                n -= 1;
324                w ^= 1 << i;
325            }
326        }
327        None
328    }
329
330    #[test]
331    fn test_nth_set_bit_u32() {
332        for n in 0..256 {
333            assert_eq!(nth_set_bit_u32(0, n), None);
334        }
335
336        for i in 0..32 {
337            assert_eq!(nth_set_bit_u32(1 << i, 0), Some(i));
338            assert_eq!(nth_set_bit_u32(1 << i, 1), None);
339        }
340
341        for i in 0..10000 {
342            let rnd = (0xbdbc9d8ec9d5c461u64.wrapping_mul(i as u64) >> 32) as u32;
343            for i in 0..=32 {
344                assert_eq!(nth_set_bit_u32(rnd, i), naive_nth_bit_set(rnd, i));
345            }
346        }
347    }
348}