polars_core/frame/group_by/
mod.rs

1use std::fmt::{Debug, Display, Formatter};
2use std::hash::Hash;
3
4use num_traits::NumCast;
5use polars_utils::format_pl_smallstr;
6use polars_utils::hashing::DirtyHash;
7use rayon::prelude::*;
8
9use self::hashing::*;
10use crate::prelude::*;
11use crate::utils::{_set_partition_size, accumulate_dataframes_vertical};
12use crate::POOL;
13
14pub mod aggregations;
15pub mod expr;
16pub(crate) mod hashing;
17mod into_groups;
18mod perfect;
19mod position;
20
21pub use into_groups::*;
22pub use position::*;
23
24use crate::chunked_array::ops::row_encode::{
25    encode_rows_unordered, encode_rows_vertical_par_unordered,
26};
27
28impl DataFrame {
29    pub fn group_by_with_series(
30        &self,
31        mut by: Vec<Column>,
32        multithreaded: bool,
33        sorted: bool,
34    ) -> PolarsResult<GroupBy> {
35        polars_ensure!(
36            !by.is_empty(),
37            ComputeError: "at least one key is required in a group_by operation"
38        );
39        let minimal_by_len = by.iter().map(|s| s.len()).min().expect("at least 1 key");
40        let df_height = self.height();
41
42        // we only throw this error if self.width > 0
43        // so that we can still call this on a dummy dataframe where we provide the keys
44        if (minimal_by_len != df_height) && (self.width() > 0) {
45            polars_ensure!(
46                minimal_by_len == 1,
47                ShapeMismatch: "series used as keys should have the same length as the DataFrame"
48            );
49            for by_key in by.iter_mut() {
50                if by_key.len() == minimal_by_len {
51                    *by_key = by_key.new_from_index(0, df_height)
52                }
53            }
54        };
55
56        let groups = if by.len() == 1 {
57            let column = &by[0];
58            column
59                .as_materialized_series()
60                .group_tuples(multithreaded, sorted)
61        } else if by.iter().any(|s| s.dtype().is_object()) {
62            #[cfg(feature = "object")]
63            {
64                let mut df = DataFrame::new(by.clone()).unwrap();
65                let n = df.height();
66                let rows = df.to_av_rows();
67                let iter = (0..n).map(|i| rows.get(i));
68                Ok(group_by(iter, sorted))
69            }
70            #[cfg(not(feature = "object"))]
71            {
72                unreachable!()
73            }
74        } else {
75            // Skip null dtype.
76            let by = by
77                .iter()
78                .filter(|s| !s.dtype().is_null())
79                .cloned()
80                .collect::<Vec<_>>();
81            if by.is_empty() {
82                let groups = if self.is_empty() {
83                    vec![]
84                } else {
85                    vec![[0, self.height() as IdxSize]]
86                };
87                Ok(GroupsType::Slice {
88                    groups,
89                    rolling: false,
90                })
91            } else {
92                let rows = if multithreaded {
93                    encode_rows_vertical_par_unordered(&by)
94                } else {
95                    encode_rows_unordered(&by)
96                }?
97                .into_series();
98                rows.group_tuples(multithreaded, sorted)
99            }
100        };
101        Ok(GroupBy::new(self, by, groups?.into_sliceable(), None))
102    }
103
104    /// Group DataFrame using a Series column.
105    ///
106    /// # Example
107    ///
108    /// ```
109    /// use polars_core::prelude::*;
110    /// fn group_by_sum(df: &DataFrame) -> PolarsResult<DataFrame> {
111    ///     df.group_by(["column_name"])?
112    ///     .select(["agg_column_name"])
113    ///     .sum()
114    /// }
115    /// ```
116    pub fn group_by<I, S>(&self, by: I) -> PolarsResult<GroupBy>
117    where
118        I: IntoIterator<Item = S>,
119        S: Into<PlSmallStr>,
120    {
121        let selected_keys = self.select_columns(by)?;
122        self.group_by_with_series(selected_keys, true, false)
123    }
124
125    /// Group DataFrame using a Series column.
126    /// The groups are ordered by their smallest row index.
127    pub fn group_by_stable<I, S>(&self, by: I) -> PolarsResult<GroupBy>
128    where
129        I: IntoIterator<Item = S>,
130        S: Into<PlSmallStr>,
131    {
132        let selected_keys = self.select_columns(by)?;
133        self.group_by_with_series(selected_keys, true, true)
134    }
135}
136
137/// Returned by a group_by operation on a DataFrame. This struct supports
138/// several aggregations.
139///
140/// Until described otherwise, the examples in this struct are performed on the following DataFrame:
141///
142/// ```ignore
143/// use polars_core::prelude::*;
144///
145/// let dates = &[
146/// "2020-08-21",
147/// "2020-08-21",
148/// "2020-08-22",
149/// "2020-08-23",
150/// "2020-08-22",
151/// ];
152/// // date format
153/// let fmt = "%Y-%m-%d";
154/// // create date series
155/// let s0 = DateChunked::parse_from_str_slice("date", dates, fmt)
156///         .into_series();
157/// // create temperature series
158/// let s1 = Series::new("temp".into(), [20, 10, 7, 9, 1]);
159/// // create rain series
160/// let s2 = Series::new("rain".into(), [0.2, 0.1, 0.3, 0.1, 0.01]);
161/// // create a new DataFrame
162/// let df = DataFrame::new(vec![s0, s1, s2]).unwrap();
163/// println!("{:?}", df);
164/// ```
165///
166/// Outputs:
167///
168/// ```text
169/// +------------+------+------+
170/// | date       | temp | rain |
171/// | ---        | ---  | ---  |
172/// | Date       | i32  | f64  |
173/// +============+======+======+
174/// | 2020-08-21 | 20   | 0.2  |
175/// +------------+------+------+
176/// | 2020-08-21 | 10   | 0.1  |
177/// +------------+------+------+
178/// | 2020-08-22 | 7    | 0.3  |
179/// +------------+------+------+
180/// | 2020-08-23 | 9    | 0.1  |
181/// +------------+------+------+
182/// | 2020-08-22 | 1    | 0.01 |
183/// +------------+------+------+
184/// ```
185///
186#[derive(Debug, Clone)]
187pub struct GroupBy<'a> {
188    pub df: &'a DataFrame,
189    pub(crate) selected_keys: Vec<Column>,
190    // [first idx, [other idx]]
191    groups: GroupPositions,
192    // columns selected for aggregation
193    pub(crate) selected_agg: Option<Vec<PlSmallStr>>,
194}
195
196impl<'a> GroupBy<'a> {
197    pub fn new(
198        df: &'a DataFrame,
199        by: Vec<Column>,
200        groups: GroupPositions,
201        selected_agg: Option<Vec<PlSmallStr>>,
202    ) -> Self {
203        GroupBy {
204            df,
205            selected_keys: by,
206            groups,
207            selected_agg,
208        }
209    }
210
211    /// Select the column(s) that should be aggregated.
212    /// You can select a single column or a slice of columns.
213    ///
214    /// Note that making a selection with this method is not required. If you
215    /// skip it all columns (except for the keys) will be selected for aggregation.
216    #[must_use]
217    pub fn select<I: IntoIterator<Item = S>, S: Into<PlSmallStr>>(mut self, selection: I) -> Self {
218        self.selected_agg = Some(selection.into_iter().map(|s| s.into()).collect());
219        self
220    }
221
222    /// Get the internal representation of the GroupBy operation.
223    /// The Vec returned contains:
224    ///     (first_idx, [`Vec<indexes>`])
225    ///     Where second value in the tuple is a vector with all matching indexes.
226    pub fn get_groups(&self) -> &GroupPositions {
227        &self.groups
228    }
229
230    /// Get the internal representation of the GroupBy operation.
231    /// The Vec returned contains:
232    ///     (first_idx, [`Vec<indexes>`])
233    ///     Where second value in the tuple is a vector with all matching indexes.
234    ///
235    /// # Safety
236    /// Groups should always be in bounds of the `DataFrame` hold by this [`GroupBy`].
237    /// If you mutate it, you must hold that invariant.
238    pub unsafe fn get_groups_mut(&mut self) -> &mut GroupPositions {
239        &mut self.groups
240    }
241
242    pub fn take_groups(self) -> GroupPositions {
243        self.groups
244    }
245
246    pub fn take_groups_mut(&mut self) -> GroupPositions {
247        std::mem::take(&mut self.groups)
248    }
249
250    pub fn keys_sliced(&self, slice: Option<(i64, usize)>) -> Vec<Column> {
251        #[allow(unused_assignments)]
252        // needed to keep the lifetimes valid for this scope
253        let mut groups_owned = None;
254
255        let groups = if let Some((offset, len)) = slice {
256            groups_owned = Some(self.groups.slice(offset, len));
257            groups_owned.as_deref().unwrap()
258        } else {
259            &self.groups
260        };
261        POOL.install(|| {
262            self.selected_keys
263                .par_iter()
264                .map(Column::as_materialized_series)
265                .map(|s| {
266                    match groups {
267                        GroupsType::Idx(groups) => {
268                            // SAFETY: groups are always in bounds.
269                            let mut out = unsafe { s.take_slice_unchecked(groups.first()) };
270                            if groups.sorted {
271                                out.set_sorted_flag(s.is_sorted_flag());
272                            };
273                            out
274                        },
275                        GroupsType::Slice { groups, rolling } => {
276                            if *rolling && !groups.is_empty() {
277                                // Groups can be sliced.
278                                let offset = groups[0][0];
279                                let [upper_offset, upper_len] = groups[groups.len() - 1];
280                                return s.slice(
281                                    offset as i64,
282                                    ((upper_offset + upper_len) - offset) as usize,
283                                );
284                            }
285
286                            let indices = groups
287                                .iter()
288                                .map(|&[first, _len]| first)
289                                .collect_ca(PlSmallStr::EMPTY);
290                            // SAFETY: groups are always in bounds.
291                            let mut out = unsafe { s.take_unchecked(&indices) };
292                            // Sliced groups are always in order of discovery.
293                            out.set_sorted_flag(s.is_sorted_flag());
294                            out
295                        },
296                    }
297                })
298                .map(Column::from)
299                .collect()
300        })
301    }
302
303    pub fn keys(&self) -> Vec<Column> {
304        self.keys_sliced(None)
305    }
306
307    fn prepare_agg(&self) -> PolarsResult<(Vec<Column>, Vec<Column>)> {
308        let keys = self.keys();
309
310        let agg_col = match &self.selected_agg {
311            Some(selection) => self.df.select_columns_impl(selection.as_slice()),
312            None => {
313                let by: Vec<_> = self.selected_keys.iter().map(|s| s.name()).collect();
314                let selection = self
315                    .df
316                    .iter()
317                    .map(|s| s.name())
318                    .filter(|a| !by.contains(a))
319                    .cloned()
320                    .collect::<Vec<_>>();
321
322                self.df.select_columns_impl(selection.as_slice())
323            },
324        }?;
325
326        Ok((keys, agg_col))
327    }
328
329    /// Aggregate grouped series and compute the mean per group.
330    ///
331    /// # Example
332    ///
333    /// ```rust
334    /// # use polars_core::prelude::*;
335    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
336    ///     df.group_by(["date"])?.select(["temp", "rain"]).mean()
337    /// }
338    /// ```
339    /// Returns:
340    ///
341    /// ```text
342    /// +------------+-----------+-----------+
343    /// | date       | temp_mean | rain_mean |
344    /// | ---        | ---       | ---       |
345    /// | Date       | f64       | f64       |
346    /// +============+===========+===========+
347    /// | 2020-08-23 | 9         | 0.1       |
348    /// +------------+-----------+-----------+
349    /// | 2020-08-22 | 4         | 0.155     |
350    /// +------------+-----------+-----------+
351    /// | 2020-08-21 | 15        | 0.15      |
352    /// +------------+-----------+-----------+
353    /// ```
354    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
355    pub fn mean(&self) -> PolarsResult<DataFrame> {
356        let (mut cols, agg_cols) = self.prepare_agg()?;
357
358        for agg_col in agg_cols {
359            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Mean);
360            let mut agg = unsafe { agg_col.agg_mean(&self.groups) };
361            agg.rename(new_name);
362            cols.push(agg);
363        }
364        DataFrame::new(cols)
365    }
366
367    /// Aggregate grouped series and compute the sum per group.
368    ///
369    /// # Example
370    ///
371    /// ```rust
372    /// # use polars_core::prelude::*;
373    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
374    ///     df.group_by(["date"])?.select(["temp"]).sum()
375    /// }
376    /// ```
377    /// Returns:
378    ///
379    /// ```text
380    /// +------------+----------+
381    /// | date       | temp_sum |
382    /// | ---        | ---      |
383    /// | Date       | i32      |
384    /// +============+==========+
385    /// | 2020-08-23 | 9        |
386    /// +------------+----------+
387    /// | 2020-08-22 | 8        |
388    /// +------------+----------+
389    /// | 2020-08-21 | 30       |
390    /// +------------+----------+
391    /// ```
392    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
393    pub fn sum(&self) -> PolarsResult<DataFrame> {
394        let (mut cols, agg_cols) = self.prepare_agg()?;
395
396        for agg_col in agg_cols {
397            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Sum);
398            let mut agg = unsafe { agg_col.agg_sum(&self.groups) };
399            agg.rename(new_name);
400            cols.push(agg);
401        }
402        DataFrame::new(cols)
403    }
404
405    /// Aggregate grouped series and compute the minimal value per group.
406    ///
407    /// # Example
408    ///
409    /// ```rust
410    /// # use polars_core::prelude::*;
411    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
412    ///     df.group_by(["date"])?.select(["temp"]).min()
413    /// }
414    /// ```
415    /// Returns:
416    ///
417    /// ```text
418    /// +------------+----------+
419    /// | date       | temp_min |
420    /// | ---        | ---      |
421    /// | Date       | i32      |
422    /// +============+==========+
423    /// | 2020-08-23 | 9        |
424    /// +------------+----------+
425    /// | 2020-08-22 | 1        |
426    /// +------------+----------+
427    /// | 2020-08-21 | 10       |
428    /// +------------+----------+
429    /// ```
430    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
431    pub fn min(&self) -> PolarsResult<DataFrame> {
432        let (mut cols, agg_cols) = self.prepare_agg()?;
433        for agg_col in agg_cols {
434            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Min);
435            let mut agg = unsafe { agg_col.agg_min(&self.groups) };
436            agg.rename(new_name);
437            cols.push(agg);
438        }
439        DataFrame::new(cols)
440    }
441
442    /// Aggregate grouped series and compute the maximum value per group.
443    ///
444    /// # Example
445    ///
446    /// ```rust
447    /// # use polars_core::prelude::*;
448    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
449    ///     df.group_by(["date"])?.select(["temp"]).max()
450    /// }
451    /// ```
452    /// Returns:
453    ///
454    /// ```text
455    /// +------------+----------+
456    /// | date       | temp_max |
457    /// | ---        | ---      |
458    /// | Date       | i32      |
459    /// +============+==========+
460    /// | 2020-08-23 | 9        |
461    /// +------------+----------+
462    /// | 2020-08-22 | 7        |
463    /// +------------+----------+
464    /// | 2020-08-21 | 20       |
465    /// +------------+----------+
466    /// ```
467    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
468    pub fn max(&self) -> PolarsResult<DataFrame> {
469        let (mut cols, agg_cols) = self.prepare_agg()?;
470        for agg_col in agg_cols {
471            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Max);
472            let mut agg = unsafe { agg_col.agg_max(&self.groups) };
473            agg.rename(new_name);
474            cols.push(agg);
475        }
476        DataFrame::new(cols)
477    }
478
479    /// Aggregate grouped `Series` and find the first value per group.
480    ///
481    /// # Example
482    ///
483    /// ```rust
484    /// # use polars_core::prelude::*;
485    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
486    ///     df.group_by(["date"])?.select(["temp"]).first()
487    /// }
488    /// ```
489    /// Returns:
490    ///
491    /// ```text
492    /// +------------+------------+
493    /// | date       | temp_first |
494    /// | ---        | ---        |
495    /// | Date       | i32        |
496    /// +============+============+
497    /// | 2020-08-23 | 9          |
498    /// +------------+------------+
499    /// | 2020-08-22 | 7          |
500    /// +------------+------------+
501    /// | 2020-08-21 | 20         |
502    /// +------------+------------+
503    /// ```
504    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
505    pub fn first(&self) -> PolarsResult<DataFrame> {
506        let (mut cols, agg_cols) = self.prepare_agg()?;
507        for agg_col in agg_cols {
508            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::First);
509            let mut agg = unsafe { agg_col.agg_first(&self.groups) };
510            agg.rename(new_name);
511            cols.push(agg);
512        }
513        DataFrame::new(cols)
514    }
515
516    /// Aggregate grouped `Series` and return the last value per group.
517    ///
518    /// # Example
519    ///
520    /// ```rust
521    /// # use polars_core::prelude::*;
522    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
523    ///     df.group_by(["date"])?.select(["temp"]).last()
524    /// }
525    /// ```
526    /// Returns:
527    ///
528    /// ```text
529    /// +------------+------------+
530    /// | date       | temp_last |
531    /// | ---        | ---        |
532    /// | Date       | i32        |
533    /// +============+============+
534    /// | 2020-08-23 | 9          |
535    /// +------------+------------+
536    /// | 2020-08-22 | 1          |
537    /// +------------+------------+
538    /// | 2020-08-21 | 10         |
539    /// +------------+------------+
540    /// ```
541    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
542    pub fn last(&self) -> PolarsResult<DataFrame> {
543        let (mut cols, agg_cols) = self.prepare_agg()?;
544        for agg_col in agg_cols {
545            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Last);
546            let mut agg = unsafe { agg_col.agg_last(&self.groups) };
547            agg.rename(new_name);
548            cols.push(agg);
549        }
550        DataFrame::new(cols)
551    }
552
553    /// Aggregate grouped `Series` by counting the number of unique values.
554    ///
555    /// # Example
556    ///
557    /// ```rust
558    /// # use polars_core::prelude::*;
559    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
560    ///     df.group_by(["date"])?.select(["temp"]).n_unique()
561    /// }
562    /// ```
563    /// Returns:
564    ///
565    /// ```text
566    /// +------------+---------------+
567    /// | date       | temp_n_unique |
568    /// | ---        | ---           |
569    /// | Date       | u32           |
570    /// +============+===============+
571    /// | 2020-08-23 | 1             |
572    /// +------------+---------------+
573    /// | 2020-08-22 | 2             |
574    /// +------------+---------------+
575    /// | 2020-08-21 | 2             |
576    /// +------------+---------------+
577    /// ```
578    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
579    pub fn n_unique(&self) -> PolarsResult<DataFrame> {
580        let (mut cols, agg_cols) = self.prepare_agg()?;
581        for agg_col in agg_cols {
582            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::NUnique);
583            let mut agg = unsafe { agg_col.agg_n_unique(&self.groups) };
584            agg.rename(new_name);
585            cols.push(agg);
586        }
587        DataFrame::new(cols)
588    }
589
590    /// Aggregate grouped [`Series`] and determine the quantile per group.
591    ///
592    /// # Example
593    ///
594    /// ```rust
595    /// # use polars_core::prelude::*;
596    /// # use arrow::legacy::prelude::QuantileMethod;
597    ///
598    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
599    ///     df.group_by(["date"])?.select(["temp"]).quantile(0.2, QuantileMethod::default())
600    /// }
601    /// ```
602    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
603    pub fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<DataFrame> {
604        polars_ensure!(
605            (0.0..=1.0).contains(&quantile),
606            ComputeError: "`quantile` should be within 0.0 and 1.0"
607        );
608        let (mut cols, agg_cols) = self.prepare_agg()?;
609        for agg_col in agg_cols {
610            let new_name = fmt_group_by_column(
611                agg_col.name().as_str(),
612                GroupByMethod::Quantile(quantile, method),
613            );
614            let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, method) };
615            agg.rename(new_name);
616            cols.push(agg);
617        }
618        DataFrame::new(cols)
619    }
620
621    /// Aggregate grouped [`Series`] and determine the median per group.
622    ///
623    /// # Example
624    ///
625    /// ```rust
626    /// # use polars_core::prelude::*;
627    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
628    ///     df.group_by(["date"])?.select(["temp"]).median()
629    /// }
630    /// ```
631    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
632    pub fn median(&self) -> PolarsResult<DataFrame> {
633        let (mut cols, agg_cols) = self.prepare_agg()?;
634        for agg_col in agg_cols {
635            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Median);
636            let mut agg = unsafe { agg_col.agg_median(&self.groups) };
637            agg.rename(new_name);
638            cols.push(agg);
639        }
640        DataFrame::new(cols)
641    }
642
643    /// Aggregate grouped [`Series`] and determine the variance per group.
644    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
645    pub fn var(&self, ddof: u8) -> PolarsResult<DataFrame> {
646        let (mut cols, agg_cols) = self.prepare_agg()?;
647        for agg_col in agg_cols {
648            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Var(ddof));
649            let mut agg = unsafe { agg_col.agg_var(&self.groups, ddof) };
650            agg.rename(new_name);
651            cols.push(agg);
652        }
653        DataFrame::new(cols)
654    }
655
656    /// Aggregate grouped [`Series`] and determine the standard deviation per group.
657    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
658    pub fn std(&self, ddof: u8) -> PolarsResult<DataFrame> {
659        let (mut cols, agg_cols) = self.prepare_agg()?;
660        for agg_col in agg_cols {
661            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Std(ddof));
662            let mut agg = unsafe { agg_col.agg_std(&self.groups, ddof) };
663            agg.rename(new_name);
664            cols.push(agg);
665        }
666        DataFrame::new(cols)
667    }
668
669    /// Aggregate grouped series and compute the number of values per group.
670    ///
671    /// # Example
672    ///
673    /// ```rust
674    /// # use polars_core::prelude::*;
675    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
676    ///     df.group_by(["date"])?.select(["temp"]).count()
677    /// }
678    /// ```
679    /// Returns:
680    ///
681    /// ```text
682    /// +------------+------------+
683    /// | date       | temp_count |
684    /// | ---        | ---        |
685    /// | Date       | u32        |
686    /// +============+============+
687    /// | 2020-08-23 | 1          |
688    /// +------------+------------+
689    /// | 2020-08-22 | 2          |
690    /// +------------+------------+
691    /// | 2020-08-21 | 2          |
692    /// +------------+------------+
693    /// ```
694    pub fn count(&self) -> PolarsResult<DataFrame> {
695        let (mut cols, agg_cols) = self.prepare_agg()?;
696
697        for agg_col in agg_cols {
698            let new_name = fmt_group_by_column(
699                agg_col.name().as_str(),
700                GroupByMethod::Count {
701                    include_nulls: true,
702                },
703            );
704            let mut ca = self.groups.group_count();
705            ca.rename(new_name);
706            cols.push(ca.into_column());
707        }
708        DataFrame::new(cols)
709    }
710
711    /// Get the group_by group indexes.
712    ///
713    /// # Example
714    ///
715    /// ```rust
716    /// # use polars_core::prelude::*;
717    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
718    ///     df.group_by(["date"])?.groups()
719    /// }
720    /// ```
721    /// Returns:
722    ///
723    /// ```text
724    /// +--------------+------------+
725    /// | date         | groups     |
726    /// | ---          | ---        |
727    /// | Date(days)   | list [u32] |
728    /// +==============+============+
729    /// | 2020-08-23   | "[3]"      |
730    /// +--------------+------------+
731    /// | 2020-08-22   | "[2, 4]"   |
732    /// +--------------+------------+
733    /// | 2020-08-21   | "[0, 1]"   |
734    /// +--------------+------------+
735    /// ```
736    pub fn groups(&self) -> PolarsResult<DataFrame> {
737        let mut cols = self.keys();
738        let mut column = self.groups.as_list_chunked();
739        let new_name = fmt_group_by_column("", GroupByMethod::Groups);
740        column.rename(new_name);
741        cols.push(column.into_column());
742        DataFrame::new(cols)
743    }
744
745    /// Aggregate the groups of the group_by operation into lists.
746    ///
747    /// # Example
748    ///
749    /// ```rust
750    /// # use polars_core::prelude::*;
751    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
752    ///     // GroupBy and aggregate to Lists
753    ///     df.group_by(["date"])?.select(["temp"]).agg_list()
754    /// }
755    /// ```
756    /// Returns:
757    ///
758    /// ```text
759    /// +------------+------------------------+
760    /// | date       | temp_agg_list          |
761    /// | ---        | ---                    |
762    /// | Date       | list [i32]             |
763    /// +============+========================+
764    /// | 2020-08-23 | "[Some(9)]"            |
765    /// +------------+------------------------+
766    /// | 2020-08-22 | "[Some(7), Some(1)]"   |
767    /// +------------+------------------------+
768    /// | 2020-08-21 | "[Some(20), Some(10)]" |
769    /// +------------+------------------------+
770    /// ```
771    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
772    pub fn agg_list(&self) -> PolarsResult<DataFrame> {
773        let (mut cols, agg_cols) = self.prepare_agg()?;
774        for agg_col in agg_cols {
775            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Implode);
776            let mut agg = unsafe { agg_col.agg_list(&self.groups) };
777            agg.rename(new_name);
778            cols.push(agg);
779        }
780        DataFrame::new(cols)
781    }
782
783    fn prepare_apply(&self) -> PolarsResult<DataFrame> {
784        polars_ensure!(self.df.height() > 0, ComputeError: "cannot group_by + apply on empty 'DataFrame'");
785        if let Some(agg) = &self.selected_agg {
786            if agg.is_empty() {
787                Ok(self.df.clone())
788            } else {
789                let mut new_cols = Vec::with_capacity(self.selected_keys.len() + agg.len());
790                new_cols.extend_from_slice(&self.selected_keys);
791                let cols = self.df.select_columns_impl(agg.as_slice())?;
792                new_cols.extend(cols);
793                Ok(unsafe { DataFrame::new_no_checks(self.df.height(), new_cols) })
794            }
795        } else {
796            Ok(self.df.clone())
797        }
798    }
799
800    /// Apply a closure over the groups as a new [`DataFrame`] in parallel.
801    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
802    pub fn par_apply<F>(&self, f: F) -> PolarsResult<DataFrame>
803    where
804        F: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
805    {
806        let df = self.prepare_apply()?;
807        let dfs = self
808            .get_groups()
809            .par_iter()
810            .map(|g| {
811                // SAFETY:
812                // groups are in bounds
813                let sub_df = unsafe { take_df(&df, g) };
814                f(sub_df)
815            })
816            .collect::<PolarsResult<Vec<_>>>()?;
817
818        let mut df = accumulate_dataframes_vertical(dfs)?;
819        df.as_single_chunk_par();
820        Ok(df)
821    }
822
823    /// Apply a closure over the groups as a new [`DataFrame`].
824    pub fn apply<F>(&self, mut f: F) -> PolarsResult<DataFrame>
825    where
826        F: FnMut(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
827    {
828        let df = self.prepare_apply()?;
829        let dfs = self
830            .get_groups()
831            .iter()
832            .map(|g| {
833                // SAFETY:
834                // groups are in bounds
835                let sub_df = unsafe { take_df(&df, g) };
836                f(sub_df)
837            })
838            .collect::<PolarsResult<Vec<_>>>()?;
839
840        let mut df = accumulate_dataframes_vertical(dfs)?;
841        df.as_single_chunk_par();
842        Ok(df)
843    }
844
845    pub fn sliced(mut self, slice: Option<(i64, usize)>) -> Self {
846        match slice {
847            None => self,
848            Some((offset, length)) => {
849                self.groups = (self.groups.slice(offset, length)).clone();
850                self.selected_keys = self.keys_sliced(slice);
851                self
852            },
853        }
854    }
855}
856
857unsafe fn take_df(df: &DataFrame, g: GroupsIndicator) -> DataFrame {
858    match g {
859        GroupsIndicator::Idx(idx) => df.take_slice_unchecked(idx.1),
860        GroupsIndicator::Slice([first, len]) => df.slice(first as i64, len as usize),
861    }
862}
863
864#[derive(Copy, Clone, Debug)]
865pub enum GroupByMethod {
866    Min,
867    NanMin,
868    Max,
869    NanMax,
870    Median,
871    Mean,
872    First,
873    Last,
874    Sum,
875    Groups,
876    NUnique,
877    Quantile(f64, QuantileMethod),
878    Count { include_nulls: bool },
879    Implode,
880    Std(u8),
881    Var(u8),
882}
883
884impl Display for GroupByMethod {
885    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
886        use GroupByMethod::*;
887        let s = match self {
888            Min => "min",
889            NanMin => "nan_min",
890            Max => "max",
891            NanMax => "nan_max",
892            Median => "median",
893            Mean => "mean",
894            First => "first",
895            Last => "last",
896            Sum => "sum",
897            Groups => "groups",
898            NUnique => "n_unique",
899            Quantile(_, _) => "quantile",
900            Count { .. } => "count",
901            Implode => "list",
902            Std(_) => "std",
903            Var(_) => "var",
904        };
905        write!(f, "{s}")
906    }
907}
908
909// Formatting functions used in eager and lazy code for renaming grouped columns
910pub fn fmt_group_by_column(name: &str, method: GroupByMethod) -> PlSmallStr {
911    use GroupByMethod::*;
912    match method {
913        Min => format_pl_smallstr!("{name}_min"),
914        Max => format_pl_smallstr!("{name}_max"),
915        NanMin => format_pl_smallstr!("{name}_nan_min"),
916        NanMax => format_pl_smallstr!("{name}_nan_max"),
917        Median => format_pl_smallstr!("{name}_median"),
918        Mean => format_pl_smallstr!("{name}_mean"),
919        First => format_pl_smallstr!("{name}_first"),
920        Last => format_pl_smallstr!("{name}_last"),
921        Sum => format_pl_smallstr!("{name}_sum"),
922        Groups => PlSmallStr::from_static("groups"),
923        NUnique => format_pl_smallstr!("{name}_n_unique"),
924        Count { .. } => format_pl_smallstr!("{name}_count"),
925        Implode => format_pl_smallstr!("{name}_agg_list"),
926        Quantile(quantile, _interpol) => format_pl_smallstr!("{name}_quantile_{quantile:.2}"),
927        Std(_) => format_pl_smallstr!("{name}_agg_std"),
928        Var(_) => format_pl_smallstr!("{name}_agg_var"),
929    }
930}
931
932#[cfg(test)]
933mod test {
934    use num_traits::FloatConst;
935
936    use crate::prelude::*;
937
938    #[test]
939    #[cfg(feature = "dtype-date")]
940    #[cfg_attr(miri, ignore)]
941    fn test_group_by() -> PolarsResult<()> {
942        let s0 = Column::new(
943            PlSmallStr::from_static("date"),
944            &[
945                "2020-08-21",
946                "2020-08-21",
947                "2020-08-22",
948                "2020-08-23",
949                "2020-08-22",
950            ],
951        );
952        let s1 = Column::new(PlSmallStr::from_static("temp"), [20, 10, 7, 9, 1]);
953        let s2 = Column::new(PlSmallStr::from_static("rain"), [0.2, 0.1, 0.3, 0.1, 0.01]);
954        let df = DataFrame::new(vec![s0, s1, s2]).unwrap();
955
956        let out = df.group_by_stable(["date"])?.select(["temp"]).count()?;
957        assert_eq!(
958            out.column("temp_count")?,
959            &Column::new(PlSmallStr::from_static("temp_count"), [2 as IdxSize, 2, 1])
960        );
961
962        // Use of deprecated mean() for testing purposes
963        #[allow(deprecated)]
964        // Select multiple
965        let out = df
966            .group_by_stable(["date"])?
967            .select(["temp", "rain"])
968            .mean()?;
969        assert_eq!(
970            out.column("temp_mean")?,
971            &Column::new(PlSmallStr::from_static("temp_mean"), [15.0f64, 4.0, 9.0])
972        );
973
974        // Use of deprecated `mean()` for testing purposes
975        #[allow(deprecated)]
976        // Group by multiple
977        let out = df
978            .group_by_stable(["date", "temp"])?
979            .select(["rain"])
980            .mean()?;
981        assert!(out.column("rain_mean").is_ok());
982
983        // Use of deprecated `sum()` for testing purposes
984        #[allow(deprecated)]
985        let out = df.group_by_stable(["date"])?.select(["temp"]).sum()?;
986        assert_eq!(
987            out.column("temp_sum")?,
988            &Column::new(PlSmallStr::from_static("temp_sum"), [30, 8, 9])
989        );
990
991        // Use of deprecated `n_unique()` for testing purposes
992        #[allow(deprecated)]
993        // implicit select all and only aggregate on methods that support that aggregation
994        let gb = df.group_by(["date"]).unwrap().n_unique().unwrap();
995        // check the group by column is filtered out.
996        assert_eq!(gb.width(), 3);
997        Ok(())
998    }
999
1000    #[test]
1001    #[cfg_attr(miri, ignore)]
1002    fn test_static_group_by_by_12_columns() {
1003        // Build GroupBy DataFrame.
1004        let s0 = Column::new("G1".into(), ["A", "A", "B", "B", "C"].as_ref());
1005        let s1 = Column::new("N".into(), [1, 2, 2, 4, 2].as_ref());
1006        let s2 = Column::new("G2".into(), ["k", "l", "m", "m", "l"].as_ref());
1007        let s3 = Column::new("G3".into(), ["a", "b", "c", "c", "d"].as_ref());
1008        let s4 = Column::new("G4".into(), ["1", "2", "3", "3", "4"].as_ref());
1009        let s5 = Column::new("G5".into(), ["X", "Y", "Z", "Z", "W"].as_ref());
1010        let s6 = Column::new("G6".into(), [false, true, true, true, false].as_ref());
1011        let s7 = Column::new("G7".into(), ["r", "x", "q", "q", "o"].as_ref());
1012        let s8 = Column::new("G8".into(), ["R", "X", "Q", "Q", "O"].as_ref());
1013        let s9 = Column::new("G9".into(), [1, 2, 3, 3, 4].as_ref());
1014        let s10 = Column::new("G10".into(), [".", "!", "?", "?", "/"].as_ref());
1015        let s11 = Column::new("G11".into(), ["(", ")", "@", "@", "$"].as_ref());
1016        let s12 = Column::new("G12".into(), ["-", "_", ";", ";", ","].as_ref());
1017
1018        let df =
1019            DataFrame::new(vec![s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]).unwrap();
1020
1021        // Use of deprecated `sum()` for testing purposes
1022        #[allow(deprecated)]
1023        let adf = df
1024            .group_by([
1025                "G1", "G2", "G3", "G4", "G5", "G6", "G7", "G8", "G9", "G10", "G11", "G12",
1026            ])
1027            .unwrap()
1028            .select(["N"])
1029            .sum()
1030            .unwrap();
1031
1032        assert_eq!(
1033            Vec::from(&adf.column("N_sum").unwrap().i32().unwrap().sort(false)),
1034            &[Some(1), Some(2), Some(2), Some(6)]
1035        );
1036    }
1037
1038    #[test]
1039    #[cfg_attr(miri, ignore)]
1040    fn test_dynamic_group_by_by_13_columns() {
1041        // The content for every group_by series.
1042        let series_content = ["A", "A", "B", "B", "C"];
1043
1044        // The name of every group_by series.
1045        let series_names = [
1046            "G1", "G2", "G3", "G4", "G5", "G6", "G7", "G8", "G9", "G10", "G11", "G12", "G13",
1047        ];
1048
1049        // Vector to contain every series.
1050        let mut columns = Vec::with_capacity(14);
1051
1052        // Create a series for every group name.
1053        for series_name in series_names {
1054            let group_columns = Column::new(series_name.into(), series_content.as_ref());
1055            columns.push(group_columns);
1056        }
1057
1058        // Create a series for the aggregation column.
1059        let agg_series = Column::new("N".into(), [1, 2, 3, 3, 4].as_ref());
1060        columns.push(agg_series);
1061
1062        // Create the dataframe with the computed series.
1063        let df = DataFrame::new(columns).unwrap();
1064
1065        // Use of deprecated `sum()` for testing purposes
1066        #[allow(deprecated)]
1067        // Compute the aggregated DataFrame by the 13 columns defined in `series_names`.
1068        let adf = df
1069            .group_by(series_names)
1070            .unwrap()
1071            .select(["N"])
1072            .sum()
1073            .unwrap();
1074
1075        // Check that the results of the group-by are correct. The content of every column
1076        // is equal, then, the grouped columns shall be equal and in the same order.
1077        for series_name in &series_names {
1078            assert_eq!(
1079                Vec::from(&adf.column(series_name).unwrap().str().unwrap().sort(false)),
1080                &[Some("A"), Some("B"), Some("C")]
1081            );
1082        }
1083
1084        // Check the aggregated column is the expected one.
1085        assert_eq!(
1086            Vec::from(&adf.column("N_sum").unwrap().i32().unwrap().sort(false)),
1087            &[Some(3), Some(4), Some(6)]
1088        );
1089    }
1090
1091    #[test]
1092    #[cfg_attr(miri, ignore)]
1093    fn test_group_by_floats() {
1094        let df = df! {"flt" => [1., 1., 2., 2., 3.],
1095                    "val" => [1, 1, 1, 1, 1]
1096        }
1097        .unwrap();
1098        // Use of deprecated `sum()` for testing purposes
1099        #[allow(deprecated)]
1100        let res = df.group_by(["flt"]).unwrap().sum().unwrap();
1101        let res = res.sort(["flt"], SortMultipleOptions::default()).unwrap();
1102        assert_eq!(
1103            Vec::from(res.column("val_sum").unwrap().i32().unwrap()),
1104            &[Some(2), Some(2), Some(1)]
1105        );
1106    }
1107
1108    #[test]
1109    #[cfg_attr(miri, ignore)]
1110    #[cfg(feature = "dtype-categorical")]
1111    fn test_group_by_categorical() {
1112        let mut df = df! {"foo" => ["a", "a", "b", "b", "c"],
1113                    "ham" => ["a", "a", "b", "b", "c"],
1114                    "bar" => [1, 1, 1, 1, 1]
1115        }
1116        .unwrap();
1117
1118        df.apply("foo", |s| {
1119            s.cast(&DataType::Categorical(None, Default::default()))
1120                .unwrap()
1121        })
1122        .unwrap();
1123
1124        // Use of deprecated `sum()` for testing purposes
1125        #[allow(deprecated)]
1126        // check multiple keys and categorical
1127        let res = df
1128            .group_by_stable(["foo", "ham"])
1129            .unwrap()
1130            .select(["bar"])
1131            .sum()
1132            .unwrap();
1133
1134        assert_eq!(
1135            Vec::from(
1136                res.column("bar_sum")
1137                    .unwrap()
1138                    .as_materialized_series()
1139                    .i32()
1140                    .unwrap()
1141            ),
1142            &[Some(2), Some(2), Some(1)]
1143        );
1144    }
1145
1146    #[test]
1147    #[cfg_attr(miri, ignore)]
1148    fn test_group_by_null_handling() -> PolarsResult<()> {
1149        let df = df!(
1150            "a" => ["a", "a", "a", "b", "b"],
1151            "b" => [Some(1), Some(2), None, None, Some(1)]
1152        )?;
1153        // Use of deprecated `mean()` for testing purposes
1154        #[allow(deprecated)]
1155        let out = df.group_by_stable(["a"])?.mean()?;
1156
1157        assert_eq!(
1158            Vec::from(out.column("b_mean")?.as_materialized_series().f64()?),
1159            &[Some(1.5), Some(1.0)]
1160        );
1161        Ok(())
1162    }
1163
1164    #[test]
1165    #[cfg_attr(miri, ignore)]
1166    fn test_group_by_var() -> PolarsResult<()> {
1167        // check variance and proper coercion to f64
1168        let df = df![
1169            "g" => ["foo", "foo", "bar"],
1170            "flt" => [1.0, 2.0, 3.0],
1171            "int" => [1, 2, 3]
1172        ]?;
1173
1174        // Use of deprecated `sum()` for testing purposes
1175        #[allow(deprecated)]
1176        let out = df.group_by_stable(["g"])?.select(["int"]).var(1)?;
1177
1178        assert_eq!(out.column("int_agg_var")?.f64()?.get(0), Some(0.5));
1179        // Use of deprecated `std()` for testing purposes
1180        #[allow(deprecated)]
1181        let out = df.group_by_stable(["g"])?.select(["int"]).std(1)?;
1182        let val = out.column("int_agg_std")?.f64()?.get(0).unwrap();
1183        let expected = f64::FRAC_1_SQRT_2();
1184        assert!((val - expected).abs() < 0.000001);
1185        Ok(())
1186    }
1187
1188    #[test]
1189    #[cfg_attr(miri, ignore)]
1190    #[cfg(feature = "dtype-categorical")]
1191    fn test_group_by_null_group() -> PolarsResult<()> {
1192        // check if null is own group
1193        let mut df = df![
1194            "g" => [Some("foo"), Some("foo"), Some("bar"), None, None],
1195            "flt" => [1.0, 2.0, 3.0, 1.0, 1.0],
1196            "int" => [1, 2, 3, 1, 1]
1197        ]?;
1198
1199        df.try_apply("g", |s| {
1200            s.cast(&DataType::Categorical(None, Default::default()))
1201        })?;
1202
1203        // Use of deprecated `sum()` for testing purposes
1204        #[allow(deprecated)]
1205        let _ = df.group_by(["g"])?.sum()?;
1206        Ok(())
1207    }
1208}