polars_compute/unique/
primitive.rs

1use std::ops::{Add, RangeInclusive, Sub};
2
3use arrow::array::PrimitiveArray;
4use arrow::bitmap::bitmask::BitMask;
5use arrow::bitmap::{BitmapBuilder, MutableBitmap};
6use arrow::datatypes::ArrowDataType;
7use arrow::types::NativeType;
8use num_traits::{FromPrimitive, ToPrimitive};
9use polars_utils::total_ord::TotalOrd;
10
11use super::RangedUniqueKernel;
12
13/// A specialized unique kernel for [`PrimitiveArray`] for when all values are in a small known
14/// range.
15pub struct PrimitiveRangedUniqueState<T: NativeType> {
16    seen: Seen,
17    range: RangeInclusive<T>,
18}
19
20enum Seen {
21    Small(u128),
22    Large(MutableBitmap),
23}
24
25impl Seen {
26    pub fn from_size(size: usize) -> Self {
27        if size <= 128 {
28            Self::Small(0)
29        } else {
30            Self::Large(MutableBitmap::from_len_zeroed(size))
31        }
32    }
33
34    fn num_seen(&self) -> usize {
35        match self {
36            Seen::Small(v) => v.count_ones() as usize,
37            Seen::Large(v) => v.set_bits(),
38        }
39    }
40
41    fn has_seen_null(&self, size: usize) -> bool {
42        match self {
43            Self::Small(v) => v >> (size - 1) != 0,
44            Self::Large(v) => v.get(size - 1),
45        }
46    }
47}
48
49impl<T: NativeType> PrimitiveRangedUniqueState<T>
50where
51    T: Add<T, Output = T> + Sub<T, Output = T> + ToPrimitive + FromPrimitive,
52{
53    pub fn new(min_value: T, max_value: T) -> Self {
54        let size = (max_value - min_value).to_usize().unwrap();
55        // Range is inclusive
56        let size = size + 1;
57        // One value is left for null
58        let size = size + 1;
59
60        Self {
61            seen: Seen::from_size(size),
62            range: min_value..=max_value,
63        }
64    }
65
66    fn size(&self) -> usize {
67        (*self.range.end() - *self.range.start())
68            .to_usize()
69            .unwrap()
70            + 1
71    }
72}
73
74impl<T: NativeType> RangedUniqueKernel for PrimitiveRangedUniqueState<T>
75where
76    T: Add<T, Output = T> + Sub<T, Output = T> + ToPrimitive + FromPrimitive,
77{
78    type Array = PrimitiveArray<T>;
79
80    fn has_seen_all(&self) -> bool {
81        let size = self.size();
82        match &self.seen {
83            Seen::Small(v) if size == 128 => !v == 0,
84            Seen::Small(v) => *v == ((1 << size) - 1),
85            Seen::Large(v) => BitMask::new(v.as_slice(), 0, size).unset_bits() == 0,
86        }
87    }
88
89    fn append(&mut self, array: &Self::Array) {
90        let size = self.size();
91        match array.validity().as_ref().filter(|v| v.unset_bits() > 0) {
92            None => {
93                const STEP_SIZE: usize = 512;
94
95                let mut i = 0;
96                let values = array.values().as_slice();
97
98                match self.seen {
99                    Seen::Small(ref mut seen) => {
100                        // Check every so often whether we have already seen all the values.
101                        while *seen != ((1 << (size - 1)) - 1) && i < values.len() {
102                            for v in values[i..].iter().take(STEP_SIZE) {
103                                if cfg!(debug_assertions) {
104                                    assert!(TotalOrd::tot_ge(v, self.range.start()));
105                                    assert!(TotalOrd::tot_le(v, self.range.end()));
106                                }
107
108                                let v = *v - *self.range.start();
109                                let v = unsafe { v.to_usize().unwrap_unchecked() };
110                                *seen |= 1 << v;
111                            }
112
113                            i += STEP_SIZE;
114                        }
115                    },
116                    Seen::Large(ref mut seen) => {
117                        // Check every so often whether we have already seen all the values.
118                        while BitMask::new(seen.as_slice(), 0, size - 1).unset_bits() > 0
119                            && i < values.len()
120                        {
121                            for v in values[i..].iter().take(STEP_SIZE) {
122                                if cfg!(debug_assertions) {
123                                    assert!(TotalOrd::tot_ge(v, self.range.start()));
124                                    assert!(TotalOrd::tot_le(v, self.range.end()));
125                                }
126
127                                let v = *v - *self.range.start();
128                                let v = unsafe { v.to_usize().unwrap_unchecked() };
129                                seen.set(v, true);
130                            }
131
132                            i += STEP_SIZE;
133                        }
134                    },
135                }
136            },
137            Some(_) => {
138                let iter = array.non_null_values_iter();
139
140                match self.seen {
141                    Seen::Small(ref mut seen) => {
142                        *seen |= 1 << (size - 1);
143
144                        for v in iter {
145                            if cfg!(debug_assertions) {
146                                assert!(TotalOrd::tot_ge(&v, self.range.start()));
147                                assert!(TotalOrd::tot_le(&v, self.range.end()));
148                            }
149
150                            let v = v - *self.range.start();
151                            let v = unsafe { v.to_usize().unwrap_unchecked() };
152                            *seen |= 1 << v;
153                        }
154                    },
155                    Seen::Large(ref mut seen) => {
156                        seen.set(size - 1, true);
157
158                        for v in iter {
159                            if cfg!(debug_assertions) {
160                                assert!(TotalOrd::tot_ge(&v, self.range.start()));
161                                assert!(TotalOrd::tot_le(&v, self.range.end()));
162                            }
163
164                            let v = v - *self.range.start();
165                            let v = unsafe { v.to_usize().unwrap_unchecked() };
166                            seen.set(v, true);
167                        }
168                    },
169                }
170            },
171        }
172    }
173
174    fn append_state(&mut self, other: &Self) {
175        debug_assert_eq!(self.size(), other.size());
176        match (&mut self.seen, &other.seen) {
177            (Seen::Small(lhs), Seen::Small(rhs)) => *lhs |= rhs,
178            (Seen::Large(lhs), Seen::Large(ref rhs)) => {
179                let mut lhs = lhs;
180                <&mut MutableBitmap as std::ops::BitOrAssign<&MutableBitmap>>::bitor_assign(
181                    &mut lhs, rhs,
182                )
183            },
184            _ => unreachable!(),
185        }
186    }
187
188    fn finalize_unique(self) -> Self::Array {
189        let size = self.size();
190        let seen = self.seen;
191
192        let has_null = seen.has_seen_null(size);
193        let num_values = seen.num_seen();
194        let mut values = Vec::with_capacity(num_values);
195
196        let mut offset = 0;
197        match seen {
198            Seen::Small(mut v) => {
199                while v != 0 {
200                    let shift = v.trailing_zeros();
201                    offset += shift as u8;
202                    values.push(*self.range.start() + T::from_u8(offset).unwrap());
203
204                    v >>= shift + 1;
205                    offset += 1;
206                }
207            },
208            Seen::Large(v) => {
209                for offset in v.freeze().true_idx_iter() {
210                    values.push(*self.range.start() + T::from_usize(offset).unwrap());
211                }
212            },
213        }
214
215        let validity = if has_null {
216            let mut validity = BitmapBuilder::new();
217            validity.extend_constant(values.len() - 1, true);
218            validity.push(false);
219            // The null has already been pushed.
220            *values.last_mut().unwrap() = T::zeroed();
221            Some(validity.freeze())
222        } else {
223            None
224        };
225
226        PrimitiveArray::new(ArrowDataType::from(T::PRIMITIVE), values.into(), validity)
227    }
228
229    fn finalize_n_unique(&self) -> usize {
230        self.seen.num_seen()
231    }
232
233    fn finalize_n_unique_non_null(&self) -> usize {
234        self.seen.num_seen() - usize::from(self.seen.has_seen_null(self.size()))
235    }
236}