polars_core/frame/group_by/
into_groups.rs

1use arrow::legacy::kernels::sort_partition::{create_clean_partitions, partition_to_groups};
2use polars_error::check_signals;
3use polars_utils::total_ord::{ToTotalOrd, TotalHash};
4
5use super::*;
6use crate::chunked_array::cast::CastOptions;
7use crate::chunked_array::ops::row_encode::_get_rows_encoded_ca_unordered;
8use crate::config::verbose;
9use crate::series::BitRepr;
10use crate::utils::flatten::flatten_par;
11
12/// Used to create the tuples for a group_by operation.
13pub trait IntoGroupsType {
14    /// Create the tuples need for a group_by operation.
15    ///     * The first value in the tuple is the first index of the group.
16    ///     * The second value in the tuple is the indexes of the groups including the first value.
17    fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult<GroupsType> {
18        unimplemented!()
19    }
20}
21
22fn group_multithreaded<T: PolarsDataType>(ca: &ChunkedArray<T>) -> bool {
23    // TODO! change to something sensible
24    ca.len() > 1000 && POOL.current_num_threads() > 1
25}
26
27fn num_groups_proxy<T>(ca: &ChunkedArray<T>, multithreaded: bool, sorted: bool) -> GroupsType
28where
29    T: PolarsNumericType,
30    T::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd,
31    <T::Native as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash,
32{
33    if multithreaded && group_multithreaded(ca) {
34        let n_partitions = _set_partition_size();
35
36        // use the arrays as iterators
37        if ca.null_count() == 0 {
38            let keys = ca
39                .downcast_iter()
40                .map(|arr| arr.values().as_slice())
41                .collect::<Vec<_>>();
42            group_by_threaded_slice(keys, n_partitions, sorted)
43        } else {
44            let keys = ca
45                .downcast_iter()
46                .map(|arr| arr.iter().map(|o| o.copied()))
47                .collect::<Vec<_>>();
48            group_by_threaded_iter(&keys, n_partitions, sorted)
49        }
50    } else if !ca.has_nulls() {
51        group_by(ca.into_no_null_iter(), sorted)
52    } else {
53        group_by(ca.iter(), sorted)
54    }
55}
56
57impl<T> ChunkedArray<T>
58where
59    T: PolarsNumericType,
60    T::Native: NumCast,
61{
62    fn create_groups_from_sorted(&self, multithreaded: bool) -> GroupsSlice {
63        if verbose() {
64            eprintln!("group_by keys are sorted; running sorted key fast path");
65        }
66        let arr = self.downcast_iter().next().unwrap();
67        if arr.is_empty() {
68            return GroupsSlice::default();
69        }
70        let mut values = arr.values().as_slice();
71        let null_count = arr.null_count();
72        let length = values.len();
73
74        // all nulls
75        if null_count == length {
76            return vec![[0, length as IdxSize]];
77        }
78
79        let mut nulls_first = false;
80        if null_count > 0 {
81            nulls_first = arr.get(0).is_none()
82        }
83
84        if nulls_first {
85            values = &values[null_count..];
86        } else {
87            values = &values[..length - null_count];
88        };
89
90        let n_threads = POOL.current_num_threads();
91        let groups = if multithreaded && n_threads > 1 {
92            let parts =
93                create_clean_partitions(values, n_threads, self.is_sorted_descending_flag());
94            let n_parts = parts.len();
95
96            let first_ptr = &values[0] as *const T::Native as usize;
97            let groups = parts.par_iter().enumerate().map(|(i, part)| {
98                // we go via usize as *const is not send
99                let first_ptr = first_ptr as *const T::Native;
100
101                let part_first_ptr = &part[0] as *const T::Native;
102                let mut offset = unsafe { part_first_ptr.offset_from(first_ptr) } as IdxSize;
103
104                // nulls first: only add the nulls at the first partition
105                if nulls_first && i == 0 {
106                    partition_to_groups(part, null_count as IdxSize, true, offset)
107                }
108                // nulls last: only compute at the last partition
109                else if !nulls_first && i == n_parts - 1 {
110                    partition_to_groups(part, null_count as IdxSize, false, offset)
111                }
112                // other partitions
113                else {
114                    if nulls_first {
115                        offset += null_count as IdxSize;
116                    };
117
118                    partition_to_groups(part, 0, false, offset)
119                }
120            });
121            let groups = POOL.install(|| groups.collect::<Vec<_>>());
122            flatten_par(&groups)
123        } else {
124            partition_to_groups(values, null_count as IdxSize, nulls_first, 0)
125        };
126        groups
127    }
128}
129
130#[cfg(all(feature = "dtype-categorical", feature = "performant"))]
131impl IntoGroupsType for CategoricalChunked {
132    fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsType> {
133        Ok(self.group_tuples_perfect(multithreaded, sorted))
134    }
135}
136
137impl<T> IntoGroupsType for ChunkedArray<T>
138where
139    T: PolarsNumericType,
140    T::Native: NumCast,
141{
142    fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsType> {
143        // sorted path
144        if self.is_sorted_ascending_flag() || self.is_sorted_descending_flag() {
145            // don't have to pass `sorted` arg, GroupSlice is always sorted.
146            return Ok(GroupsType::Slice {
147                groups: self.rechunk().create_groups_from_sorted(multithreaded),
148                rolling: false,
149            });
150        }
151
152        let out = match self.dtype() {
153            DataType::UInt64 => {
154                // convince the compiler that we are this type.
155                let ca: &UInt64Chunked = unsafe {
156                    &*(self as *const ChunkedArray<T> as *const ChunkedArray<UInt64Type>)
157                };
158                num_groups_proxy(ca, multithreaded, sorted)
159            },
160            DataType::UInt32 => {
161                // convince the compiler that we are this type.
162                let ca: &UInt32Chunked = unsafe {
163                    &*(self as *const ChunkedArray<T> as *const ChunkedArray<UInt32Type>)
164                };
165                num_groups_proxy(ca, multithreaded, sorted)
166            },
167            DataType::Int64 => {
168                let BitRepr::Large(ca) = self.to_bit_repr() else {
169                    unreachable!()
170                };
171                num_groups_proxy(&ca, multithreaded, sorted)
172            },
173            DataType::Int32 => {
174                let BitRepr::Small(ca) = self.to_bit_repr() else {
175                    unreachable!()
176                };
177                num_groups_proxy(&ca, multithreaded, sorted)
178            },
179            DataType::Float64 => {
180                // convince the compiler that we are this type.
181                let ca: &Float64Chunked = unsafe {
182                    &*(self as *const ChunkedArray<T> as *const ChunkedArray<Float64Type>)
183                };
184                num_groups_proxy(ca, multithreaded, sorted)
185            },
186            DataType::Float32 => {
187                // convince the compiler that we are this type.
188                let ca: &Float32Chunked = unsafe {
189                    &*(self as *const ChunkedArray<T> as *const ChunkedArray<Float32Type>)
190                };
191                num_groups_proxy(ca, multithreaded, sorted)
192            },
193            #[cfg(feature = "dtype-decimal")]
194            DataType::Decimal(_, _) => {
195                // convince the compiler that we are this type.
196                let ca: &Int128Chunked = unsafe {
197                    &*(self as *const ChunkedArray<T> as *const ChunkedArray<Int128Type>)
198                };
199                num_groups_proxy(ca, multithreaded, sorted)
200            },
201            #[cfg(all(feature = "performant", feature = "dtype-i8", feature = "dtype-u8"))]
202            DataType::Int8 => {
203                // convince the compiler that we are this type.
204                let ca: &Int8Chunked =
205                    unsafe { &*(self as *const ChunkedArray<T> as *const ChunkedArray<Int8Type>) };
206                let s = ca.reinterpret_unsigned();
207                return s.group_tuples(multithreaded, sorted);
208            },
209            #[cfg(all(feature = "performant", feature = "dtype-i8", feature = "dtype-u8"))]
210            DataType::UInt8 => {
211                // convince the compiler that we are this type.
212                let ca: &UInt8Chunked =
213                    unsafe { &*(self as *const ChunkedArray<T> as *const ChunkedArray<UInt8Type>) };
214                num_groups_proxy(ca, multithreaded, sorted)
215            },
216            #[cfg(all(feature = "performant", feature = "dtype-i16", feature = "dtype-u16"))]
217            DataType::Int16 => {
218                // convince the compiler that we are this type.
219                let ca: &Int16Chunked =
220                    unsafe { &*(self as *const ChunkedArray<T> as *const ChunkedArray<Int16Type>) };
221                let s = ca.reinterpret_unsigned();
222                return s.group_tuples(multithreaded, sorted);
223            },
224            #[cfg(all(feature = "performant", feature = "dtype-i16", feature = "dtype-u16"))]
225            DataType::UInt16 => {
226                // convince the compiler that we are this type.
227                let ca: &UInt16Chunked = unsafe {
228                    &*(self as *const ChunkedArray<T> as *const ChunkedArray<UInt16Type>)
229                };
230                num_groups_proxy(ca, multithreaded, sorted)
231            },
232            _ => {
233                let ca = unsafe { self.cast_unchecked(&DataType::UInt32).unwrap() };
234                let ca = ca.u32().unwrap();
235                num_groups_proxy(ca, multithreaded, sorted)
236            },
237        };
238        check_signals()?;
239        Ok(out)
240    }
241}
242impl IntoGroupsType for BooleanChunked {
243    fn group_tuples(&self, mut multithreaded: bool, sorted: bool) -> PolarsResult<GroupsType> {
244        multithreaded &= POOL.current_num_threads() > 1;
245
246        #[cfg(feature = "performant")]
247        {
248            let ca = self
249                .cast_with_options(&DataType::UInt8, CastOptions::Overflowing)
250                .unwrap();
251            let ca = ca.u8().unwrap();
252            ca.group_tuples(multithreaded, sorted)
253        }
254        #[cfg(not(feature = "performant"))]
255        {
256            let ca = self
257                .cast_with_options(&DataType::UInt32, CastOptions::Overflowing)
258                .unwrap();
259            let ca = ca.u32().unwrap();
260            ca.group_tuples(multithreaded, sorted)
261        }
262    }
263}
264
265impl IntoGroupsType for StringChunked {
266    #[allow(clippy::needless_lifetimes)]
267    fn group_tuples<'a>(&'a self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsType> {
268        self.as_binary().group_tuples(multithreaded, sorted)
269    }
270}
271
272impl IntoGroupsType for BinaryChunked {
273    #[allow(clippy::needless_lifetimes)]
274    fn group_tuples<'a>(
275        &'a self,
276        mut multithreaded: bool,
277        sorted: bool,
278    ) -> PolarsResult<GroupsType> {
279        multithreaded &= POOL.current_num_threads() > 1;
280        let bh = self.to_bytes_hashes(multithreaded, Default::default());
281
282        let out = if multithreaded {
283            let n_partitions = bh.len();
284            // Take slices so that the vecs are not cloned.
285            let bh = bh.iter().map(|v| v.as_slice()).collect::<Vec<_>>();
286            group_by_threaded_slice(bh, n_partitions, sorted)
287        } else {
288            group_by(bh[0].iter(), sorted)
289        };
290        check_signals()?;
291        Ok(out)
292    }
293}
294
295impl IntoGroupsType for BinaryOffsetChunked {
296    #[allow(clippy::needless_lifetimes)]
297    fn group_tuples<'a>(
298        &'a self,
299        mut multithreaded: bool,
300        sorted: bool,
301    ) -> PolarsResult<GroupsType> {
302        multithreaded &= POOL.current_num_threads() > 1;
303        let bh = self.to_bytes_hashes(multithreaded, Default::default());
304
305        let out = if multithreaded {
306            let n_partitions = bh.len();
307            // Take slices so that the vecs are not cloned.
308            let bh = bh.iter().map(|v| v.as_slice()).collect::<Vec<_>>();
309            group_by_threaded_slice(bh, n_partitions, sorted)
310        } else {
311            group_by(bh[0].iter(), sorted)
312        };
313        Ok(out)
314    }
315}
316
317impl IntoGroupsType for ListChunked {
318    #[allow(clippy::needless_lifetimes)]
319    #[allow(unused_variables)]
320    fn group_tuples<'a>(
321        &'a self,
322        mut multithreaded: bool,
323        sorted: bool,
324    ) -> PolarsResult<GroupsType> {
325        multithreaded &= POOL.current_num_threads() > 1;
326        let by = &[self.clone().into_column()];
327        let ca = if multithreaded {
328            encode_rows_vertical_par_unordered(by).unwrap()
329        } else {
330            _get_rows_encoded_ca_unordered(PlSmallStr::EMPTY, by).unwrap()
331        };
332
333        ca.group_tuples(multithreaded, sorted)
334    }
335}
336
337#[cfg(feature = "dtype-array")]
338impl IntoGroupsType for ArrayChunked {
339    #[allow(clippy::needless_lifetimes)]
340    #[allow(unused_variables)]
341    fn group_tuples<'a>(&'a self, _multithreaded: bool, _sorted: bool) -> PolarsResult<GroupsType> {
342        todo!("grouping FixedSizeList not yet supported")
343    }
344}
345
346#[cfg(feature = "object")]
347impl<T> IntoGroupsType for ObjectChunked<T>
348where
349    T: PolarsObject,
350{
351    fn group_tuples(&self, _multithreaded: bool, sorted: bool) -> PolarsResult<GroupsType> {
352        Ok(group_by(self.into_iter(), sorted))
353    }
354}