numpy/
strings.rs

1//! Types to support arrays of [ASCII][ascii] and [UCS4][ucs4] strings
2//!
3//! [ascii]: https://numpy.org/doc/stable/reference/c-api/dtype.html#c.NPY_STRING
4//! [ucs4]: https://numpy.org/doc/stable/reference/c-api/dtype.html#c.NPY_UNICODE
5
6use std::cell::RefCell;
7use std::collections::hash_map::Entry;
8use std::fmt;
9use std::mem::size_of;
10use std::os::raw::c_char;
11use std::str;
12
13use pyo3::{
14    ffi::{Py_UCS1, Py_UCS4},
15    sync::GILProtected,
16    Bound, Py, Python,
17};
18use rustc_hash::FxHashMap;
19
20use crate::dtype::{clone_methods_impl, Element, PyArrayDescr, PyArrayDescrMethods};
21use crate::npyffi::PyDataType_SET_ELSIZE;
22use crate::npyffi::NPY_TYPES;
23
24/// A newtype wrapper around [`[u8; N]`][Py_UCS1] to handle [`byte` scalars][numpy-bytes] while satisfying coherence.
25///
26/// Note that when creating arrays of ASCII strings without an explicit `dtype`,
27/// NumPy will automatically determine the smallest possible array length at runtime.
28///
29/// For example,
30///
31/// ```python
32/// array = numpy.array([b"foo", b"bar", b"foobar"])
33/// ```
34///
35/// yields `S6` for `array.dtype`.
36///
37/// On the Rust side however, the length `N` of `PyFixedString<N>` must always be given
38/// explicitly and as a compile-time constant. For this work reliably, the Python code
39/// should set the `dtype` explicitly, e.g.
40///
41/// ```python
42/// numpy.array([b"foo", b"bar", b"foobar"], dtype='S12')
43/// ```
44///
45/// always matching `PyArray1<PyFixedString<12>>`.
46///
47/// # Example
48///
49/// ```rust
50/// # use pyo3::Python;
51/// use numpy::{PyArray1, PyUntypedArrayMethods, PyFixedString};
52///
53/// # Python::with_gil(|py| {
54/// let array = PyArray1::<PyFixedString<3>>::from_vec(py, vec![[b'f', b'o', b'o'].into()]);
55///
56/// assert!(array.dtype().to_string().contains("S3"));
57/// # });
58/// ```
59///
60/// [numpy-bytes]: https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bytes_
61#[repr(transparent)]
62#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
63pub struct PyFixedString<const N: usize>(pub [Py_UCS1; N]);
64
65impl<const N: usize> fmt::Display for PyFixedString<N> {
66    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
67        fmt.write_str(str::from_utf8(&self.0).unwrap().trim_end_matches('\0'))
68    }
69}
70
71impl<const N: usize> From<[Py_UCS1; N]> for PyFixedString<N> {
72    fn from(val: [Py_UCS1; N]) -> Self {
73        Self(val)
74    }
75}
76
77unsafe impl<const N: usize> Element for PyFixedString<N> {
78    const IS_COPY: bool = true;
79
80    fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
81        static DTYPES: TypeDescriptors = TypeDescriptors::new();
82
83        unsafe { DTYPES.from_size(py, NPY_TYPES::NPY_STRING, b'|' as _, size_of::<Self>()) }
84    }
85
86    clone_methods_impl!(Self);
87}
88
89/// A newtype wrapper around [`[PyUCS4; N]`][Py_UCS4] to handle [`str_` scalars][numpy-str] while satisfying coherence.
90///
91/// Note that when creating arrays of Unicode strings without an explicit `dtype`,
92/// NumPy will automatically determine the smallest possible array length at runtime.
93///
94/// For example,
95///
96/// ```python
97/// numpy.array(["foo🐍", "bar🦀", "foobar"])
98/// ```
99///
100/// yields `U6` for `array.dtype`.
101///
102/// On the Rust side however, the length `N` of `PyFixedUnicode<N>` must always be given
103/// explicitly and as a compile-time constant. For this work reliably, the Python code
104/// should set the `dtype` explicitly, e.g.
105///
106/// ```python
107/// numpy.array(["foo🐍", "bar🦀", "foobar"], dtype='U12')
108/// ```
109///
110/// always matching `PyArray1<PyFixedUnicode<12>>`.
111///
112/// # Example
113///
114/// ```rust
115/// # use pyo3::Python;
116/// use numpy::{PyArray1, PyUntypedArrayMethods, PyFixedUnicode};
117///
118/// # Python::with_gil(|py| {
119/// let array = PyArray1::<PyFixedUnicode<3>>::from_vec(py, vec![[b'b' as _, b'a' as _, b'r' as _].into()]);
120///
121/// assert!(array.dtype().to_string().contains("U3"));
122/// # });
123/// ```
124///
125/// [numpy-str]: https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.str_
126#[repr(transparent)]
127#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
128pub struct PyFixedUnicode<const N: usize>(pub [Py_UCS4; N]);
129
130impl<const N: usize> fmt::Display for PyFixedUnicode<N> {
131    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
132        for character in self.0 {
133            if character == 0 {
134                break;
135            }
136
137            write!(fmt, "{}", char::from_u32(character).unwrap())?;
138        }
139
140        Ok(())
141    }
142}
143
144impl<const N: usize> From<[Py_UCS4; N]> for PyFixedUnicode<N> {
145    fn from(val: [Py_UCS4; N]) -> Self {
146        Self(val)
147    }
148}
149
150unsafe impl<const N: usize> Element for PyFixedUnicode<N> {
151    const IS_COPY: bool = true;
152
153    fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
154        static DTYPES: TypeDescriptors = TypeDescriptors::new();
155
156        unsafe { DTYPES.from_size(py, NPY_TYPES::NPY_UNICODE, b'=' as _, size_of::<Self>()) }
157    }
158
159    clone_methods_impl!(Self);
160}
161
162struct TypeDescriptors {
163    #[allow(clippy::type_complexity)]
164    dtypes: GILProtected<RefCell<Option<FxHashMap<usize, Py<PyArrayDescr>>>>>,
165}
166
167impl TypeDescriptors {
168    const fn new() -> Self {
169        Self {
170            dtypes: GILProtected::new(RefCell::new(None)),
171        }
172    }
173
174    /// `npy_type` must be either `NPY_STRING` or `NPY_UNICODE` with matching `byteorder` and `size`
175    #[allow(clippy::wrong_self_convention)]
176    unsafe fn from_size<'py>(
177        &self,
178        py: Python<'py>,
179        npy_type: NPY_TYPES,
180        byteorder: c_char,
181        size: usize,
182    ) -> Bound<'py, PyArrayDescr> {
183        let mut dtypes = self.dtypes.get(py).borrow_mut();
184
185        let dtype = match dtypes.get_or_insert_with(Default::default).entry(size) {
186            Entry::Occupied(entry) => entry.into_mut(),
187            Entry::Vacant(entry) => {
188                let dtype = PyArrayDescr::new_from_npy_type(py, npy_type);
189
190                let descr = &mut *dtype.as_dtype_ptr();
191                PyDataType_SET_ELSIZE(py, descr, size.try_into().unwrap());
192                descr.byteorder = byteorder;
193
194                entry.insert(dtype.into())
195            }
196        };
197
198        dtype.bind(py).to_owned()
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn format_fixed_string() {
208        assert_eq!(
209            PyFixedString([b'f', b'o', b'o', 0, 0, 0]).to_string(),
210            "foo"
211        );
212        assert_eq!(
213            PyFixedString([b'f', b'o', b'o', b'b', b'a', b'r']).to_string(),
214            "foobar"
215        );
216    }
217
218    #[test]
219    fn format_fixed_unicode() {
220        assert_eq!(
221            PyFixedUnicode([b'f' as _, b'o' as _, b'o' as _, 0, 0, 0]).to_string(),
222            "foo"
223        );
224        assert_eq!(
225            PyFixedUnicode([0x1F980, 0x1F40D, 0, 0, 0, 0]).to_string(),
226            "🦀🐍"
227        );
228        assert_eq!(
229            PyFixedUnicode([b'f' as _, b'o' as _, b'o' as _, b'b' as _, b'a' as _, b'r' as _])
230                .to_string(),
231            "foobar"
232        );
233    }
234}