numpy/
dtype.rs

1use std::mem::size_of;
2use std::os::raw::{c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort};
3use std::ptr;
4
5#[cfg(feature = "half")]
6use half::{bf16, f16};
7use num_traits::{Bounded, Zero};
8#[cfg(feature = "half")]
9use pyo3::sync::GILOnceCell;
10use pyo3::{
11    conversion::IntoPyObject,
12    exceptions::{PyIndexError, PyValueError},
13    ffi::{self, PyTuple_Size},
14    pyobject_native_type_named,
15    types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyType},
16    Borrowed, Bound, Py, PyAny, PyObject, PyResult, PyTypeInfo, Python,
17};
18
19use crate::npyffi::{
20    NpyTypes, PyArray_Descr, PyDataType_ALIGNMENT, PyDataType_ELSIZE, PyDataType_FIELDS,
21    PyDataType_FLAGS, PyDataType_NAMES, PyDataType_SUBARRAY, NPY_ALIGNED_STRUCT,
22    NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES, PY_ARRAY_API,
23};
24
25pub use num_complex::{Complex32, Complex64};
26
27/// Binding of [`numpy.dtype`][dtype].
28///
29/// # Example
30///
31/// ```
32/// use numpy::{dtype, get_array_module, PyArrayDescr, PyArrayDescrMethods};
33/// use numpy::pyo3::{types::{IntoPyDict, PyAnyMethods}, Python, ffi::c_str};
34///
35/// # fn main() -> pyo3::PyResult<()> {
36/// Python::with_gil(|py| {
37///     let locals = [("np", get_array_module(py)?)].into_py_dict(py)?;
38///
39///     let dt = py
40///         .eval(c_str!("np.array([1, 2, 3.0]).dtype"), Some(&locals), None)?
41///         .downcast_into::<PyArrayDescr>()?;
42///
43///     assert!(dt.is_equiv_to(&dtype::<f64>(py)));
44/// #   Ok(())
45/// })
46/// # }
47/// ```
48///
49/// [dtype]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.html
50#[repr(transparent)]
51pub struct PyArrayDescr(PyAny);
52
53pyobject_native_type_named!(PyArrayDescr);
54
55unsafe impl PyTypeInfo for PyArrayDescr {
56    const NAME: &'static str = "PyArrayDescr";
57    const MODULE: Option<&'static str> = Some("numpy");
58
59    #[inline]
60    fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
61        unsafe { PY_ARRAY_API.get_type_object(py, NpyTypes::PyArrayDescr_Type) }
62    }
63}
64
65/// Returns the type descriptor ("dtype") for a registered type.
66#[inline]
67pub fn dtype<'py, T: Element>(py: Python<'py>) -> Bound<'py, PyArrayDescr> {
68    T::get_dtype(py)
69}
70
71/// Deprecated name for [`dtype`].
72#[deprecated(since = "0.23.0", note = "renamed to `dtype`")]
73#[inline]
74pub fn dtype_bound<'py, T: Element>(py: Python<'py>) -> Bound<'py, PyArrayDescr> {
75    dtype::<T>(py)
76}
77
78impl PyArrayDescr {
79    /// Creates a new type descriptor ("dtype") object from an arbitrary object.
80    ///
81    /// Equivalent to invoking the constructor of [`numpy.dtype`][dtype].
82    ///
83    /// [dtype]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.html
84    #[inline]
85    pub fn new<'a, 'py, T>(py: Python<'py>, ob: T) -> PyResult<Bound<'py, Self>>
86    where
87        T: IntoPyObject<'py>,
88    {
89        fn inner<'py>(
90            py: Python<'py>,
91            obj: Borrowed<'_, 'py, PyAny>,
92        ) -> PyResult<Bound<'py, PyArrayDescr>> {
93            let mut descr: *mut PyArray_Descr = ptr::null_mut();
94            unsafe {
95                // None is an invalid input here and is not converted to NPY_DEFAULT_TYPE
96                PY_ARRAY_API.PyArray_DescrConverter2(py, obj.as_ptr(), &mut descr);
97                Bound::from_owned_ptr_or_err(py, descr.cast())
98                    .map(|any| any.downcast_into_unchecked())
99            }
100        }
101
102        inner(
103            py,
104            ob.into_pyobject(py)
105                .map_err(Into::into)?
106                .into_any()
107                .as_borrowed(),
108        )
109    }
110
111    /// Deprecated name for [`PyArrayDescr::new`].
112    #[deprecated(since = "0.23.0", note = "renamed to `PyArrayDescr::new`")]
113    #[allow(deprecated)]
114    #[inline]
115    pub fn new_bound<'py, T: pyo3::ToPyObject + ?Sized>(
116        py: Python<'py>,
117        ob: &T,
118    ) -> PyResult<Bound<'py, Self>> {
119        Self::new(py, ob.to_object(py))
120    }
121
122    /// Shortcut for creating a type descriptor of `object` type.
123    #[inline]
124    pub fn object(py: Python<'_>) -> Bound<'_, Self> {
125        Self::from_npy_type(py, NPY_TYPES::NPY_OBJECT)
126    }
127
128    /// Deprecated name for [`PyArrayDescr::object`].
129    #[deprecated(since = "0.23.0", note = "renamed to `PyArrayDescr::object`")]
130    #[inline]
131    pub fn object_bound(py: Python<'_>) -> Bound<'_, Self> {
132        Self::object(py)
133    }
134
135    /// Returns the type descriptor for a registered type.
136    #[inline]
137    pub fn of<'py, T: Element>(py: Python<'py>) -> Bound<'py, Self> {
138        T::get_dtype(py)
139    }
140
141    /// Deprecated name for [`PyArrayDescr::of`].
142    #[deprecated(since = "0.23.0", note = "renamed to `PyArrayDescr::of`")]
143    #[inline]
144    pub fn of_bound<'py, T: Element>(py: Python<'py>) -> Bound<'py, Self> {
145        Self::of::<T>(py)
146    }
147
148    fn from_npy_type<'py>(py: Python<'py>, npy_type: NPY_TYPES) -> Bound<'py, Self> {
149        unsafe {
150            let descr = PY_ARRAY_API.PyArray_DescrFromType(py, npy_type as _);
151            Bound::from_owned_ptr(py, descr.cast()).downcast_into_unchecked()
152        }
153    }
154
155    pub(crate) fn new_from_npy_type<'py>(py: Python<'py>, npy_type: NPY_TYPES) -> Bound<'py, Self> {
156        unsafe {
157            let descr = PY_ARRAY_API.PyArray_DescrNewFromType(py, npy_type as _);
158            Bound::from_owned_ptr(py, descr.cast()).downcast_into_unchecked()
159        }
160    }
161}
162
163/// Implementation of functionality for [`PyArrayDescr`].
164#[doc(alias = "PyArrayDescr")]
165pub trait PyArrayDescrMethods<'py>: Sealed {
166    /// Returns `self` as `*mut PyArray_Descr`.
167    fn as_dtype_ptr(&self) -> *mut PyArray_Descr;
168
169    /// Returns `self` as `*mut PyArray_Descr` while increasing the reference count.
170    ///
171    /// Useful in cases where the descriptor is stolen by the API.
172    fn into_dtype_ptr(self) -> *mut PyArray_Descr;
173
174    /// Returns true if two type descriptors are equivalent.
175    fn is_equiv_to(&self, other: &Self) -> bool;
176
177    /// Returns the [array scalar][arrays-scalars] corresponding to this type descriptor.
178    ///
179    /// Equivalent to [`numpy.dtype.type`][dtype-type].
180    ///
181    /// [arrays-scalars]: https://numpy.org/doc/stable/reference/arrays.scalars.html
182    /// [dtype-type]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.type.html
183    fn typeobj(&self) -> Bound<'py, PyType>;
184
185    /// Returns a unique number for each of the 21 different built-in
186    /// [enumerated types][enumerated-types].
187    ///
188    /// These are roughly ordered from least-to-most precision.
189    ///
190    /// Equivalent to [`numpy.dtype.num`][dtype-num].
191    ///
192    /// [enumerated-types]: https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types
193    /// [dtype-num]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.num.html
194    fn num(&self) -> c_int {
195        unsafe { &*self.as_dtype_ptr() }.type_num
196    }
197
198    /// Returns the element size of this type descriptor.
199    ///
200    /// Equivalent to [`numpy.dtype.itemsize`][dtype-itemsize].
201    ///
202    /// [dtype-itemsiize]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.itemsize.html
203    fn itemsize(&self) -> usize;
204
205    /// Returns the required alignment (bytes) of this type descriptor according to the compiler.
206    ///
207    /// Equivalent to [`numpy.dtype.alignment`][dtype-alignment].
208    ///
209    /// [dtype-alignment]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.alignment.html
210    fn alignment(&self) -> usize;
211
212    /// Returns an ASCII character indicating the byte-order of this type descriptor object.
213    ///
214    /// All built-in data-type objects have byteorder either `=` or `|`.
215    ///
216    /// Equivalent to [`numpy.dtype.byteorder`][dtype-byteorder].
217    ///
218    /// [dtype-byteorder]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.byteorder.html
219    fn byteorder(&self) -> u8 {
220        unsafe { &*self.as_dtype_ptr() }.byteorder.max(0) as _
221    }
222
223    /// Returns a unique ASCII character for each of the 21 different built-in types.
224    ///
225    /// Note that structured data types are categorized as `V` (void).
226    ///
227    /// Equivalent to [`numpy.dtype.char`][dtype-char].
228    ///
229    /// [dtype-char]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.char.html
230    fn char(&self) -> u8 {
231        unsafe { &*self.as_dtype_ptr() }.type_.max(0) as _
232    }
233
234    /// Returns an ASCII character (one of `biufcmMOSUV`) identifying the general kind of data.
235    ///
236    /// Note that structured data types are categorized as `V` (void).
237    ///
238    /// Equivalent to [`numpy.dtype.kind`][dtype-kind].
239    ///
240    /// [dtype-kind]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html
241    fn kind(&self) -> u8 {
242        unsafe { &*self.as_dtype_ptr() }.kind.max(0) as _
243    }
244
245    /// Returns bit-flags describing how this type descriptor is to be interpreted.
246    ///
247    /// Equivalent to [`numpy.dtype.flags`][dtype-flags].
248    ///
249    /// [dtype-flags]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.flags.html
250    fn flags(&self) -> u64;
251
252    /// Returns the number of dimensions if this type descriptor represents a sub-array, and zero otherwise.
253    ///
254    /// Equivalent to [`numpy.dtype.ndim`][dtype-ndim].
255    ///
256    /// [dtype-ndim]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.ndim.html
257    fn ndim(&self) -> usize;
258
259    /// Returns the type descriptor for the base element of subarrays, regardless of their dimension or shape.
260    ///
261    /// If the dtype is not a subarray, returns self.
262    ///
263    /// Equivalent to [`numpy.dtype.base`][dtype-base].
264    ///
265    /// [dtype-base]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.base.html
266    fn base(&self) -> Bound<'py, PyArrayDescr>;
267
268    /// Returns the shape of the sub-array.
269    ///
270    /// If the dtype is not a sub-array, an empty vector is returned.
271    ///
272    /// Equivalent to [`numpy.dtype.shape`][dtype-shape].
273    ///
274    /// [dtype-shape]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.shape.html
275    fn shape(&self) -> Vec<usize>;
276
277    /// Returns true if the type descriptor contains any reference-counted objects in any fields or sub-dtypes.
278    ///
279    /// Equivalent to [`numpy.dtype.hasobject`][dtype-hasobject].
280    ///
281    /// [dtype-hasobject]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.hasobject.html
282    fn has_object(&self) -> bool {
283        self.flags() & NPY_ITEM_HASOBJECT != 0
284    }
285
286    /// Returns true if the type descriptor is a struct which maintains field alignment.
287    ///
288    /// This flag is sticky, so when combining multiple structs together, it is preserved
289    /// and produces new dtypes which are also aligned.
290    ///
291    /// Equivalent to [`numpy.dtype.isalignedstruct`][dtype-isalignedstruct].
292    ///
293    /// [dtype-isalignedstruct]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.isalignedstruct.html
294    fn is_aligned_struct(&self) -> bool {
295        self.flags() & NPY_ALIGNED_STRUCT != 0
296    }
297
298    /// Returns true if the type descriptor is a sub-array.
299    ///
300    /// Equivalent to PyDataType_HASSUBARRAY(self).
301    fn has_subarray(&self) -> bool;
302
303    /// Returns true if the type descriptor is a structured type.
304    ///
305    /// Equivalent to PyDataType_HASFIELDS(self).
306    fn has_fields(&self) -> bool;
307
308    /// Returns true if type descriptor byteorder is native, or `None` if not applicable.
309    fn is_native_byteorder(&self) -> Option<bool> {
310        // based on PyArray_ISNBO(self->byteorder)
311        match self.byteorder() {
312            b'=' => Some(true),
313            b'|' => None,
314            byteorder => Some(byteorder == NPY_BYTEORDER_CHAR::NPY_NATBYTE as u8),
315        }
316    }
317
318    /// Returns an ordered list of field names, or `None` if there are no fields.
319    ///
320    /// The names are ordered according to increasing byte offset.
321    ///
322    /// Equivalent to [`numpy.dtype.names`][dtype-names].
323    ///
324    /// [dtype-names]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.names.html
325    fn names(&self) -> Option<Vec<String>>;
326
327    /// Returns the type descriptor and offset of the field with the given name.
328    ///
329    /// This method will return an error if this type descriptor is not structured,
330    /// or if it does not contain a field with a given name.
331    ///
332    /// The list of all names can be found via [`PyArrayDescr::names`].
333    ///
334    /// Equivalent to retrieving a single item from [`numpy.dtype.fields`][dtype-fields].
335    ///
336    /// [dtype-fields]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.fields.html
337    fn get_field(&self, name: &str) -> PyResult<(Bound<'py, PyArrayDescr>, usize)>;
338}
339
340mod sealed {
341    pub trait Sealed {}
342}
343
344use sealed::Sealed;
345
346impl<'py> PyArrayDescrMethods<'py> for Bound<'py, PyArrayDescr> {
347    fn as_dtype_ptr(&self) -> *mut PyArray_Descr {
348        self.as_ptr() as _
349    }
350
351    fn into_dtype_ptr(self) -> *mut PyArray_Descr {
352        self.into_ptr() as _
353    }
354
355    fn is_equiv_to(&self, other: &Self) -> bool {
356        let self_ptr = self.as_dtype_ptr();
357        let other_ptr = other.as_dtype_ptr();
358
359        unsafe {
360            self_ptr == other_ptr
361                || PY_ARRAY_API.PyArray_EquivTypes(self.py(), self_ptr, other_ptr) != 0
362        }
363    }
364
365    fn typeobj(&self) -> Bound<'py, PyType> {
366        let dtype_type_ptr = unsafe { &*self.as_dtype_ptr() }.typeobj;
367        unsafe { PyType::from_borrowed_type_ptr(self.py(), dtype_type_ptr) }
368    }
369
370    fn itemsize(&self) -> usize {
371        unsafe { PyDataType_ELSIZE(self.py(), self.as_dtype_ptr()).max(0) as _ }
372    }
373
374    fn alignment(&self) -> usize {
375        unsafe { PyDataType_ALIGNMENT(self.py(), self.as_dtype_ptr()).max(0) as _ }
376    }
377
378    fn flags(&self) -> u64 {
379        unsafe { PyDataType_FLAGS(self.py(), self.as_dtype_ptr()) as _ }
380    }
381
382    fn ndim(&self) -> usize {
383        let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
384        match subarray {
385            None => 0,
386            Some(subarray) => unsafe { PyTuple_Size(subarray.shape) }.max(0) as _,
387        }
388    }
389
390    fn base(&self) -> Bound<'py, PyArrayDescr> {
391        let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
392        match subarray {
393            None => self.clone(),
394            Some(subarray) => unsafe {
395                Bound::from_borrowed_ptr(self.py(), subarray.base.cast()).downcast_into_unchecked()
396            },
397        }
398    }
399
400    fn shape(&self) -> Vec<usize> {
401        let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
402        match subarray {
403            None => Vec::new(),
404            Some(subarray) => {
405                // NumPy guarantees that shape is a tuple of non-negative integers so this should never panic.
406                let shape = unsafe { Borrowed::from_ptr(self.py(), subarray.shape) };
407                shape.extract().unwrap()
408            }
409        }
410    }
411
412    fn has_subarray(&self) -> bool {
413        unsafe { !PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).is_null() }
414    }
415
416    fn has_fields(&self) -> bool {
417        unsafe { !PyDataType_NAMES(self.py(), self.as_dtype_ptr()).is_null() }
418    }
419
420    fn names(&self) -> Option<Vec<String>> {
421        if !self.has_fields() {
422            return None;
423        }
424        let names = unsafe {
425            Borrowed::from_ptr(self.py(), PyDataType_NAMES(self.py(), self.as_dtype_ptr()))
426        };
427        names.extract().ok()
428    }
429
430    fn get_field(&self, name: &str) -> PyResult<(Bound<'py, PyArrayDescr>, usize)> {
431        if !self.has_fields() {
432            return Err(PyValueError::new_err(
433                "cannot get field information: type descriptor has no fields",
434            ));
435        }
436        let dict = unsafe {
437            Borrowed::from_ptr(self.py(), PyDataType_FIELDS(self.py(), self.as_dtype_ptr()))
438        };
439        let dict = unsafe { dict.downcast_unchecked::<PyDict>() };
440        // NumPy guarantees that fields are tuples of proper size and type, so this should never panic.
441        let tuple = dict
442            .get_item(name)?
443            .ok_or_else(|| PyIndexError::new_err(name.to_owned()))?
444            .downcast_into::<PyTuple>()
445            .unwrap();
446        // Note that we cannot just extract the entire tuple since the third element can be a title.
447        let dtype = tuple
448            .get_item(0)
449            .unwrap()
450            .downcast_into::<PyArrayDescr>()
451            .unwrap();
452        let offset = tuple.get_item(1).unwrap().extract().unwrap();
453        Ok((dtype, offset))
454    }
455}
456
457impl Sealed for Bound<'_, PyArrayDescr> {}
458
459/// Represents that a type can be an element of `PyArray`.
460///
461/// Currently, only integer/float/complex/object types are supported. The [NumPy documentation][enumerated-types]
462/// list the other built-in types which we are not yet implemented.
463///
464/// Note that NumPy's integer types like `numpy.int_` and `numpy.uint` are based on C's integer hierarchy
465/// which implies that their widths change depending on the platform's [data model][data-models].
466/// For example, `numpy.int_` matches C's `long` which is 32 bits wide on Windows (using the LLP64 data model)
467/// but 64 bits wide on Linux (using the LP64 data model).
468///
469/// In contrast, Rust's [`isize`] and [`usize`] types are defined to have the same width as a pointer
470/// and are therefore always 64 bits wide on 64-bit platforms. If you want to match NumPy's behaviour,
471/// consider using the [`c_long`][std::ffi::c_long] and [`c_ulong`][std::ffi::c_ulong] type aliases.
472///
473/// # Safety
474///
475/// A type `T` that implements this trait should be safe when managed by a NumPy
476/// array, thus implementing this trait is marked unsafe. Data types that don't
477/// contain Python objects (i.e., either the object type itself or record types
478/// containing object-type fields) are assumed to be trivially copyable, which
479/// is reflected in the `IS_COPY` flag. Furthermore, it is assumed that for
480/// the object type the elements are pointers into the Python heap and that the
481/// corresponding `Clone` implemenation will never panic as it only increases
482/// the reference count.
483///
484/// # Custom element types
485///
486/// Note that we cannot safely store `Py<T>` where `T: PyClass`, because the type information would be
487/// eliminated in the resulting NumPy array.
488/// In other words, objects are always treated as `Py<PyAny>` (a.k.a. `PyObject`) by Python code,
489/// and only `Py<PyAny>` can be stored in a type safe manner.
490///
491/// You can however create [`Array<Py<T>, D>`][ndarray::Array] and turn that into a NumPy array
492/// safely and efficiently using [`from_owned_object_array`][crate::PyArray::from_owned_object_array].
493///
494/// [enumerated-types]: https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types
495/// [data-models]: https://en.wikipedia.org/wiki/64-bit_computing#64-bit_data_models
496pub unsafe trait Element: Sized + Send + Sync {
497    /// Flag that indicates whether this type is trivially copyable.
498    ///
499    /// It should be set to true for all trivially copyable types (like scalar types
500    /// and record/array types only containing trivially copyable fields and elements).
501    ///
502    /// This flag should *always* be set to `false` for object types or record types
503    /// that contain object-type fields.
504    const IS_COPY: bool;
505
506    /// Returns the associated type descriptor ("dtype") for the given element type.
507    fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr>;
508
509    /// Deprecated name for [`Element::get_dtype`].
510    #[deprecated(since = "0.23.0", note = "renamed to `Element::get_dtype`")]
511    #[inline]
512    fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
513        Self::get_dtype(py)
514    }
515
516    /// Create a clone of the value while the GIL is guaranteed to be held.
517    fn clone_ref(&self, py: Python<'_>) -> Self;
518
519    /// Create an owned copy of the slice while the GIL is guaranteed to be held.
520    ///
521    /// Some types may provide implementations of this method that are more efficient
522    /// than simply mapping the `py_clone` method to each element in the slice.
523    #[inline]
524    fn vec_from_slice(py: Python<'_>, slc: &[Self]) -> Vec<Self> {
525        slc.iter().map(|elem| elem.clone_ref(py)).collect()
526    }
527
528    /// Create an owned copy of the array while the GIL is guaranteed to be held.
529    ///
530    /// Some types may provide implementations of this method that are more efficient
531    /// than simply mapping the `py_clone` method to each element in the view.
532    #[inline]
533    fn array_from_view<D>(
534        py: Python<'_>,
535        view: ::ndarray::ArrayView<'_, Self, D>,
536    ) -> ::ndarray::Array<Self, D>
537    where
538        D: ::ndarray::Dimension,
539    {
540        view.map(|elem| elem.clone_ref(py))
541    }
542}
543
544fn npy_int_type_lookup<T, T0, T1, T2>(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES {
545    // `npy_common.h` defines the integer aliases. In order, it checks:
546    // NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
547    // and assigns the alias to the first matching size, so we should check in this order.
548    match size_of::<T>() {
549        x if x == size_of::<T0>() => npy_types[0],
550        x if x == size_of::<T1>() => npy_types[1],
551        x if x == size_of::<T2>() => npy_types[2],
552        _ => panic!("Unable to match integer type descriptor: {:?}", npy_types),
553    }
554}
555
556fn npy_int_type<T: Bounded + Zero + Sized + PartialEq>() -> NPY_TYPES {
557    let is_unsigned = T::min_value() == T::zero();
558    let bit_width = 8 * size_of::<T>();
559
560    match (is_unsigned, bit_width) {
561        (false, 8) => NPY_TYPES::NPY_BYTE,
562        (false, 16) => NPY_TYPES::NPY_SHORT,
563        (false, 32) => npy_int_type_lookup::<i32, c_long, c_int, c_short>([
564            NPY_TYPES::NPY_LONG,
565            NPY_TYPES::NPY_INT,
566            NPY_TYPES::NPY_SHORT,
567        ]),
568        (false, 64) => npy_int_type_lookup::<i64, c_long, c_longlong, c_int>([
569            NPY_TYPES::NPY_LONG,
570            NPY_TYPES::NPY_LONGLONG,
571            NPY_TYPES::NPY_INT,
572        ]),
573        (true, 8) => NPY_TYPES::NPY_UBYTE,
574        (true, 16) => NPY_TYPES::NPY_USHORT,
575        (true, 32) => npy_int_type_lookup::<u32, c_ulong, c_uint, c_ushort>([
576            NPY_TYPES::NPY_ULONG,
577            NPY_TYPES::NPY_UINT,
578            NPY_TYPES::NPY_USHORT,
579        ]),
580        (true, 64) => npy_int_type_lookup::<u64, c_ulong, c_ulonglong, c_uint>([
581            NPY_TYPES::NPY_ULONG,
582            NPY_TYPES::NPY_ULONGLONG,
583            NPY_TYPES::NPY_UINT,
584        ]),
585        _ => unreachable!(),
586    }
587}
588
589// Invoke within the `Element` impl for a `Clone` type to provide an efficient
590// implementation of the cloning methods
591macro_rules! clone_methods_impl {
592    ($Self:ty) => {
593        #[inline]
594        fn clone_ref(&self, _py: ::pyo3::Python<'_>) -> $Self {
595            ::std::clone::Clone::clone(self)
596        }
597
598        #[inline]
599        fn vec_from_slice(_py: ::pyo3::Python<'_>, slc: &[$Self]) -> Vec<$Self> {
600            ::std::borrow::ToOwned::to_owned(slc)
601        }
602
603        #[inline]
604        fn array_from_view<D>(
605            _py: ::pyo3::Python<'_>,
606            view: ::ndarray::ArrayView<'_, $Self, D>,
607        ) -> ::ndarray::Array<$Self, D>
608        where
609            D: ::ndarray::Dimension,
610        {
611            ::ndarray::ArrayView::to_owned(&view)
612        }
613    };
614}
615pub(crate) use clone_methods_impl;
616use pyo3::BoundObject;
617
618macro_rules! impl_element_scalar {
619    (@impl: $ty:ty, $npy_type:expr $(,#[$meta:meta])*) => {
620        $(#[$meta])*
621        unsafe impl Element for $ty {
622            const IS_COPY: bool = true;
623
624            fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
625                PyArrayDescr::from_npy_type(py, $npy_type)
626            }
627
628            clone_methods_impl!($ty);
629        }
630    };
631    ($ty:ty => $npy_type:ident $(,#[$meta:meta])*) => {
632        impl_element_scalar!(@impl: $ty, NPY_TYPES::$npy_type $(,#[$meta])*);
633    };
634    ($($tys:ty),+) => {
635        $(impl_element_scalar!(@impl: $tys, npy_int_type::<$tys>());)+
636    };
637}
638
639impl_element_scalar!(bool => NPY_BOOL);
640
641impl_element_scalar!(i8, i16, i32, i64);
642impl_element_scalar!(u8, u16, u32, u64);
643
644impl_element_scalar!(f32 => NPY_FLOAT);
645impl_element_scalar!(f64 => NPY_DOUBLE);
646
647#[cfg(feature = "half")]
648impl_element_scalar!(f16 => NPY_HALF);
649
650#[cfg(feature = "half")]
651unsafe impl Element for bf16 {
652    const IS_COPY: bool = true;
653
654    fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
655        static DTYPE: GILOnceCell<Py<PyArrayDescr>> = GILOnceCell::new();
656
657        DTYPE
658            .get_or_init(py, || {
659                PyArrayDescr::new(py, "bfloat16").expect("A package which provides a `bfloat16` data type for NumPy is required to use the `half::bf16` element type.").unbind()
660            })
661            .clone_ref(py)
662            .into_bound(py)
663    }
664
665    clone_methods_impl!(Self);
666}
667
668impl_element_scalar!(Complex32 => NPY_CFLOAT,
669    #[doc = "Complex type with `f32` components which maps to `numpy.csingle` (`numpy.complex64`)."]);
670impl_element_scalar!(Complex64 => NPY_CDOUBLE,
671    #[doc = "Complex type with `f64` components which maps to `numpy.cdouble` (`numpy.complex128`)."]);
672
673#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
674impl_element_scalar!(usize, isize);
675
676unsafe impl Element for PyObject {
677    const IS_COPY: bool = false;
678
679    fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
680        PyArrayDescr::object(py)
681    }
682
683    #[inline]
684    fn clone_ref(&self, py: Python<'_>) -> Self {
685        Py::clone_ref(self, py)
686    }
687}
688
689#[cfg(test)]
690mod tests {
691    use super::*;
692
693    use pyo3::types::PyString;
694    use pyo3::{py_run, types::PyTypeMethods};
695
696    use crate::npyffi::{is_numpy_2, NPY_NEEDS_PYAPI};
697
698    #[test]
699    fn test_dtype_new() {
700        Python::with_gil(|py| {
701            assert!(PyArrayDescr::new(py, "float64")
702                .unwrap()
703                .is(&dtype::<f64>(py)));
704
705            let dt = PyArrayDescr::new(py, [("a", "O"), ("b", "?")].as_ref()).unwrap();
706            assert_eq!(dt.names(), Some(vec!["a".to_owned(), "b".to_owned()]));
707            assert!(dt.has_object());
708            assert!(dt.get_field("a").unwrap().0.is(&dtype::<PyObject>(py)));
709            assert!(dt.get_field("b").unwrap().0.is(&dtype::<bool>(py)));
710
711            assert!(PyArrayDescr::new(py, 123_usize).is_err());
712        });
713    }
714
715    #[test]
716    fn test_dtype_names() {
717        fn type_name<T: Element>(py: Python<'_>) -> Bound<'_, PyString> {
718            dtype::<T>(py).typeobj().qualname().unwrap()
719        }
720        Python::with_gil(|py| {
721            if is_numpy_2(py) {
722                assert_eq!(type_name::<bool>(py), "bool");
723            } else {
724                assert_eq!(type_name::<bool>(py), "bool_");
725            }
726
727            assert_eq!(type_name::<i8>(py), "int8");
728            assert_eq!(type_name::<i16>(py), "int16");
729            assert_eq!(type_name::<i32>(py), "int32");
730            assert_eq!(type_name::<i64>(py), "int64");
731            assert_eq!(type_name::<u8>(py), "uint8");
732            assert_eq!(type_name::<u16>(py), "uint16");
733            assert_eq!(type_name::<u32>(py), "uint32");
734            assert_eq!(type_name::<u64>(py), "uint64");
735            assert_eq!(type_name::<f32>(py), "float32");
736            assert_eq!(type_name::<f64>(py), "float64");
737
738            assert_eq!(type_name::<Complex32>(py), "complex64");
739            assert_eq!(type_name::<Complex64>(py), "complex128");
740
741            #[cfg(target_pointer_width = "32")]
742            {
743                assert_eq!(type_name::<usize>(py), "uint32");
744                assert_eq!(type_name::<isize>(py), "int32");
745            }
746
747            #[cfg(target_pointer_width = "64")]
748            {
749                assert_eq!(type_name::<usize>(py), "uint64");
750                assert_eq!(type_name::<isize>(py), "int64");
751            }
752        });
753    }
754
755    #[test]
756    fn test_dtype_methods_scalar() {
757        Python::with_gil(|py| {
758            let dt = dtype::<f64>(py);
759
760            assert_eq!(dt.num(), NPY_TYPES::NPY_DOUBLE as c_int);
761            assert_eq!(dt.flags(), 0);
762            assert_eq!(dt.typeobj().qualname().unwrap(), "float64");
763            assert_eq!(dt.char(), b'd');
764            assert_eq!(dt.kind(), b'f');
765            assert_eq!(dt.byteorder(), b'=');
766            assert_eq!(dt.is_native_byteorder(), Some(true));
767            assert_eq!(dt.itemsize(), 8);
768            assert_eq!(dt.alignment(), 8);
769            assert!(!dt.has_object());
770            assert!(dt.names().is_none());
771            assert!(!dt.has_fields());
772            assert!(!dt.is_aligned_struct());
773            assert!(!dt.has_subarray());
774            assert!(dt.base().is_equiv_to(&dt));
775            assert_eq!(dt.ndim(), 0);
776            assert_eq!(dt.shape(), Vec::<usize>::new());
777        });
778    }
779
780    #[test]
781    fn test_dtype_methods_subarray() {
782        Python::with_gil(|py| {
783            let locals = PyDict::new(py);
784            py_run!(
785                py,
786                *locals,
787                "dtype = __import__('numpy').dtype(('f8', (2, 3)))"
788            );
789            let dt = locals
790                .get_item("dtype")
791                .unwrap()
792                .unwrap()
793                .downcast_into::<PyArrayDescr>()
794                .unwrap();
795
796            assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
797            assert_eq!(dt.flags(), 0);
798            assert_eq!(dt.typeobj().qualname().unwrap(), "void");
799            assert_eq!(dt.char(), b'V');
800            assert_eq!(dt.kind(), b'V');
801            assert_eq!(dt.byteorder(), b'|');
802            assert_eq!(dt.is_native_byteorder(), None);
803            assert_eq!(dt.itemsize(), 48);
804            assert_eq!(dt.alignment(), 8);
805            assert!(!dt.has_object());
806            assert!(dt.names().is_none());
807            assert!(!dt.has_fields());
808            assert!(!dt.is_aligned_struct());
809            assert!(dt.has_subarray());
810            assert_eq!(dt.ndim(), 2);
811            assert_eq!(dt.shape(), vec![2, 3]);
812            assert!(dt.base().is_equiv_to(&dtype::<f64>(py)));
813        });
814    }
815
816    #[test]
817    fn test_dtype_methods_record() {
818        Python::with_gil(|py| {
819            let locals = PyDict::new(py);
820            py_run!(
821                py,
822                *locals,
823                "dtype = __import__('numpy').dtype([('x', 'u1'), ('y', 'f8'), ('z', 'O')], align=True)"
824            );
825            let dt = locals
826                .get_item("dtype")
827                .unwrap()
828                .unwrap()
829                .downcast_into::<PyArrayDescr>()
830                .unwrap();
831
832            assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
833            assert_ne!(dt.flags() & NPY_ITEM_HASOBJECT, 0);
834            assert_ne!(dt.flags() & NPY_NEEDS_PYAPI, 0);
835            assert_ne!(dt.flags() & NPY_ALIGNED_STRUCT, 0);
836            assert_eq!(dt.typeobj().qualname().unwrap(), "void");
837            assert_eq!(dt.char(), b'V');
838            assert_eq!(dt.kind(), b'V');
839            assert_eq!(dt.byteorder(), b'|');
840            assert_eq!(dt.is_native_byteorder(), None);
841            assert_eq!(dt.itemsize(), 24);
842            assert_eq!(dt.alignment(), 8);
843            assert!(dt.has_object());
844            assert_eq!(
845                dt.names(),
846                Some(vec!["x".to_owned(), "y".to_owned(), "z".to_owned()])
847            );
848            assert!(dt.has_fields());
849            assert!(dt.is_aligned_struct());
850            assert!(!dt.has_subarray());
851            assert_eq!(dt.ndim(), 0);
852            assert_eq!(dt.shape(), Vec::<usize>::new());
853            assert!(dt.base().is_equiv_to(&dt));
854            let x = dt.get_field("x").unwrap();
855            assert!(x.0.is_equiv_to(&dtype::<u8>(py)));
856            assert_eq!(x.1, 0);
857            let y = dt.get_field("y").unwrap();
858            assert!(y.0.is_equiv_to(&dtype::<f64>(py)));
859            assert_eq!(y.1, 8);
860            let z = dt.get_field("z").unwrap();
861            assert!(z.0.is_equiv_to(&dtype::<PyObject>(py)));
862            assert_eq!(z.1, 16);
863        });
864    }
865}