polars_utils/idx_map/
bytes_idx_map.rs

1use hashbrown::hash_table::{
2    Entry as TEntry, HashTable, OccupiedEntry as TOccupiedEntry, VacantEntry as TVacantEntry,
3};
4
5use crate::IdxSize;
6
7const BASE_KEY_DATA_CAPACITY: usize = 1024;
8
9struct Key {
10    key_hash: u64,
11    key_buffer: u32,
12    key_offset: usize,
13    key_length: u32,
14}
15
16impl Key {
17    unsafe fn get<'k>(&self, key_data: &'k [Vec<u8>]) -> &'k [u8] {
18        let buf = key_data.get_unchecked(self.key_buffer as usize);
19        buf.get_unchecked(self.key_offset..self.key_offset + self.key_length as usize)
20    }
21}
22
23/// An IndexMap where the keys are always [u8] slices which are pre-hashed.
24pub struct BytesIndexMap<V> {
25    table: HashTable<IdxSize>,
26    tuples: Vec<(Key, V)>,
27    key_data: Vec<Vec<u8>>,
28
29    // Internal random seed used to keep hash iteration order decorrelated.
30    // We simply store a random odd number and multiply the canonical hash by it.
31    seed: u64,
32}
33
34impl<V> Default for BytesIndexMap<V> {
35    fn default() -> Self {
36        Self {
37            table: HashTable::new(),
38            tuples: Vec::new(),
39            key_data: vec![Vec::with_capacity(BASE_KEY_DATA_CAPACITY)],
40            seed: rand::random::<u64>() | 1,
41        }
42    }
43}
44
45impl<V> BytesIndexMap<V> {
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    pub fn reserve(&mut self, additional: usize) {
51        self.table.reserve(additional, |i| unsafe {
52            let tuple = self.tuples.get_unchecked(*i as usize);
53            tuple.0.key_hash.wrapping_mul(self.seed)
54        });
55        self.tuples.reserve(additional);
56    }
57
58    pub fn len(&self) -> IdxSize {
59        self.table.len() as IdxSize
60    }
61
62    pub fn is_empty(&self) -> bool {
63        self.table.is_empty()
64    }
65
66    pub fn get(&self, hash: u64, key: &[u8]) -> Option<&V> {
67        let idx = self.table.find(hash.wrapping_mul(self.seed), |i| unsafe {
68            let t = self.tuples.get_unchecked(*i as usize);
69            hash == t.0.key_hash && key == t.0.get(&self.key_data)
70        })?;
71        unsafe { Some(&self.tuples.get_unchecked(*idx as usize).1) }
72    }
73
74    pub fn entry<'k>(&mut self, hash: u64, key: &'k [u8]) -> Entry<'_, 'k, V> {
75        let entry = self.table.entry(
76            hash.wrapping_mul(self.seed),
77            |i| unsafe {
78                let t = self.tuples.get_unchecked(*i as usize);
79                hash == t.0.key_hash && key == t.0.get(&self.key_data)
80            },
81            |i| unsafe {
82                let t = self.tuples.get_unchecked(*i as usize);
83                t.0.key_hash.wrapping_mul(self.seed)
84            },
85        );
86
87        match entry {
88            TEntry::Occupied(o) => Entry::Occupied(OccupiedEntry {
89                entry: o,
90                tuples: &mut self.tuples,
91            }),
92            TEntry::Vacant(v) => Entry::Vacant(VacantEntry {
93                key,
94                hash,
95                entry: v,
96                tuples: &mut self.tuples,
97                key_data: &mut self.key_data,
98            }),
99        }
100    }
101
102    /// Gets the hash, key and value at the given index by insertion order.
103    #[inline(always)]
104    pub fn get_index(&self, idx: IdxSize) -> Option<(u64, &[u8], &V)> {
105        let t = self.tuples.get(idx as usize)?;
106        Some((t.0.key_hash, unsafe { t.0.get(&self.key_data) }, &t.1))
107    }
108
109    /// Gets the hash, key and value at the given index by insertion order.
110    ///
111    /// # Safety
112    /// The index must be less than len().
113    #[inline(always)]
114    pub unsafe fn get_index_unchecked(&self, idx: IdxSize) -> (u64, &[u8], &V) {
115        let t = self.tuples.get_unchecked(idx as usize);
116        (t.0.key_hash, t.0.get(&self.key_data), &t.1)
117    }
118
119    /// Iterates over the (hash, key) pairs in insertion order.
120    pub fn iter_hash_keys(&self) -> impl Iterator<Item = (u64, &[u8])> {
121        self.tuples
122            .iter()
123            .map(|t| unsafe { (t.0.key_hash, t.0.get(&self.key_data)) })
124    }
125
126    /// Iterates over the values in insertion order.
127    pub fn iter_values(&self) -> impl Iterator<Item = &V> {
128        self.tuples.iter().map(|t| &t.1)
129    }
130}
131
132pub enum Entry<'a, 'k, V> {
133    Occupied(OccupiedEntry<'a, V>),
134    Vacant(VacantEntry<'a, 'k, V>),
135}
136
137pub struct OccupiedEntry<'a, V> {
138    entry: TOccupiedEntry<'a, IdxSize>,
139    tuples: &'a mut Vec<(Key, V)>,
140}
141
142impl<'a, V> OccupiedEntry<'a, V> {
143    pub fn index(&self) -> IdxSize {
144        *self.entry.get()
145    }
146
147    pub fn into_mut(self) -> &'a mut V {
148        let idx = self.index();
149        unsafe { &mut self.tuples.get_unchecked_mut(idx as usize).1 }
150    }
151}
152
153pub struct VacantEntry<'a, 'k, V> {
154    hash: u64,
155    key: &'k [u8],
156    entry: TVacantEntry<'a, IdxSize>,
157    tuples: &'a mut Vec<(Key, V)>,
158    key_data: &'a mut Vec<Vec<u8>>,
159}
160
161#[allow(clippy::needless_lifetimes)]
162impl<'a, 'k, V> VacantEntry<'a, 'k, V> {
163    pub fn index(&self) -> IdxSize {
164        self.tuples.len() as IdxSize
165    }
166
167    pub fn insert(self, value: V) -> &'a mut V {
168        unsafe {
169            let tuple_idx: IdxSize = self.tuples.len().try_into().unwrap();
170
171            let mut num_buffers = self.key_data.len() as u32;
172            let mut active_buf = self.key_data.last_mut().unwrap_unchecked();
173            let key_len = self.key.len();
174            if active_buf.len() + key_len > active_buf.capacity() {
175                let ideal_next_cap = BASE_KEY_DATA_CAPACITY.checked_shl(num_buffers).unwrap();
176                let next_capacity = std::cmp::max(ideal_next_cap, key_len);
177                self.key_data.push(Vec::with_capacity(next_capacity));
178                active_buf = self.key_data.last_mut().unwrap_unchecked();
179                num_buffers += 1;
180            }
181
182            let tuple_key = Key {
183                key_hash: self.hash,
184                key_buffer: num_buffers - 1,
185                key_offset: active_buf.len(),
186                key_length: self.key.len().try_into().unwrap(),
187            };
188            self.tuples.push((tuple_key, value));
189            active_buf.extend_from_slice(self.key);
190            self.entry.insert(tuple_idx);
191            &mut self.tuples.last_mut().unwrap_unchecked().1
192        }
193    }
194}