polars_arrow/bitmap/
bitmask.rs1#[cfg(feature = "simd")]
2use std::simd::{LaneCount, Mask, MaskElement, SupportedLaneCount};
3
4use polars_utils::slice::load_padded_le_u64;
5
6use super::iterator::FastU56BitmapIter;
7use super::utils::{count_zeros, fmt, BitmapIter};
8use crate::bitmap::Bitmap;
9
10fn nth_set_bit_u32(w: u32, n: u32) -> Option<u32> {
13 #[cfg(all(not(miri), target_feature = "bmi2"))]
20 {
21 if n >= 32 {
22 return None;
23 }
24
25 let nth_set_bit = unsafe { core::arch::x86_64::_pdep_u32(1 << n, w) };
26 if nth_set_bit == 0 {
27 return None;
28 }
29
30 Some(nth_set_bit.trailing_zeros())
31 }
32
33 #[cfg(any(miri, not(target_feature = "bmi2")))]
34 {
35 let set_per_2 = w - ((w >> 1) & 0x55555555);
37 let set_per_4 = (set_per_2 & 0x33333333) + ((set_per_2 >> 2) & 0x33333333);
38 let set_per_8 = (set_per_4 + (set_per_4 >> 4)) & 0x0f0f0f0f;
39 let set_per_16 = (set_per_8 + (set_per_8 >> 8)) & 0x00ff00ff;
40 let set_per_32 = (set_per_16 + (set_per_16 >> 16)) & 0xff;
41 if n >= set_per_32 {
42 return None;
43 }
44
45 let mut idx = 0;
46 let mut n = n;
47 let next16 = set_per_16 & 0xff;
48 if n >= next16 {
49 n -= next16;
50 idx += 16;
51 }
52 let next8 = (set_per_8 >> idx) & 0xff;
53 if n >= next8 {
54 n -= next8;
55 idx += 8;
56 }
57 let next4 = (set_per_4 >> idx) & 0b1111;
58 if n >= next4 {
59 n -= next4;
60 idx += 4;
61 }
62 let next2 = (set_per_2 >> idx) & 0b11;
63 if n >= next2 {
64 n -= next2;
65 idx += 2;
66 }
67 let next1 = (w >> idx) & 0b1;
68 if n >= next1 {
69 idx += 1;
70 }
71 Some(idx)
72 }
73}
74
75#[derive(Default, Clone)]
76pub struct BitMask<'a> {
77 bytes: &'a [u8],
78 offset: usize,
79 len: usize,
80}
81
82impl std::fmt::Debug for BitMask<'_> {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 let Self { bytes, offset, len } = self;
85 let offset_num_bytes = offset / 8;
86 let offset_in_byte = offset % 8;
87 fmt(&bytes[offset_num_bytes..], offset_in_byte, *len, f)
88 }
89}
90
91impl<'a> BitMask<'a> {
92 pub fn from_bitmap(bitmap: &'a Bitmap) -> Self {
93 let (bytes, offset, len) = bitmap.as_slice();
94 Self::new(bytes, offset, len)
95 }
96
97 pub fn inner(&self) -> (&[u8], usize, usize) {
98 (self.bytes, self.offset, self.len)
99 }
100
101 pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {
102 assert!(bytes.len() * 8 >= len + offset);
104 Self { bytes, offset, len }
105 }
106
107 #[inline(always)]
108 pub fn len(&self) -> usize {
109 self.len
110 }
111
112 #[inline]
113 pub fn advance_by(&mut self, idx: usize) {
114 assert!(idx <= self.len);
115 self.offset += idx;
116 self.len -= idx;
117 }
118
119 #[inline]
120 pub fn split_at(&self, idx: usize) -> (Self, Self) {
121 assert!(idx <= self.len);
122 unsafe { self.split_at_unchecked(idx) }
123 }
124
125 #[inline]
128 pub unsafe fn split_at_unchecked(&self, idx: usize) -> (Self, Self) {
129 debug_assert!(idx <= self.len);
130 let left = Self { len: idx, ..*self };
131 let right = Self {
132 len: self.len - idx,
133 offset: self.offset + idx,
134 ..*self
135 };
136 (left, right)
137 }
138
139 #[inline]
140 pub fn sliced(&self, offset: usize, length: usize) -> Self {
141 assert!(offset.checked_add(length).unwrap() <= self.len);
142 unsafe { self.sliced_unchecked(offset, length) }
143 }
144
145 #[inline]
148 pub unsafe fn sliced_unchecked(&self, offset: usize, length: usize) -> Self {
149 if cfg!(debug_assertions) {
150 assert!(offset.checked_add(length).unwrap() <= self.len);
151 }
152
153 Self {
154 bytes: self.bytes,
155 offset: self.offset + offset,
156 len: length,
157 }
158 }
159
160 pub fn unset_bits(&self) -> usize {
161 count_zeros(self.bytes, self.offset, self.len)
162 }
163
164 pub fn set_bits(&self) -> usize {
165 self.len - self.unset_bits()
166 }
167
168 pub fn fast_iter_u56(&self) -> FastU56BitmapIter {
169 FastU56BitmapIter::new(self.bytes, self.offset, self.len)
170 }
171
172 #[cfg(feature = "simd")]
173 #[inline]
174 pub fn get_simd<T, const N: usize>(&self, idx: usize) -> Mask<T, N>
175 where
176 T: MaskElement,
177 LaneCount<N>: SupportedLaneCount,
178 {
179 let lanes = LaneCount::<N>::BITMASK_LEN;
183 assert!(lanes < 64);
184
185 let start_byte_idx = (self.offset + idx) / 8;
186 let byte_shift = (self.offset + idx) % 8;
187 if idx + lanes <= self.len {
188 let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
190 Mask::from_bitmask(mask >> byte_shift)
191 } else if idx < self.len {
192 let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
195 let num_out_of_bounds = idx + lanes - self.len;
196 let shifted = (mask << num_out_of_bounds) >> (num_out_of_bounds + byte_shift);
197 Mask::from_bitmask(shifted)
198 } else {
199 Mask::from_bitmask(0u64)
200 }
201 }
202
203 #[inline]
204 pub fn get_u32(&self, idx: usize) -> u32 {
205 let start_byte_idx = (self.offset + idx) / 8;
206 let byte_shift = (self.offset + idx) % 8;
207 if idx + 32 <= self.len {
208 let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
210 (mask >> byte_shift) as u32
211 } else if idx < self.len {
212 let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
215 let out_of_bounds_mask = (1u32 << (self.len - idx)) - 1;
216 ((mask >> byte_shift) as u32) & out_of_bounds_mask
217 } else {
218 0
219 }
220 }
221
222 pub fn nth_set_bit_idx(&self, mut n: usize, mut start: usize) -> Option<usize> {
228 while start < self.len {
229 let next_u32_mask = self.get_u32(start);
230 if next_u32_mask == u32::MAX {
231 if n < 32 {
233 return Some(start + n);
234 }
235 n -= 32;
236 } else {
237 let ones = next_u32_mask.count_ones() as usize;
238 if n < ones {
239 let idx = unsafe {
240 nth_set_bit_u32(next_u32_mask, n as u32).unwrap_unchecked() as usize
242 };
243 return Some(start + idx);
244 }
245 n -= ones;
246 }
247
248 start += 32;
249 }
250
251 None
252 }
253
254 pub fn nth_set_bit_idx_rev(&self, mut n: usize, mut end: usize) -> Option<usize> {
260 while end > 0 {
261 let (u32_mask_start, u32_mask_mask) = if end >= 32 {
264 (end - 32, u32::MAX)
265 } else {
266 (0, (1 << end) - 1)
267 };
268 let next_u32_mask = self.get_u32(u32_mask_start) & u32_mask_mask;
269 if next_u32_mask == u32::MAX {
270 if n < 32 {
272 return Some(end - 1 - n);
273 }
274 n -= 32;
275 } else {
276 let ones = next_u32_mask.count_ones() as usize;
277 if n < ones {
278 let rev_n = ones - 1 - n;
279 let idx = unsafe {
280 nth_set_bit_u32(next_u32_mask, rev_n as u32).unwrap_unchecked() as usize
282 };
283 return Some(u32_mask_start + idx);
284 }
285 n -= ones;
286 }
287
288 end = u32_mask_start;
289 }
290
291 None
292 }
293
294 #[inline]
295 pub fn get(&self, idx: usize) -> bool {
296 let byte_idx = (self.offset + idx) / 8;
297 let byte_shift = (self.offset + idx) % 8;
298
299 if idx < self.len {
300 let byte = unsafe { *self.bytes.get_unchecked(byte_idx) };
302 (byte >> byte_shift) & 1 == 1
303 } else {
304 false
305 }
306 }
307
308 pub fn iter(&self) -> BitmapIter {
309 BitmapIter::new(self.bytes, self.offset, self.len)
310 }
311}
312
313#[cfg(test)]
314mod test {
315 use super::*;
316
317 fn naive_nth_bit_set(mut w: u32, mut n: u32) -> Option<u32> {
318 for i in 0..32 {
319 if w & (1 << i) != 0 {
320 if n == 0 {
321 return Some(i);
322 }
323 n -= 1;
324 w ^= 1 << i;
325 }
326 }
327 None
328 }
329
330 #[test]
331 fn test_nth_set_bit_u32() {
332 for n in 0..256 {
333 assert_eq!(nth_set_bit_u32(0, n), None);
334 }
335
336 for i in 0..32 {
337 assert_eq!(nth_set_bit_u32(1 << i, 0), Some(i));
338 assert_eq!(nth_set_bit_u32(1 << i, 1), None);
339 }
340
341 for i in 0..10000 {
342 let rnd = (0xbdbc9d8ec9d5c461u64.wrapping_mul(i as u64) >> 32) as u32;
343 for i in 0..=32 {
344 assert_eq!(nth_set_bit_u32(rnd, i), naive_nth_bit_set(rnd, i));
345 }
346 }
347 }
348}