polars_utils/
total_ord.rs

1use std::cmp::Ordering;
2use std::hash::{Hash, Hasher};
3
4use bytemuck::TransparentWrapper;
5
6use crate::hashing::{BytesHash, DirtyHash};
7use crate::nulls::IsNull;
8
9/// Converts an f32 into a canonical form, where -0 == 0 and all NaNs map to
10/// the same value.
11#[inline]
12pub fn canonical_f32(x: f32) -> f32 {
13    // -0.0 + 0.0 becomes 0.0.
14    let convert_zero = x + 0.0;
15    if convert_zero.is_nan() {
16        f32::from_bits(0x7fc00000) // Canonical quiet NaN.
17    } else {
18        convert_zero
19    }
20}
21
22/// Converts an f64 into a canonical form, where -0 == 0 and all NaNs map to
23/// the same value.
24#[inline]
25pub fn canonical_f64(x: f64) -> f64 {
26    // -0.0 + 0.0 becomes 0.0.
27    let convert_zero = x + 0.0;
28    if convert_zero.is_nan() {
29        f64::from_bits(0x7ff8000000000000) // Canonical quiet NaN.
30    } else {
31        convert_zero
32    }
33}
34
35/// Alternative trait for Eq. By consistently using this we can still be
36/// generic w.r.t Eq while getting a total ordering for floats.
37pub trait TotalEq {
38    fn tot_eq(&self, other: &Self) -> bool;
39
40    #[inline]
41    fn tot_ne(&self, other: &Self) -> bool {
42        !(self.tot_eq(other))
43    }
44}
45
46/// Alternative trait for Ord. By consistently using this we can still be
47/// generic w.r.t Ord while getting a total ordering for floats.
48pub trait TotalOrd: TotalEq {
49    fn tot_cmp(&self, other: &Self) -> Ordering;
50
51    #[inline]
52    fn tot_lt(&self, other: &Self) -> bool {
53        self.tot_cmp(other) == Ordering::Less
54    }
55
56    #[inline]
57    fn tot_gt(&self, other: &Self) -> bool {
58        self.tot_cmp(other) == Ordering::Greater
59    }
60
61    #[inline]
62    fn tot_le(&self, other: &Self) -> bool {
63        self.tot_cmp(other) != Ordering::Greater
64    }
65
66    #[inline]
67    fn tot_ge(&self, other: &Self) -> bool {
68        self.tot_cmp(other) != Ordering::Less
69    }
70}
71
72/// Alternative trait for Hash. By consistently using this we can still be
73/// generic w.r.t Hash while being able to hash floats.
74pub trait TotalHash {
75    fn tot_hash<H>(&self, state: &mut H)
76    where
77        H: Hasher;
78
79    fn tot_hash_slice<H>(data: &[Self], state: &mut H)
80    where
81        H: Hasher,
82        Self: Sized,
83    {
84        for piece in data {
85            piece.tot_hash(state)
86        }
87    }
88}
89
90#[repr(transparent)]
91pub struct TotalOrdWrap<T>(pub T);
92unsafe impl<T> TransparentWrapper<T> for TotalOrdWrap<T> {}
93
94impl<T: TotalOrd> PartialOrd for TotalOrdWrap<T> {
95    #[inline(always)]
96    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
97        Some(self.cmp(other))
98    }
99
100    #[inline(always)]
101    fn lt(&self, other: &Self) -> bool {
102        self.0.tot_lt(&other.0)
103    }
104
105    #[inline(always)]
106    fn le(&self, other: &Self) -> bool {
107        self.0.tot_le(&other.0)
108    }
109
110    #[inline(always)]
111    fn gt(&self, other: &Self) -> bool {
112        self.0.tot_gt(&other.0)
113    }
114
115    #[inline(always)]
116    fn ge(&self, other: &Self) -> bool {
117        self.0.tot_ge(&other.0)
118    }
119}
120
121impl<T: TotalOrd> Ord for TotalOrdWrap<T> {
122    #[inline(always)]
123    fn cmp(&self, other: &Self) -> Ordering {
124        self.0.tot_cmp(&other.0)
125    }
126}
127
128impl<T: TotalEq> PartialEq for TotalOrdWrap<T> {
129    #[inline(always)]
130    fn eq(&self, other: &Self) -> bool {
131        self.0.tot_eq(&other.0)
132    }
133
134    #[inline(always)]
135    #[allow(clippy::partialeq_ne_impl)]
136    fn ne(&self, other: &Self) -> bool {
137        self.0.tot_ne(&other.0)
138    }
139}
140
141impl<T: TotalEq> Eq for TotalOrdWrap<T> {}
142
143impl<T: TotalHash> Hash for TotalOrdWrap<T> {
144    #[inline(always)]
145    fn hash<H: Hasher>(&self, state: &mut H) {
146        self.0.tot_hash(state);
147    }
148}
149
150impl<T: Clone> Clone for TotalOrdWrap<T> {
151    #[inline]
152    fn clone(&self) -> Self {
153        Self(self.0.clone())
154    }
155}
156
157impl<T: Copy> Copy for TotalOrdWrap<T> {}
158
159impl<T: IsNull> IsNull for TotalOrdWrap<T> {
160    const HAS_NULLS: bool = T::HAS_NULLS;
161    type Inner = T::Inner;
162
163    #[inline(always)]
164    fn is_null(&self) -> bool {
165        self.0.is_null()
166    }
167
168    #[inline(always)]
169    fn unwrap_inner(self) -> Self::Inner {
170        self.0.unwrap_inner()
171    }
172}
173
174impl DirtyHash for f32 {
175    #[inline(always)]
176    fn dirty_hash(&self) -> u64 {
177        canonical_f32(*self).to_bits().dirty_hash()
178    }
179}
180
181impl DirtyHash for f64 {
182    #[inline(always)]
183    fn dirty_hash(&self) -> u64 {
184        canonical_f64(*self).to_bits().dirty_hash()
185    }
186}
187
188impl<T: DirtyHash> DirtyHash for TotalOrdWrap<T> {
189    #[inline(always)]
190    fn dirty_hash(&self) -> u64 {
191        self.0.dirty_hash()
192    }
193}
194
195macro_rules! impl_trivial_total {
196    ($T: ty) => {
197        impl TotalEq for $T {
198            #[inline(always)]
199            fn tot_eq(&self, other: &Self) -> bool {
200                self == other
201            }
202
203            #[inline(always)]
204            fn tot_ne(&self, other: &Self) -> bool {
205                self != other
206            }
207        }
208
209        impl TotalOrd for $T {
210            #[inline(always)]
211            fn tot_cmp(&self, other: &Self) -> Ordering {
212                self.cmp(other)
213            }
214
215            #[inline(always)]
216            fn tot_lt(&self, other: &Self) -> bool {
217                self < other
218            }
219
220            #[inline(always)]
221            fn tot_gt(&self, other: &Self) -> bool {
222                self > other
223            }
224
225            #[inline(always)]
226            fn tot_le(&self, other: &Self) -> bool {
227                self <= other
228            }
229
230            #[inline(always)]
231            fn tot_ge(&self, other: &Self) -> bool {
232                self >= other
233            }
234        }
235
236        impl TotalHash for $T {
237            #[inline(always)]
238            fn tot_hash<H>(&self, state: &mut H)
239            where
240                H: Hasher,
241            {
242                self.hash(state);
243            }
244        }
245    };
246}
247
248// We can't do a blanket impl because Rust complains f32 might implement
249// Ord / Eq someday.
250impl_trivial_total!(bool);
251impl_trivial_total!(u8);
252impl_trivial_total!(u16);
253impl_trivial_total!(u32);
254impl_trivial_total!(u64);
255impl_trivial_total!(u128);
256impl_trivial_total!(usize);
257impl_trivial_total!(i8);
258impl_trivial_total!(i16);
259impl_trivial_total!(i32);
260impl_trivial_total!(i64);
261impl_trivial_total!(i128);
262impl_trivial_total!(isize);
263impl_trivial_total!(char);
264impl_trivial_total!(&str);
265impl_trivial_total!(&[u8]);
266impl_trivial_total!(String);
267
268macro_rules! impl_float_eq_ord {
269    ($T:ty) => {
270        impl TotalEq for $T {
271            #[inline]
272            fn tot_eq(&self, other: &Self) -> bool {
273                if self.is_nan() {
274                    other.is_nan()
275                } else {
276                    self == other
277                }
278            }
279        }
280
281        impl TotalOrd for $T {
282            #[inline(always)]
283            fn tot_cmp(&self, other: &Self) -> Ordering {
284                if self.tot_lt(other) {
285                    Ordering::Less
286                } else if self.tot_gt(other) {
287                    Ordering::Greater
288                } else {
289                    Ordering::Equal
290                }
291            }
292
293            #[inline(always)]
294            fn tot_lt(&self, other: &Self) -> bool {
295                !self.tot_ge(other)
296            }
297
298            #[inline(always)]
299            fn tot_gt(&self, other: &Self) -> bool {
300                other.tot_lt(self)
301            }
302
303            #[inline(always)]
304            fn tot_le(&self, other: &Self) -> bool {
305                other.tot_ge(self)
306            }
307
308            #[inline(always)]
309            fn tot_ge(&self, other: &Self) -> bool {
310                // We consider all NaNs equal, and NaN is the largest possible
311                // value. Thus if self is NaN we always return true. Otherwise
312                // self >= other is correct. If other is not NaN it is trivially
313                // correct, and if it is we note that nothing can be greater or
314                // equal to NaN except NaN itself, which we already handled earlier.
315                self.is_nan() | (self >= other)
316            }
317        }
318    };
319}
320
321impl_float_eq_ord!(f32);
322impl_float_eq_ord!(f64);
323
324impl TotalHash for f32 {
325    #[inline(always)]
326    fn tot_hash<H>(&self, state: &mut H)
327    where
328        H: Hasher,
329    {
330        canonical_f32(*self).to_bits().hash(state)
331    }
332}
333
334impl TotalHash for f64 {
335    #[inline(always)]
336    fn tot_hash<H>(&self, state: &mut H)
337    where
338        H: Hasher,
339    {
340        canonical_f64(*self).to_bits().hash(state)
341    }
342}
343
344// Blanket implementations.
345impl<T: TotalEq> TotalEq for Option<T> {
346    #[inline(always)]
347    fn tot_eq(&self, other: &Self) -> bool {
348        match (self, other) {
349            (None, None) => true,
350            (Some(a), Some(b)) => a.tot_eq(b),
351            _ => false,
352        }
353    }
354
355    #[inline(always)]
356    fn tot_ne(&self, other: &Self) -> bool {
357        match (self, other) {
358            (None, None) => false,
359            (Some(a), Some(b)) => a.tot_ne(b),
360            _ => true,
361        }
362    }
363}
364
365impl<T: TotalOrd> TotalOrd for Option<T> {
366    #[inline(always)]
367    fn tot_cmp(&self, other: &Self) -> Ordering {
368        match (self, other) {
369            (None, None) => Ordering::Equal,
370            (None, Some(_)) => Ordering::Less,
371            (Some(_), None) => Ordering::Greater,
372            (Some(a), Some(b)) => a.tot_cmp(b),
373        }
374    }
375
376    #[inline(always)]
377    fn tot_lt(&self, other: &Self) -> bool {
378        match (self, other) {
379            (None, Some(_)) => true,
380            (Some(a), Some(b)) => a.tot_lt(b),
381            _ => false,
382        }
383    }
384
385    #[inline(always)]
386    fn tot_gt(&self, other: &Self) -> bool {
387        other.tot_lt(self)
388    }
389
390    #[inline(always)]
391    fn tot_le(&self, other: &Self) -> bool {
392        match (self, other) {
393            (Some(_), None) => false,
394            (Some(a), Some(b)) => a.tot_lt(b),
395            _ => true,
396        }
397    }
398
399    #[inline(always)]
400    fn tot_ge(&self, other: &Self) -> bool {
401        other.tot_le(self)
402    }
403}
404
405impl<T: TotalHash> TotalHash for Option<T> {
406    #[inline]
407    fn tot_hash<H>(&self, state: &mut H)
408    where
409        H: Hasher,
410    {
411        self.is_some().tot_hash(state);
412        if let Some(slf) = self {
413            slf.tot_hash(state)
414        }
415    }
416}
417
418impl<T: TotalEq + ?Sized> TotalEq for &T {
419    #[inline(always)]
420    fn tot_eq(&self, other: &Self) -> bool {
421        (*self).tot_eq(*other)
422    }
423
424    #[inline(always)]
425    fn tot_ne(&self, other: &Self) -> bool {
426        (*self).tot_ne(*other)
427    }
428}
429
430impl<T: TotalHash + ?Sized> TotalHash for &T {
431    #[inline(always)]
432    fn tot_hash<H>(&self, state: &mut H)
433    where
434        H: Hasher,
435    {
436        (*self).tot_hash(state)
437    }
438}
439
440impl<T: TotalEq, U: TotalEq> TotalEq for (T, U) {
441    #[inline]
442    fn tot_eq(&self, other: &Self) -> bool {
443        self.0.tot_eq(&other.0) && self.1.tot_eq(&other.1)
444    }
445}
446
447impl<T: TotalOrd, U: TotalOrd> TotalOrd for (T, U) {
448    #[inline]
449    fn tot_cmp(&self, other: &Self) -> Ordering {
450        self.0
451            .tot_cmp(&other.0)
452            .then_with(|| self.1.tot_cmp(&other.1))
453    }
454}
455
456impl TotalHash for BytesHash<'_> {
457    #[inline(always)]
458    fn tot_hash<H>(&self, state: &mut H)
459    where
460        H: Hasher,
461    {
462        self.hash(state)
463    }
464}
465
466impl TotalEq for BytesHash<'_> {
467    #[inline(always)]
468    fn tot_eq(&self, other: &Self) -> bool {
469        self == other
470    }
471}
472
473/// This elides creating a [`TotalOrdWrap`] for types that don't need it.
474pub trait ToTotalOrd {
475    type TotalOrdItem: Hash + Eq;
476    type SourceItem;
477
478    fn to_total_ord(&self) -> Self::TotalOrdItem;
479
480    fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem;
481}
482
483macro_rules! impl_to_total_ord_identity {
484    ($T: ty) => {
485        impl ToTotalOrd for $T {
486            type TotalOrdItem = $T;
487            type SourceItem = $T;
488
489            #[inline]
490            fn to_total_ord(&self) -> Self::TotalOrdItem {
491                self.clone()
492            }
493
494            #[inline]
495            fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem {
496                ord_item
497            }
498        }
499    };
500}
501
502impl_to_total_ord_identity!(bool);
503impl_to_total_ord_identity!(u8);
504impl_to_total_ord_identity!(u16);
505impl_to_total_ord_identity!(u32);
506impl_to_total_ord_identity!(u64);
507impl_to_total_ord_identity!(u128);
508impl_to_total_ord_identity!(usize);
509impl_to_total_ord_identity!(i8);
510impl_to_total_ord_identity!(i16);
511impl_to_total_ord_identity!(i32);
512impl_to_total_ord_identity!(i64);
513impl_to_total_ord_identity!(i128);
514impl_to_total_ord_identity!(isize);
515impl_to_total_ord_identity!(char);
516impl_to_total_ord_identity!(String);
517
518macro_rules! impl_to_total_ord_lifetimed_ref_identity {
519    ($T: ty) => {
520        impl<'a> ToTotalOrd for &'a $T {
521            type TotalOrdItem = &'a $T;
522            type SourceItem = &'a $T;
523
524            #[inline]
525            fn to_total_ord(&self) -> Self::TotalOrdItem {
526                *self
527            }
528
529            #[inline]
530            fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem {
531                ord_item
532            }
533        }
534    };
535}
536
537impl_to_total_ord_lifetimed_ref_identity!(str);
538impl_to_total_ord_lifetimed_ref_identity!([u8]);
539
540macro_rules! impl_to_total_ord_wrapped {
541    ($T: ty) => {
542        impl ToTotalOrd for $T {
543            type TotalOrdItem = TotalOrdWrap<$T>;
544            type SourceItem = $T;
545
546            #[inline]
547            fn to_total_ord(&self) -> Self::TotalOrdItem {
548                TotalOrdWrap(self.clone())
549            }
550
551            #[inline]
552            fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem {
553                ord_item.0
554            }
555        }
556    };
557}
558
559impl_to_total_ord_wrapped!(f32);
560impl_to_total_ord_wrapped!(f64);
561
562/// This is safe without needing to map the option value to TotalOrdWrap, since
563/// for example:
564/// `TotalOrdWrap<Option<T>>` implements `Eq + Hash`, iff:
565/// `Option<T>` implements `TotalEq + TotalHash`, iff:
566/// `T` implements `TotalEq + TotalHash`
567impl<T: Copy + TotalEq + TotalHash> ToTotalOrd for Option<T> {
568    type TotalOrdItem = TotalOrdWrap<Option<T>>;
569    type SourceItem = Option<T>;
570
571    #[inline]
572    fn to_total_ord(&self) -> Self::TotalOrdItem {
573        TotalOrdWrap(*self)
574    }
575
576    #[inline]
577    fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem {
578        ord_item.0
579    }
580}
581
582impl<T: ToTotalOrd> ToTotalOrd for &T {
583    type TotalOrdItem = T::TotalOrdItem;
584    type SourceItem = T::SourceItem;
585
586    #[inline]
587    fn to_total_ord(&self) -> Self::TotalOrdItem {
588        (*self).to_total_ord()
589    }
590
591    #[inline]
592    fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem {
593        T::peel_total_ord(ord_item)
594    }
595}
596
597impl<'a> ToTotalOrd for BytesHash<'a> {
598    type TotalOrdItem = BytesHash<'a>;
599    type SourceItem = BytesHash<'a>;
600
601    #[inline]
602    fn to_total_ord(&self) -> Self::TotalOrdItem {
603        *self
604    }
605
606    #[inline]
607    fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem {
608        ord_item
609    }
610}