polars_utils/
mmap.rs

1use std::fs::File;
2use std::io;
3
4pub use memmap::Mmap;
5
6mod private {
7    use std::fs::File;
8    use std::ops::Deref;
9    use std::sync::Arc;
10
11    use polars_error::PolarsResult;
12
13    use super::MMapSemaphore;
14    use crate::mem::prefetch_l2;
15
16    /// A read-only reference to a slice of memory that can potentially be memory-mapped.
17    ///
18    /// A reference count is kept to the underlying buffer to ensure the memory is kept alive.
19    /// [`MemSlice::slice`] can be used to slice the memory in a zero-copy manner.
20    ///
21    /// This still owns the all the original memory and therefore should probably not be a long-lasting
22    /// structure.
23    #[derive(Clone, Debug)]
24    pub struct MemSlice {
25        // Store the `&[u8]` to make the `Deref` free.
26        // `slice` is not 'static - it is backed by `inner`. This is safe as long as `slice` is not
27        // directly accessed, and we are in a private module to guarantee that. Access should only
28        // be done through `Deref<Target = [u8]>`, which automatically gives the correct lifetime.
29        slice: &'static [u8],
30        #[allow(unused)]
31        inner: MemSliceInner,
32    }
33
34    /// Keeps the underlying buffer alive. This should be cheaply cloneable.
35    #[derive(Clone, Debug)]
36    #[allow(unused)]
37    enum MemSliceInner {
38        Bytes(bytes::Bytes), // Separate because it does atomic refcounting internally
39        Arc(Arc<dyn std::fmt::Debug + Send + Sync>),
40    }
41
42    impl Deref for MemSlice {
43        type Target = [u8];
44
45        #[inline(always)]
46        fn deref(&self) -> &Self::Target {
47            self.slice
48        }
49    }
50
51    impl AsRef<[u8]> for MemSlice {
52        #[inline(always)]
53        fn as_ref(&self) -> &[u8] {
54            self.slice
55        }
56    }
57
58    impl Default for MemSlice {
59        fn default() -> Self {
60            Self::from_bytes(bytes::Bytes::new())
61        }
62    }
63
64    impl From<Vec<u8>> for MemSlice {
65        fn from(value: Vec<u8>) -> Self {
66            Self::from_vec(value)
67        }
68    }
69
70    impl MemSlice {
71        pub const EMPTY: Self = Self::from_static(&[]);
72
73        /// Copy the contents into a new owned `Vec`
74        #[inline(always)]
75        pub fn to_vec(self) -> Vec<u8> {
76            <[u8]>::to_vec(self.deref())
77        }
78
79        /// Construct a `MemSlice` from an existing `Vec<u8>`. This is zero-copy.
80        #[inline]
81        pub fn from_vec(v: Vec<u8>) -> Self {
82            Self::from_bytes(bytes::Bytes::from(v))
83        }
84
85        /// Construct a `MemSlice` from [`bytes::Bytes`]. This is zero-copy.
86        #[inline]
87        pub fn from_bytes(bytes: bytes::Bytes) -> Self {
88            Self {
89                slice: unsafe { std::mem::transmute::<&[u8], &'static [u8]>(bytes.as_ref()) },
90                inner: MemSliceInner::Bytes(bytes),
91            }
92        }
93
94        #[inline]
95        pub fn from_mmap(mmap: Arc<MMapSemaphore>) -> Self {
96            Self {
97                slice: unsafe {
98                    std::mem::transmute::<&[u8], &'static [u8]>(mmap.as_ref().as_ref())
99                },
100                inner: MemSliceInner::Arc(mmap),
101            }
102        }
103
104        #[inline]
105        pub fn from_arc<T>(slice: &[u8], arc: Arc<T>) -> Self
106        where
107            T: std::fmt::Debug + Send + Sync + 'static,
108        {
109            Self {
110                slice: unsafe { std::mem::transmute::<&[u8], &'static [u8]>(slice) },
111                inner: MemSliceInner::Arc(arc),
112            }
113        }
114
115        #[inline]
116        pub fn from_file(file: &File) -> PolarsResult<Self> {
117            let mmap = MMapSemaphore::new_from_file(file)?;
118            Ok(Self::from_mmap(Arc::new(mmap)))
119        }
120
121        /// Construct a `MemSlice` that simply wraps around a `&[u8]`.
122        #[inline]
123        pub const fn from_static(slice: &'static [u8]) -> Self {
124            let inner = MemSliceInner::Bytes(bytes::Bytes::from_static(slice));
125            Self { slice, inner }
126        }
127
128        /// Attempt to prefetch the memory belonging to to this [`MemSlice`]
129        #[inline]
130        pub fn prefetch(&self) {
131            prefetch_l2(self.as_ref());
132        }
133
134        /// # Panics
135        /// Panics if range is not in bounds.
136        #[inline]
137        #[track_caller]
138        pub fn slice(&self, range: std::ops::Range<usize>) -> Self {
139            let mut out = self.clone();
140            out.slice = &out.slice[range];
141            out
142        }
143    }
144
145    impl From<bytes::Bytes> for MemSlice {
146        fn from(value: bytes::Bytes) -> Self {
147            Self::from_bytes(value)
148        }
149    }
150}
151
152use memmap::MmapOptions;
153#[cfg(target_family = "unix")]
154use polars_error::polars_bail;
155use polars_error::PolarsResult;
156pub use private::MemSlice;
157
158/// A cursor over a [`MemSlice`].
159#[derive(Debug, Clone)]
160pub struct MemReader {
161    data: MemSlice,
162    position: usize,
163}
164
165impl MemReader {
166    pub fn new(data: MemSlice) -> Self {
167        Self { data, position: 0 }
168    }
169
170    #[inline(always)]
171    pub fn remaining_len(&self) -> usize {
172        self.data.len() - self.position
173    }
174
175    #[inline(always)]
176    pub fn total_len(&self) -> usize {
177        self.data.len()
178    }
179
180    #[inline(always)]
181    pub fn position(&self) -> usize {
182        self.position
183    }
184
185    /// Construct a `MemSlice` from an existing `Vec<u8>`. This is zero-copy.
186    #[inline(always)]
187    pub fn from_vec(v: Vec<u8>) -> Self {
188        Self::new(MemSlice::from_vec(v))
189    }
190
191    /// Construct a `MemSlice` from [`bytes::Bytes`]. This is zero-copy.
192    #[inline(always)]
193    pub fn from_bytes(bytes: bytes::Bytes) -> Self {
194        Self::new(MemSlice::from_bytes(bytes))
195    }
196
197    // Construct a `MemSlice` that simply wraps around a `&[u8]`. The caller must ensure the
198    /// slice outlives the returned `MemSlice`.
199    #[inline]
200    pub fn from_slice(slice: &'static [u8]) -> Self {
201        Self::new(MemSlice::from_static(slice))
202    }
203
204    #[inline(always)]
205    pub fn from_reader<R: io::Read>(mut reader: R) -> io::Result<Self> {
206        let mut vec = Vec::new();
207        reader.read_to_end(&mut vec)?;
208        Ok(Self::from_vec(vec))
209    }
210
211    #[inline(always)]
212    pub fn read_slice(&mut self, n: usize) -> MemSlice {
213        let start = self.position;
214        let end = usize::min(self.position + n, self.data.len());
215        self.position = end;
216        self.data.slice(start..end)
217    }
218}
219
220impl From<MemSlice> for MemReader {
221    fn from(data: MemSlice) -> Self {
222        Self { data, position: 0 }
223    }
224}
225
226impl io::Read for MemReader {
227    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
228        let n = usize::min(buf.len(), self.remaining_len());
229        buf[..n].copy_from_slice(&self.data[self.position..self.position + n]);
230        self.position += n;
231        Ok(n)
232    }
233}
234
235impl io::Seek for MemReader {
236    fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
237        let position = match pos {
238            io::SeekFrom::Start(position) => usize::min(position as usize, self.total_len()),
239            io::SeekFrom::End(offset) => {
240                let Some(position) = self.total_len().checked_add_signed(offset as isize) else {
241                    return Err(io::Error::new(
242                        io::ErrorKind::Other,
243                        "Seek before to before buffer",
244                    ));
245                };
246
247                position
248            },
249            io::SeekFrom::Current(offset) => {
250                let Some(position) = self.position.checked_add_signed(offset as isize) else {
251                    return Err(io::Error::new(
252                        io::ErrorKind::Other,
253                        "Seek before to before buffer",
254                    ));
255                };
256
257                position
258            },
259        };
260
261        self.position = position;
262
263        Ok(position as u64)
264    }
265}
266
267// Keep track of memory mapped files so we don't write to them while reading
268// Use a btree as it uses less memory than a hashmap and this thing never shrinks.
269// Write handle in Windows is exclusive, so this is only necessary in Unix.
270#[cfg(target_family = "unix")]
271static MEMORY_MAPPED_FILES: once_cell::sync::Lazy<
272    std::sync::Mutex<std::collections::BTreeMap<(u64, u64), u32>>,
273> = once_cell::sync::Lazy::new(|| std::sync::Mutex::new(Default::default()));
274
275#[derive(Debug)]
276pub struct MMapSemaphore {
277    #[cfg(target_family = "unix")]
278    key: (u64, u64),
279    mmap: Mmap,
280}
281
282impl MMapSemaphore {
283    pub fn new_from_file_with_options(
284        file: &File,
285        options: MmapOptions,
286    ) -> PolarsResult<MMapSemaphore> {
287        let mmap = unsafe { options.map(file) }?;
288
289        #[cfg(target_family = "unix")]
290        {
291            // FIXME: We aren't handling the case where the file is already open in write-mode here.
292
293            use std::os::unix::fs::MetadataExt;
294            let metadata = file.metadata()?;
295
296            let mut guard = MEMORY_MAPPED_FILES.lock().unwrap();
297            let key = (metadata.dev(), metadata.ino());
298            match guard.entry(key) {
299                std::collections::btree_map::Entry::Occupied(mut e) => *e.get_mut() += 1,
300                std::collections::btree_map::Entry::Vacant(e) => _ = e.insert(1),
301            }
302            Ok(Self { key, mmap })
303        }
304
305        #[cfg(not(target_family = "unix"))]
306        Ok(Self { mmap })
307    }
308
309    pub fn new_from_file(file: &File) -> PolarsResult<MMapSemaphore> {
310        Self::new_from_file_with_options(file, MmapOptions::default())
311    }
312
313    pub fn as_ptr(&self) -> *const u8 {
314        self.mmap.as_ptr()
315    }
316}
317
318impl AsRef<[u8]> for MMapSemaphore {
319    #[inline]
320    fn as_ref(&self) -> &[u8] {
321        self.mmap.as_ref()
322    }
323}
324
325#[cfg(target_family = "unix")]
326impl Drop for MMapSemaphore {
327    fn drop(&mut self) {
328        let mut guard = MEMORY_MAPPED_FILES.lock().unwrap();
329        if let std::collections::btree_map::Entry::Occupied(mut e) = guard.entry(self.key) {
330            let v = e.get_mut();
331            *v -= 1;
332
333            if *v == 0 {
334                e.remove_entry();
335            }
336        }
337    }
338}
339
340pub fn ensure_not_mapped(
341    #[cfg_attr(not(target_family = "unix"), allow(unused))] file_md: &std::fs::Metadata,
342) -> PolarsResult<()> {
343    // TODO: We need to actually register that this file has been write-opened and prevent
344    // read-opening this file based on that.
345    #[cfg(target_family = "unix")]
346    {
347        use std::os::unix::fs::MetadataExt;
348        let guard = MEMORY_MAPPED_FILES.lock().unwrap();
349        if guard.contains_key(&(file_md.dev(), file_md.ino())) {
350            polars_bail!(ComputeError: "cannot write to file: already memory mapped");
351        }
352    }
353    Ok(())
354}
355
356mod tests {
357    #[test]
358    fn test_mem_slice_zero_copy() {
359        use std::sync::Arc;
360
361        use super::MemSlice;
362
363        {
364            let vec = vec![1u8, 2, 3, 4, 5];
365            let ptr = vec.as_ptr();
366
367            let mem_slice = MemSlice::from_vec(vec);
368            let ptr_out = mem_slice.as_ptr();
369
370            assert_eq!(ptr_out, ptr);
371        }
372
373        {
374            let mut vec = vec![1u8, 2, 3, 4, 5];
375            vec.truncate(2);
376            let ptr = vec.as_ptr();
377
378            let mem_slice = MemSlice::from_vec(vec);
379            let ptr_out = mem_slice.as_ptr();
380
381            assert_eq!(ptr_out, ptr);
382        }
383
384        {
385            let bytes = bytes::Bytes::from(vec![1u8, 2, 3, 4, 5]);
386            let ptr = bytes.as_ptr();
387
388            let mem_slice = MemSlice::from_bytes(bytes);
389            let ptr_out = mem_slice.as_ptr();
390
391            assert_eq!(ptr_out, ptr);
392        }
393
394        {
395            use crate::mmap::MMapSemaphore;
396
397            let path = "../../examples/datasets/foods1.csv";
398            let file = std::fs::File::open(path).unwrap();
399            let mmap = MMapSemaphore::new_from_file(&file).unwrap();
400            let ptr = mmap.as_ptr();
401
402            let mem_slice = MemSlice::from_mmap(Arc::new(mmap));
403            let ptr_out = mem_slice.as_ptr();
404
405            assert_eq!(ptr_out, ptr);
406        }
407
408        {
409            let vec = vec![1u8, 2, 3, 4, 5];
410            let slice = vec.as_slice();
411            let ptr = slice.as_ptr();
412
413            let mem_slice = MemSlice::from_static(unsafe {
414                std::mem::transmute::<&[u8], &'static [u8]>(slice)
415            });
416            let ptr_out = mem_slice.as_ptr();
417
418            assert_eq!(ptr_out, ptr);
419        }
420    }
421
422    #[test]
423    fn test_mem_slice_slicing() {
424        use super::MemSlice;
425
426        {
427            let vec = vec![1u8, 2, 3, 4, 5];
428            let slice = vec.as_slice();
429
430            let mem_slice = MemSlice::from_static(unsafe {
431                std::mem::transmute::<&[u8], &'static [u8]>(slice)
432            });
433
434            let out = &*mem_slice.slice(3..5);
435            assert_eq!(out, &slice[3..5]);
436            assert_eq!(out.as_ptr(), slice[3..5].as_ptr());
437        }
438    }
439}