polars_utils/
cache.rs

1use std::borrow::Borrow;
2use std::cell::Cell;
3use std::hash::Hash;
4use std::mem::MaybeUninit;
5
6use bytemuck::allocation::zeroed_vec;
7use bytemuck::Zeroable;
8
9use crate::aliases::PlRandomState;
10
11/// A cached function that use `FastFixedCache` for access speed.
12/// It is important that the key is relatively cheap to compute.
13pub struct FastCachedFunc<T, R, F> {
14    func: F,
15    cache: FastFixedCache<T, R>,
16}
17
18impl<T, R, F> FastCachedFunc<T, R, F>
19where
20    F: FnMut(T) -> R,
21    T: std::hash::Hash + Eq + Clone,
22    R: Copy,
23{
24    pub fn new(func: F, size: usize) -> Self {
25        Self {
26            func,
27            cache: FastFixedCache::new(size),
28        }
29    }
30
31    pub fn eval(&mut self, x: T, use_cache: bool) -> R {
32        if use_cache {
33            *self
34                .cache
35                .get_or_insert_with(&x, |xr| (self.func)(xr.clone()))
36        } else {
37            (self.func)(x)
38        }
39    }
40}
41
42/// A fixed-size cache optimized for access speed. Does not implement LRU or use
43/// a full hash table due to cost, instead we assign two pseudorandom slots
44/// based on the hash of the key, and if both are full we evict the one that had
45/// the older last access.
46const MIN_FAST_FIXED_CACHE_SIZE: usize = 16;
47
48#[derive(Clone)]
49pub struct FastFixedCache<K, V> {
50    slots: Vec<CacheSlot<K, V>>,
51    access_ctr: Cell<u32>,
52    shift: u32,
53    random_state: PlRandomState,
54}
55
56impl<K: Hash + Eq, V> Default for FastFixedCache<K, V> {
57    fn default() -> Self {
58        Self::new(MIN_FAST_FIXED_CACHE_SIZE)
59    }
60}
61
62impl<K: Hash + Eq, V> FastFixedCache<K, V> {
63    pub fn new(n: usize) -> Self {
64        let n = (n.max(MIN_FAST_FIXED_CACHE_SIZE)).next_power_of_two();
65        Self {
66            slots: zeroed_vec(n),
67            access_ctr: Cell::new(1),
68            shift: 64 - n.ilog2(),
69            random_state: PlRandomState::default(),
70        }
71    }
72
73    pub fn get<Q: Hash + Eq + ?Sized>(&self, key: &Q) -> Option<&V>
74    where
75        K: Borrow<Q>,
76    {
77        unsafe {
78            // SAFETY: slot_idx from raw_get is valid and occupied.
79            let slot_idx = self.raw_get(self.hash(key), key)?;
80            let slot = self.slots.get_unchecked(slot_idx);
81            Some(slot.value.assume_init_ref())
82        }
83    }
84
85    pub fn get_mut<Q: Hash + Eq + ?Sized>(&mut self, key: &Q) -> Option<&mut V>
86    where
87        K: Borrow<Q>,
88    {
89        unsafe {
90            // SAFETY: slot_idx from raw_get is valid and occupied.
91            let slot_idx = self.raw_get(self.hash(&key), key)?;
92            let slot = self.slots.get_unchecked_mut(slot_idx);
93            Some(slot.value.assume_init_mut())
94        }
95    }
96
97    pub fn insert(&mut self, key: K, value: V) -> &mut V {
98        unsafe { self.raw_insert(self.hash(&key), key, value) }
99    }
100
101    pub fn get_or_insert_with<Q, F>(&mut self, key: &Q, f: F) -> &mut V
102    where
103        K: Borrow<Q>,
104        Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
105        F: FnOnce(&K) -> V,
106    {
107        unsafe {
108            let h = self.hash(key);
109            if let Some(slot_idx) = self.raw_get(self.hash(&key), key) {
110                let slot = self.slots.get_unchecked_mut(slot_idx);
111                return slot.value.assume_init_mut();
112            }
113
114            let key = key.to_owned();
115            let val = f(&key);
116            self.raw_insert(h, key, val)
117        }
118    }
119
120    pub fn try_get_or_insert_with<Q, F, E>(&mut self, key: &Q, f: F) -> Result<&mut V, E>
121    where
122        K: Borrow<Q>,
123        Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
124        F: FnOnce(&K) -> Result<V, E>,
125    {
126        unsafe {
127            let h = self.hash(key);
128            if let Some(slot_idx) = self.raw_get(self.hash(&key), key) {
129                let slot = self.slots.get_unchecked_mut(slot_idx);
130                return Ok(slot.value.assume_init_mut());
131            }
132
133            let key = key.to_owned();
134            let val = f(&key)?;
135            Ok(self.raw_insert(h, key, val))
136        }
137    }
138
139    unsafe fn raw_get<Q: Eq + ?Sized>(&self, h: HashResult, key: &Q) -> Option<usize>
140    where
141        K: Borrow<Q>,
142    {
143        unsafe {
144            // SAFETY: we assume h is a HashResult from self.hash with valid indices
145            // and we check slot.last_access != 0 before assuming the slot is initialized.
146            let slot = self.slots.get_unchecked(h.i1);
147            if slot.last_access.get() != 0
148                && slot.hash_tag == h.tag
149                && slot.key.assume_init_ref().borrow() == key
150            {
151                slot.last_access.set(self.new_access_ctr());
152                return Some(h.i1);
153            }
154
155            let slot = self.slots.get_unchecked(h.i2);
156            if slot.last_access.get() != 0
157                && slot.hash_tag == h.tag
158                && slot.key.assume_init_ref().borrow() == key
159            {
160                slot.last_access.set(self.new_access_ctr());
161                return Some(h.i2);
162            }
163        }
164
165        None
166    }
167
168    unsafe fn raw_insert(&mut self, h: HashResult, key: K, value: V) -> &mut V {
169        let last_access = self.new_access_ctr();
170        unsafe {
171            // SAFETY: i1 and i2 are valid indices and older_idx returns one of them.
172            let idx = self.older_idx(h.i1, h.i2);
173            let slot = self.slots.get_unchecked_mut(idx);
174
175            // Drop impl takes care of dropping old value, if occupied.
176            *slot = CacheSlot {
177                last_access: Cell::new(last_access),
178                hash_tag: h.tag,
179                key: MaybeUninit::new(key),
180                value: MaybeUninit::new(value),
181            };
182            slot.value.assume_init_mut()
183        }
184    }
185
186    /// Returns the older index based on access time, where unoccupied slots
187    /// are considered infinitely old.
188    unsafe fn older_idx(&mut self, i1: usize, i2: usize) -> usize {
189        let age1 = self.slots.get_unchecked(i1).last_access.get();
190        let age2 = self.slots.get_unchecked(i2).last_access.get();
191        match (age1, age2) {
192            (0, _) => i1,
193            (_, 0) => i2,
194            // This takes into account the wrap-around of our access_ctr.
195            // We assume that the smaller value between age1.wrapping_sub(age2)
196            // and age2.wrapping_sub(age1) is the true delta. Thus if
197            // age1.wrapping_sub(age2) is >= 1 << 31, we know that
198            // age2.wrapping_sub(age1) is smaller than it, and we also
199            // immediately know that age1 is older.
200            _ if age1.wrapping_sub(age2) >= (1 << 31) => i1,
201            _ => i2,
202        }
203    }
204
205    fn new_access_ctr(&self) -> u32 {
206        // This keeps the access_ctr always odd, so we don't hit access_ctr == 0,
207        // which would leak values.
208        self.access_ctr.replace(self.access_ctr.get() + 2)
209    }
210
211    /// Computes the hash tag and two slot indexes for a given key.
212    fn hash<Q: Hash + ?Sized>(&self, key: &Q) -> HashResult {
213        // An instantiation of Dietzfelbinger's multiply-shift, see 2.3 of
214        // https://arxiv.org/pdf/1504.06804.pdf.
215        // The magic constants are just two randomly chosen odd 64-bit numbers.
216        let h = self.random_state.hash_one(key);
217        let tag = h as u32;
218        let i1 = (h.wrapping_mul(0x2e623b55bc0c9073) >> self.shift) as usize;
219        let i2 = (h.wrapping_mul(0x921932b06a233d39) >> self.shift) as usize;
220        HashResult { tag, i1, i2 }
221    }
222}
223
224struct HashResult {
225    tag: u32,
226    i1: usize,
227    i2: usize,
228}
229
230struct CacheSlot<K, V> {
231    // If last_access != 0, the rest is assumed to be initialized.
232    last_access: Cell<u32>,
233    hash_tag: u32,
234    key: MaybeUninit<K>,
235    value: MaybeUninit<V>,
236}
237
238unsafe impl<K, V> Zeroable for CacheSlot<K, V> {}
239
240impl<K, V> Drop for CacheSlot<K, V> {
241    fn drop(&mut self) {
242        unsafe {
243            if self.last_access.get() != 0 {
244                self.key.assume_init_drop();
245                self.value.assume_init_drop();
246            }
247        }
248    }
249}
250
251impl<K: Clone, V: Clone> Clone for CacheSlot<K, V> {
252    fn clone(&self) -> Self {
253        unsafe {
254            if self.last_access.get() != 0 {
255                Self {
256                    last_access: self.last_access.clone(),
257                    hash_tag: self.hash_tag,
258                    key: MaybeUninit::new(self.key.assume_init_ref().clone()),
259                    value: MaybeUninit::new(self.value.assume_init_ref().clone()),
260                }
261            } else {
262                Self::zeroed()
263            }
264        }
265    }
266}