polars_compute/
hyperloglogplus.rs

1//! # HyperLogLogPlus
2//!
3//! `hyperloglogplus` module contains implementation of HyperLogLogPlus
4//! algorithm for cardinality estimation so that [`crate::series::approx_n_unique`] function can
5//! be efficiently implemented.
6//!
7//! This module borrows code from [arrow-datafusion](https://github.com/apache/arrow-datafusion/blob/93771052c5ac31f2cf22b8c25bf938656afe1047/datafusion/physical-expr/src/aggregate/hyperloglog.rs).
8//!
9//! # Examples
10//!
11//! ```
12//!     # use polars_compute::hyperloglogplus::*;
13//!     let mut hllp = HyperLogLog::new();
14//!     hllp.add(&12345);
15//!     hllp.add(&23456);
16//!
17//!     assert_eq!(hllp.count(), 2);
18//! ```
19
20use std::hash::Hash;
21use std::marker::PhantomData;
22
23use polars_utils::aliases::PlRandomStateQuality;
24
25/// The greater is P, the smaller the error.
26const HLL_P: usize = 14_usize;
27/// The number of bits of the hash value used determining the number of leading zeros
28const HLL_Q: usize = 64_usize - HLL_P;
29const NUM_REGISTERS: usize = 1_usize << HLL_P;
30/// Mask to obtain index into the registers
31const HLL_P_MASK: u64 = (NUM_REGISTERS as u64) - 1;
32
33#[derive(Clone, Debug)]
34pub struct HyperLogLog<T>
35where
36    T: Hash + ?Sized,
37{
38    registers: [u8; NUM_REGISTERS],
39    phantom: PhantomData<T>,
40}
41
42impl<T> Default for HyperLogLog<T>
43where
44    T: Hash + ?Sized,
45{
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51/// Fixed seed for the hashing so that values are consistent across runs
52///
53/// Note that when we later move on to have serialized HLL register binaries
54/// shared across cluster, this SEED will have to be consistent across all
55/// parties otherwise we might have corruption. So ideally for later this seed
56/// shall be part of the serialized form (or stay unchanged across versions).
57const SEED: PlRandomStateQuality = PlRandomStateQuality::with_seeds(
58    0x885f6cab121d01a3_u64,
59    0x71e4379f2976ad8f_u64,
60    0xbf30173dd28a8816_u64,
61    0x0eaea5d736d733a4_u64,
62);
63
64impl<T> HyperLogLog<T>
65where
66    T: Hash + ?Sized,
67{
68    /// Creates a new, empty HyperLogLog.
69    pub fn new() -> Self {
70        let registers = [0; NUM_REGISTERS];
71        Self::new_with_registers(registers)
72    }
73
74    /// Creates a HyperLogLog from already populated registers
75    /// note that this method should not be invoked in untrusted environment
76    /// because the internal structure of registers are not examined.
77    pub(crate) fn new_with_registers(registers: [u8; NUM_REGISTERS]) -> Self {
78        Self {
79            registers,
80            phantom: PhantomData,
81        }
82    }
83
84    #[inline]
85    fn hash_value(&self, obj: &T) -> u64 {
86        SEED.hash_one(obj)
87    }
88
89    /// Adds an element to the HyperLogLog.
90    pub fn add(&mut self, obj: &T) {
91        let hash = self.hash_value(obj);
92        let index = (hash & HLL_P_MASK) as usize;
93        let p = ((hash >> HLL_P) | (1_u64 << HLL_Q)).trailing_zeros() + 1;
94        self.registers[index] = self.registers[index].max(p as u8);
95    }
96
97    /// Get the register histogram (each value in register index into
98    /// the histogram; u32 is enough because we only have 2**14=16384 registers
99    #[inline]
100    fn get_histogram(&self) -> [u32; HLL_Q + 2] {
101        let mut histogram = [0; HLL_Q + 2];
102        // hopefully this can be unrolled
103        for r in self.registers {
104            histogram[r as usize] += 1;
105        }
106        histogram
107    }
108
109    /// Merge the other [`HyperLogLog`] into this one
110    pub fn merge(&mut self, other: &HyperLogLog<T>) {
111        assert!(
112            self.registers.len() == other.registers.len(),
113            "unexpected got unequal register size, expect {}, got {}",
114            self.registers.len(),
115            other.registers.len()
116        );
117        for i in 0..self.registers.len() {
118            self.registers[i] = self.registers[i].max(other.registers[i]);
119        }
120    }
121
122    /// Guess the number of unique elements seen by the HyperLogLog.
123    pub fn count(&self) -> usize {
124        let histogram = self.get_histogram();
125        let m = NUM_REGISTERS as f64;
126        let mut z = m * hll_tau((m - histogram[HLL_Q + 1] as f64) / m);
127        for i in histogram[1..=HLL_Q].iter().rev() {
128            z += *i as f64;
129            z *= 0.5;
130        }
131        z += m * hll_sigma(histogram[0] as f64 / m);
132        (0.5 / 2_f64.ln() * m * m / z).round() as usize
133    }
134}
135
136/// Helper function sigma as defined in
137/// "New cardinality estimation algorithms for HyperLogLog sketches"
138/// Otmar Ertl, arXiv:1702.01284
139#[inline]
140fn hll_sigma(x: f64) -> f64 {
141    if x == 1. {
142        f64::INFINITY
143    } else {
144        let mut y = 1.0;
145        let mut z = x;
146        let mut x = x;
147        loop {
148            x *= x;
149            let z_prime = z;
150            z += x * y;
151            y += y;
152            if z_prime == z {
153                break;
154            }
155        }
156        z
157    }
158}
159
160/// Helper function tau as defined in
161/// "New cardinality estimation algorithms for HyperLogLog sketches"
162/// Otmar Ertl, arXiv:1702.01284
163#[inline]
164fn hll_tau(x: f64) -> f64 {
165    if x == 0.0 || x == 1.0 {
166        0.0
167    } else {
168        let mut y = 1.0;
169        let mut z = 1.0 - x;
170        let mut x = x;
171        loop {
172            x = x.sqrt();
173            let z_prime = z;
174            y *= 0.5;
175            z -= (1.0 - x).powi(2) * y;
176            if z_prime == z {
177                break;
178            }
179        }
180        z / 3.0
181    }
182}
183
184impl<T> AsRef<[u8]> for HyperLogLog<T>
185where
186    T: Hash + ?Sized,
187{
188    fn as_ref(&self) -> &[u8] {
189        &self.registers
190    }
191}
192
193impl<T> Extend<T> for HyperLogLog<T>
194where
195    T: Hash,
196{
197    fn extend<S: IntoIterator<Item = T>>(&mut self, iter: S) {
198        for elem in iter {
199            self.add(&elem);
200        }
201    }
202}
203
204impl<'a, T> Extend<&'a T> for HyperLogLog<T>
205where
206    T: 'a + Hash + ?Sized,
207{
208    fn extend<S: IntoIterator<Item = &'a T>>(&mut self, iter: S) {
209        for elem in iter {
210            self.add(elem);
211        }
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::{HyperLogLog, NUM_REGISTERS};
218
219    fn compare_with_delta(got: usize, expected: usize) {
220        let expected = expected as f64;
221        let diff = (got as f64) - expected;
222        let diff = diff.abs() / expected;
223        // times 6 because we want the tests to be stable
224        // so we allow a rather large margin of error
225        // this is adopted from redis's unit test version as well
226        let margin = 1.04 / ((NUM_REGISTERS as f64).sqrt()) * 6.0;
227        assert!(
228            diff <= margin,
229            "{} is not near {} percent of {} which is ({}, {})",
230            got,
231            margin,
232            expected,
233            expected * (1.0 - margin),
234            expected * (1.0 + margin)
235        );
236    }
237
238    macro_rules! sized_number_test {
239        ($SIZE: expr, $T: tt) => {{
240            let mut hll = HyperLogLog::<$T>::new();
241            for i in 0..$SIZE {
242                hll.add(&i);
243            }
244            compare_with_delta(hll.count(), $SIZE);
245        }};
246    }
247
248    macro_rules! typed_large_number_test {
249        ($SIZE: expr) => {{
250            sized_number_test!($SIZE, u64);
251            sized_number_test!($SIZE, u128);
252            sized_number_test!($SIZE, i64);
253            sized_number_test!($SIZE, i128);
254        }};
255    }
256
257    macro_rules! typed_number_test {
258        ($SIZE: expr) => {{
259            sized_number_test!($SIZE, u16);
260            sized_number_test!($SIZE, u32);
261            sized_number_test!($SIZE, i16);
262            sized_number_test!($SIZE, i32);
263            typed_large_number_test!($SIZE);
264        }};
265    }
266
267    #[test]
268    fn test_empty() {
269        let hll = HyperLogLog::<u64>::new();
270        assert_eq!(hll.count(), 0);
271    }
272
273    #[test]
274    fn test_one() {
275        let mut hll = HyperLogLog::<u64>::new();
276        hll.add(&1);
277        assert_eq!(hll.count(), 1);
278    }
279
280    #[test]
281    fn test_number_100() {
282        typed_number_test!(100);
283    }
284
285    #[test]
286    fn test_number_1k() {
287        typed_number_test!(1_000);
288    }
289
290    #[test]
291    fn test_number_10k() {
292        typed_number_test!(10_000);
293    }
294
295    #[test]
296    fn test_number_100k() {
297        typed_large_number_test!(100_000);
298    }
299
300    #[test]
301    fn test_number_1m() {
302        typed_large_number_test!(1_000_000);
303    }
304
305    #[test]
306    fn test_u8() {
307        let mut hll = HyperLogLog::<[u8]>::new();
308        for i in 0..1000 {
309            let s = i.to_string();
310            let b = s.as_bytes();
311            hll.add(b);
312        }
313        compare_with_delta(hll.count(), 1000);
314    }
315
316    #[test]
317    fn test_string() {
318        let mut hll = HyperLogLog::<String>::new();
319        hll.extend((0..1000).map(|i| i.to_string()));
320        compare_with_delta(hll.count(), 1000);
321    }
322
323    #[test]
324    fn test_empty_merge() {
325        let mut hll = HyperLogLog::<u64>::new();
326        hll.merge(&HyperLogLog::<u64>::new());
327        assert_eq!(hll.count(), 0);
328    }
329
330    #[test]
331    fn test_merge_overlapped() {
332        let mut hll = HyperLogLog::<String>::new();
333        hll.extend((0..1000).map(|i| i.to_string()));
334
335        let mut other = HyperLogLog::<String>::new();
336        other.extend((0..1000).map(|i| i.to_string()));
337
338        hll.merge(&other);
339        compare_with_delta(hll.count(), 1000);
340    }
341
342    #[test]
343    fn test_repetition() {
344        let mut hll = HyperLogLog::<u32>::new();
345        for i in 0..1_000_000 {
346            hll.add(&(i % 1000));
347        }
348        compare_with_delta(hll.count(), 1000);
349    }
350}