1use std::cell::RefCell;
58use std::collections::hash_map::Entry;
59use std::fmt;
60use std::hash::Hash;
61use std::marker::PhantomData;
62
63use pyo3::{sync::GILProtected, Bound, Py, Python};
64use rustc_hash::FxHashMap;
65
66use crate::dtype::{clone_methods_impl, Element, PyArrayDescr, PyArrayDescrMethods};
67use crate::npyffi::{
68 PyArray_DatetimeDTypeMetaData, PyDataType_C_METADATA, NPY_DATETIMEUNIT, NPY_TYPES,
69};
70
71pub trait Unit: Send + Sync + Clone + Copy + PartialEq + Eq + Hash + PartialOrd + Ord {
75 const UNIT: NPY_DATETIMEUNIT;
79
80 const ABBREV: &'static str;
82}
83
84macro_rules! define_units {
85 ($($(#[$meta:meta])* $struct:ident => $unit:ident $abbrev:literal,)+) => {
86 $(
87
88 $(#[$meta])*
89 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
90 pub struct $struct;
91
92 impl Unit for $struct {
93 const UNIT: NPY_DATETIMEUNIT = NPY_DATETIMEUNIT::$unit;
94
95 const ABBREV: &'static str = $abbrev;
96 }
97
98 )+
99 };
100}
101
102pub mod units {
104 use super::*;
105
106 define_units!(
107 #[doc = "Years, i.e. 12 months"]
108 Years => NPY_FR_Y "a",
109 #[doc = "Months, i.e. 30 days"]
110 Months => NPY_FR_M "mo",
111 #[doc = "Weeks, i.e. 7 days"]
112 Weeks => NPY_FR_W "w",
113 #[doc = "Days, i.e. 24 hours"]
114 Days => NPY_FR_D "d",
115 #[doc = "Hours, i.e. 60 minutes"]
116 Hours => NPY_FR_h "h",
117 #[doc = "Minutes, i.e. 60 seconds"]
118 Minutes => NPY_FR_m "min",
119 #[doc = "Seconds"]
120 Seconds => NPY_FR_s "s",
121 #[doc = "Milliseconds, i.e. 10^-3 seconds"]
122 Milliseconds => NPY_FR_ms "ms",
123 #[doc = "Microseconds, i.e. 10^-6 seconds"]
124 Microseconds => NPY_FR_us "µs",
125 #[doc = "Nanoseconds, i.e. 10^-9 seconds"]
126 Nanoseconds => NPY_FR_ns "ns",
127 #[doc = "Picoseconds, i.e. 10^-12 seconds"]
128 Picoseconds => NPY_FR_ps "ps",
129 #[doc = "Femtoseconds, i.e. 10^-15 seconds"]
130 Femtoseconds => NPY_FR_fs "fs",
131 #[doc = "Attoseconds, i.e. 10^-18 seconds"]
132 Attoseconds => NPY_FR_as "as",
133 );
134}
135
136#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
140#[repr(transparent)]
141pub struct Datetime<U: Unit>(i64, PhantomData<U>);
142
143impl<U: Unit> From<i64> for Datetime<U> {
144 fn from(val: i64) -> Self {
145 Self(val, PhantomData)
146 }
147}
148
149impl<U: Unit> From<Datetime<U>> for i64 {
150 fn from(val: Datetime<U>) -> Self {
151 val.0
152 }
153}
154
155unsafe impl<U: Unit> Element for Datetime<U> {
156 const IS_COPY: bool = true;
157
158 fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
159 static DTYPES: TypeDescriptors = unsafe { TypeDescriptors::new(NPY_TYPES::NPY_DATETIME) };
160
161 DTYPES.from_unit(py, U::UNIT)
162 }
163
164 clone_methods_impl!(Self);
165}
166
167impl<U: Unit> fmt::Debug for Datetime<U> {
168 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169 write!(f, "Datetime({} {})", self.0, U::ABBREV)
170 }
171}
172
173#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
177#[repr(transparent)]
178pub struct Timedelta<U: Unit>(i64, PhantomData<U>);
179
180impl<U: Unit> From<i64> for Timedelta<U> {
181 fn from(val: i64) -> Self {
182 Self(val, PhantomData)
183 }
184}
185
186impl<U: Unit> From<Timedelta<U>> for i64 {
187 fn from(val: Timedelta<U>) -> Self {
188 val.0
189 }
190}
191
192unsafe impl<U: Unit> Element for Timedelta<U> {
193 const IS_COPY: bool = true;
194
195 fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
196 static DTYPES: TypeDescriptors = unsafe { TypeDescriptors::new(NPY_TYPES::NPY_TIMEDELTA) };
197
198 DTYPES.from_unit(py, U::UNIT)
199 }
200
201 clone_methods_impl!(Self);
202}
203
204impl<U: Unit> fmt::Debug for Timedelta<U> {
205 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206 write!(f, "Timedelta({} {})", self.0, U::ABBREV)
207 }
208}
209
210struct TypeDescriptors {
211 npy_type: NPY_TYPES,
212 #[allow(clippy::type_complexity)]
213 dtypes: GILProtected<RefCell<Option<FxHashMap<NPY_DATETIMEUNIT, Py<PyArrayDescr>>>>>,
214}
215
216impl TypeDescriptors {
217 const unsafe fn new(npy_type: NPY_TYPES) -> Self {
219 Self {
220 npy_type,
221 dtypes: GILProtected::new(RefCell::new(None)),
222 }
223 }
224
225 #[allow(clippy::wrong_self_convention)]
226 fn from_unit<'py>(&self, py: Python<'py>, unit: NPY_DATETIMEUNIT) -> Bound<'py, PyArrayDescr> {
227 let mut dtypes = self.dtypes.get(py).borrow_mut();
228
229 let dtype = match dtypes.get_or_insert_with(Default::default).entry(unit) {
230 Entry::Occupied(entry) => entry.into_mut(),
231 Entry::Vacant(entry) => {
232 let dtype = PyArrayDescr::new_from_npy_type(py, self.npy_type);
233
234 unsafe {
236 let metadata = &mut *(PyDataType_C_METADATA(py, dtype.as_dtype_ptr())
237 as *mut PyArray_DatetimeDTypeMetaData);
238
239 metadata.meta.base = unit;
240 metadata.meta.num = 1;
241 }
242
243 entry.insert(dtype.into())
244 }
245 };
246
247 dtype.bind(py).to_owned()
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 use pyo3::{
256 ffi::c_str,
257 py_run,
258 types::{PyAnyMethods, PyDict, PyModule},
259 };
260
261 use crate::array::{PyArray1, PyArrayMethods};
262
263 #[test]
264 fn from_python_to_rust() {
265 Python::with_gil(|py| {
266 let locals = py
267 .eval(c_str!("{ 'np': __import__('numpy') }"), None, None)
268 .unwrap()
269 .downcast_into::<PyDict>()
270 .unwrap();
271
272 let array = py
273 .eval(
274 c_str!("np.array([np.datetime64('1970-01-01')])"),
275 None,
276 Some(&locals),
277 )
278 .unwrap()
279 .downcast_into::<PyArray1<Datetime<units::Days>>>()
280 .unwrap();
281
282 let value: i64 = array.get_owned(0).unwrap().into();
283 assert_eq!(value, 0);
284 });
285 }
286
287 #[test]
288 fn from_rust_to_python() {
289 Python::with_gil(|py| {
290 let array = PyArray1::<Timedelta<units::Minutes>>::zeros(py, 1, false);
291
292 *array.readwrite().get_mut(0).unwrap() = Timedelta::<units::Minutes>::from(5);
293
294 let np = py
295 .eval(c_str!("__import__('numpy')"), None, None)
296 .unwrap()
297 .downcast_into::<PyModule>()
298 .unwrap();
299
300 py_run!(py, array np, "assert array.dtype == np.dtype('timedelta64[m]')");
301 py_run!(py, array np, "assert array[0] == np.timedelta64(5, 'm')");
302 });
303 }
304
305 #[test]
306 fn debug_formatting() {
307 assert_eq!(
308 format!("{:?}", Datetime::<units::Days>::from(28)),
309 "Datetime(28 d)"
310 );
311
312 assert_eq!(
313 format!("{:?}", Timedelta::<units::Milliseconds>::from(160)),
314 "Timedelta(160 ms)"
315 );
316 }
317
318 #[test]
319 fn unit_conversion() {
320 #[track_caller]
321 fn convert<'py, S: Unit, D: Unit>(py: Python<'py>, expected_value: i64) {
322 let array = PyArray1::<Timedelta<S>>::from_slice(py, &[Timedelta::<S>::from(1)]);
323 let array = array.cast::<Timedelta<D>>(false).unwrap();
324
325 let value: i64 = array.get_owned(0).unwrap().into();
326 assert_eq!(value, expected_value);
327 }
328
329 Python::with_gil(|py| {
330 convert::<units::Years, units::Days>(py, (97 + 400 * 365) / 400);
331 convert::<units::Months, units::Days>(py, (97 + 400 * 365) / 400 / 12);
332
333 convert::<units::Weeks, units::Seconds>(py, 7 * 24 * 60 * 60);
334 convert::<units::Days, units::Seconds>(py, 24 * 60 * 60);
335 convert::<units::Hours, units::Seconds>(py, 60 * 60);
336 convert::<units::Minutes, units::Seconds>(py, 60);
337
338 convert::<units::Seconds, units::Milliseconds>(py, 1_000);
339 convert::<units::Seconds, units::Microseconds>(py, 1_000_000);
340 convert::<units::Seconds, units::Nanoseconds>(py, 1_000_000_000);
341 convert::<units::Seconds, units::Picoseconds>(py, 1_000_000_000_000);
342 convert::<units::Seconds, units::Femtoseconds>(py, 1_000_000_000_000_000);
343
344 convert::<units::Femtoseconds, units::Attoseconds>(py, 1_000);
345 });
346 }
347}