polars_core/frame/group_by/aggregations/
mod.rs

1mod agg_list;
2mod boolean;
3mod dispatch;
4mod string;
5
6use std::cmp::Ordering;
7
8pub use agg_list::*;
9use arrow::bitmap::{Bitmap, MutableBitmap};
10use arrow::legacy::kernels::rolling;
11use arrow::legacy::kernels::rolling::no_nulls::{
12    MaxWindow, MeanWindow, MinWindow, QuantileWindow, RollingAggWindowNoNulls, SumWindow, VarWindow,
13};
14use arrow::legacy::kernels::rolling::nulls::RollingAggWindowNulls;
15use arrow::legacy::kernels::take_agg::*;
16use arrow::legacy::prelude::QuantileMethod;
17use arrow::legacy::trusted_len::TrustedLenPush;
18use arrow::types::NativeType;
19use num_traits::pow::Pow;
20use num_traits::{Bounded, Float, Num, NumCast, ToPrimitive, Zero};
21use polars_utils::float::IsFloat;
22use polars_utils::idx_vec::IdxVec;
23use polars_utils::ord::{compare_fn_nan_max, compare_fn_nan_min};
24use rayon::prelude::*;
25
26use crate::chunked_array::cast::CastOptions;
27#[cfg(feature = "object")]
28use crate::chunked_array::object::extension::create_extension;
29use crate::frame::group_by::GroupsIdx;
30#[cfg(feature = "object")]
31use crate::frame::group_by::GroupsIndicator;
32use crate::prelude::*;
33use crate::series::implementations::SeriesWrap;
34use crate::series::IsSorted;
35use crate::utils::NoNull;
36use crate::{apply_method_physical_integer, POOL};
37
38fn idx2usize(idx: &[IdxSize]) -> impl ExactSizeIterator<Item = usize> + '_ {
39    idx.iter().map(|i| *i as usize)
40}
41
42// if the windows overlap, we can use the rolling_<agg> kernels
43// they maintain state, which saves a lot of compute by not naively traversing all elements every
44// window
45//
46// if the windows don't overlap, we should not use these kernels as they are single threaded, so
47// we miss out on easy parallelization.
48pub fn _use_rolling_kernels(groups: &GroupsSlice, chunks: &[ArrayRef]) -> bool {
49    match groups.len() {
50        0 | 1 => false,
51        _ => {
52            let [first_offset, first_len] = groups[0];
53            let second_offset = groups[1][0];
54
55            second_offset >= first_offset // Prevent false positive from regular group-by that has out of order slices.
56                                          // Rolling group-by is expected to have monotonically increasing slices.
57                && second_offset < (first_offset + first_len)
58                && chunks.len() == 1
59        },
60    }
61}
62
63// Use an aggregation window that maintains the state
64pub fn _rolling_apply_agg_window_nulls<'a, Agg, T, O>(
65    values: &'a [T],
66    validity: &'a Bitmap,
67    offsets: O,
68    params: Option<RollingFnParams>,
69) -> PrimitiveArray<T>
70where
71    O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
72    Agg: RollingAggWindowNulls<'a, T>,
73    T: IsFloat + NativeType,
74{
75    if values.is_empty() {
76        let out: Vec<T> = vec![];
77        return PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), None);
78    }
79
80    // This iterators length can be trusted
81    // these represent the number of groups in the group_by operation
82    let output_len = offsets.size_hint().0;
83    // start with a dummy index, will be overwritten on first iteration.
84    // SAFETY:
85    // we are in bounds
86    let mut agg_window = unsafe { Agg::new(values, validity, 0, 0, params) };
87
88    let mut validity = MutableBitmap::with_capacity(output_len);
89    validity.extend_constant(output_len, true);
90
91    let out = offsets
92        .enumerate()
93        .map(|(idx, (start, len))| {
94            let end = start + len;
95
96            // SAFETY:
97            // we are in bounds
98
99            let agg = if start == end {
100                None
101            } else {
102                unsafe { agg_window.update(start as usize, end as usize) }
103            };
104
105            match agg {
106                Some(val) => val,
107                None => {
108                    // SAFETY: we are in bounds
109                    unsafe { validity.set_unchecked(idx, false) };
110                    T::default()
111                },
112            }
113        })
114        .collect_trusted::<Vec<_>>();
115
116    PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), Some(validity.into()))
117}
118
119// Use an aggregation window that maintains the state.
120pub fn _rolling_apply_agg_window_no_nulls<'a, Agg, T, O>(
121    values: &'a [T],
122    offsets: O,
123    params: Option<RollingFnParams>,
124) -> PrimitiveArray<T>
125where
126    // items (offset, len) -> so offsets are offset, offset + len
127    Agg: RollingAggWindowNoNulls<'a, T>,
128    O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
129    T: IsFloat + NativeType,
130{
131    if values.is_empty() {
132        let out: Vec<T> = vec![];
133        return PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), None);
134    }
135    // start with a dummy index, will be overwritten on first iteration.
136    let mut agg_window = Agg::new(values, 0, 0, params);
137
138    offsets
139        .map(|(start, len)| {
140            let end = start + len;
141
142            if start == end {
143                None
144            } else {
145                // SAFETY: we are in bounds.
146                unsafe { agg_window.update(start as usize, end as usize) }
147            }
148        })
149        .collect::<PrimitiveArray<T>>()
150}
151
152pub fn _slice_from_offsets<T>(ca: &ChunkedArray<T>, first: IdxSize, len: IdxSize) -> ChunkedArray<T>
153where
154    T: PolarsDataType,
155{
156    ca.slice(first as i64, len as usize)
157}
158
159/// Helper that combines the groups into a parallel iterator over `(first, all): (u32, &Vec<u32>)`.
160pub fn _agg_helper_idx<T, F>(groups: &GroupsIdx, f: F) -> Series
161where
162    F: Fn((IdxSize, &IdxVec)) -> Option<T::Native> + Send + Sync,
163    T: PolarsNumericType,
164    ChunkedArray<T>: IntoSeries,
165{
166    let ca: ChunkedArray<T> = POOL.install(|| groups.into_par_iter().map(f).collect());
167    ca.into_series()
168}
169
170/// Same helper as `_agg_helper_idx` but for aggregations that don't return an Option.
171pub fn _agg_helper_idx_no_null<T, F>(groups: &GroupsIdx, f: F) -> Series
172where
173    F: Fn((IdxSize, &IdxVec)) -> T::Native + Send + Sync,
174    T: PolarsNumericType,
175    ChunkedArray<T>: IntoSeries,
176{
177    let ca: NoNull<ChunkedArray<T>> = POOL.install(|| groups.into_par_iter().map(f).collect());
178    ca.into_inner().into_series()
179}
180
181/// Helper that iterates on the `all: Vec<Vec<u32>` collection,
182/// this doesn't have traverse the `first: Vec<u32>` memory and is therefore faster.
183fn agg_helper_idx_on_all<T, F>(groups: &GroupsIdx, f: F) -> Series
184where
185    F: Fn(&IdxVec) -> Option<T::Native> + Send + Sync,
186    T: PolarsNumericType,
187    ChunkedArray<T>: IntoSeries,
188{
189    let ca: ChunkedArray<T> = POOL.install(|| groups.all().into_par_iter().map(f).collect());
190    ca.into_series()
191}
192
193/// Same as `agg_helper_idx_on_all` but for aggregations that don't return an Option.
194fn agg_helper_idx_on_all_no_null<T, F>(groups: &GroupsIdx, f: F) -> Series
195where
196    F: Fn(&IdxVec) -> T::Native + Send + Sync,
197    T: PolarsNumericType,
198    ChunkedArray<T>: IntoSeries,
199{
200    let ca: NoNull<ChunkedArray<T>> =
201        POOL.install(|| groups.all().into_par_iter().map(f).collect());
202    ca.into_inner().into_series()
203}
204
205pub fn _agg_helper_slice<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
206where
207    F: Fn([IdxSize; 2]) -> Option<T::Native> + Send + Sync,
208    T: PolarsNumericType,
209    ChunkedArray<T>: IntoSeries,
210{
211    let ca: ChunkedArray<T> = POOL.install(|| groups.par_iter().copied().map(f).collect());
212    ca.into_series()
213}
214
215pub fn _agg_helper_slice_no_null<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
216where
217    F: Fn([IdxSize; 2]) -> T::Native + Send + Sync,
218    T: PolarsNumericType,
219    ChunkedArray<T>: IntoSeries,
220{
221    let ca: NoNull<ChunkedArray<T>> = POOL.install(|| groups.par_iter().copied().map(f).collect());
222    ca.into_inner().into_series()
223}
224
225pub trait TakeExtremum {
226    fn take_min(self, other: Self) -> Self;
227
228    fn take_max(self, other: Self) -> Self;
229}
230
231macro_rules! impl_take_extremum {
232    ($tp:ty) => {
233        impl TakeExtremum for $tp {
234            #[inline(always)]
235            fn take_min(self, other: Self) -> Self {
236                if self < other {
237                    self
238                } else {
239                    other
240                }
241            }
242
243            #[inline(always)]
244            fn take_max(self, other: Self) -> Self {
245                if self > other {
246                    self
247                } else {
248                    other
249                }
250            }
251        }
252    };
253
254    (float: $tp:ty) => {
255        impl TakeExtremum for $tp {
256            #[inline(always)]
257            fn take_min(self, other: Self) -> Self {
258                if matches!(compare_fn_nan_max(&self, &other), Ordering::Less) {
259                    self
260                } else {
261                    other
262                }
263            }
264
265            #[inline(always)]
266            fn take_max(self, other: Self) -> Self {
267                if matches!(compare_fn_nan_min(&self, &other), Ordering::Greater) {
268                    self
269                } else {
270                    other
271                }
272            }
273        }
274    };
275}
276
277#[cfg(feature = "dtype-u8")]
278impl_take_extremum!(u8);
279#[cfg(feature = "dtype-u16")]
280impl_take_extremum!(u16);
281impl_take_extremum!(u32);
282impl_take_extremum!(u64);
283#[cfg(feature = "dtype-i8")]
284impl_take_extremum!(i8);
285#[cfg(feature = "dtype-i16")]
286impl_take_extremum!(i16);
287impl_take_extremum!(i32);
288impl_take_extremum!(i64);
289#[cfg(any(feature = "dtype-decimal", feature = "dtype-i128"))]
290impl_take_extremum!(i128);
291impl_take_extremum!(float: f32);
292impl_take_extremum!(float: f64);
293
294/// Intermediate helper trait so we can have a single generic implementation
295/// This trait will ensure the specific dispatch works without complicating
296/// the trait bounds.
297trait QuantileDispatcher<K> {
298    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<K>>;
299
300    fn _median(self) -> Option<K>;
301}
302
303impl<T> QuantileDispatcher<f64> for ChunkedArray<T>
304where
305    T: PolarsIntegerType,
306    T::Native: Ord,
307    ChunkedArray<T>: IntoSeries,
308{
309    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
310        self.quantile_faster(quantile, method)
311    }
312    fn _median(self) -> Option<f64> {
313        self.median_faster()
314    }
315}
316
317impl QuantileDispatcher<f32> for Float32Chunked {
318    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f32>> {
319        self.quantile_faster(quantile, method)
320    }
321    fn _median(self) -> Option<f32> {
322        self.median_faster()
323    }
324}
325impl QuantileDispatcher<f64> for Float64Chunked {
326    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
327        self.quantile_faster(quantile, method)
328    }
329    fn _median(self) -> Option<f64> {
330        self.median_faster()
331    }
332}
333
334unsafe fn agg_quantile_generic<T, K>(
335    ca: &ChunkedArray<T>,
336    groups: &GroupsType,
337    quantile: f64,
338    method: QuantileMethod,
339) -> Series
340where
341    T: PolarsNumericType,
342    ChunkedArray<T>: QuantileDispatcher<K::Native>,
343    ChunkedArray<K>: IntoSeries,
344    K: PolarsNumericType,
345    <K as datatypes::PolarsNumericType>::Native: num_traits::Float,
346{
347    let invalid_quantile = !(0.0..=1.0).contains(&quantile);
348    if invalid_quantile {
349        return Series::full_null(ca.name().clone(), groups.len(), ca.dtype());
350    }
351    match groups {
352        GroupsType::Idx(groups) => {
353            let ca = ca.rechunk();
354            agg_helper_idx_on_all::<K, _>(groups, |idx| {
355                debug_assert!(idx.len() <= ca.len());
356                if idx.is_empty() {
357                    return None;
358                }
359                let take = { ca.take_unchecked(idx) };
360                // checked with invalid quantile check
361                take._quantile(quantile, method).unwrap_unchecked()
362            })
363        },
364        GroupsType::Slice { groups, .. } => {
365            if _use_rolling_kernels(groups, ca.chunks()) {
366                // this cast is a no-op for floats
367                let s = ca
368                    .cast_with_options(&K::get_dtype(), CastOptions::Overflowing)
369                    .unwrap();
370                let ca: &ChunkedArray<K> = s.as_ref().as_ref();
371                let arr = ca.downcast_iter().next().unwrap();
372                let values = arr.values().as_slice();
373                let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
374                let arr = match arr.validity() {
375                    None => _rolling_apply_agg_window_no_nulls::<QuantileWindow<_>, _, _>(
376                        values,
377                        offset_iter,
378                        Some(RollingFnParams::Quantile(RollingQuantileParams {
379                            prob: quantile,
380                            method,
381                        })),
382                    ),
383                    Some(validity) => {
384                        _rolling_apply_agg_window_nulls::<rolling::nulls::QuantileWindow<_>, _, _>(
385                            values,
386                            validity,
387                            offset_iter,
388                            Some(RollingFnParams::Quantile(RollingQuantileParams {
389                                prob: quantile,
390                                method,
391                            })),
392                        )
393                    },
394                };
395                // The rolling kernels works on the dtype, this is not yet the
396                // float output type we need.
397                ChunkedArray::from(arr).into_series()
398            } else {
399                _agg_helper_slice::<K, _>(groups, |[first, len]| {
400                    debug_assert!(first + len <= ca.len() as IdxSize);
401                    match len {
402                        0 => None,
403                        1 => ca.get(first as usize).map(|v| NumCast::from(v).unwrap()),
404                        _ => {
405                            let arr_group = _slice_from_offsets(ca, first, len);
406                            // unwrap checked with invalid quantile check
407                            arr_group
408                                ._quantile(quantile, method)
409                                .unwrap_unchecked()
410                                .map(|flt| NumCast::from(flt).unwrap_unchecked())
411                        },
412                    }
413                })
414            }
415        },
416    }
417}
418
419unsafe fn agg_median_generic<T, K>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series
420where
421    T: PolarsNumericType,
422    ChunkedArray<T>: QuantileDispatcher<K::Native>,
423    ChunkedArray<K>: IntoSeries,
424    K: PolarsNumericType,
425    <K as datatypes::PolarsNumericType>::Native: num_traits::Float,
426{
427    match groups {
428        GroupsType::Idx(groups) => {
429            let ca = ca.rechunk();
430            agg_helper_idx_on_all::<K, _>(groups, |idx| {
431                debug_assert!(idx.len() <= ca.len());
432                if idx.is_empty() {
433                    return None;
434                }
435                let take = { ca.take_unchecked(idx) };
436                take._median()
437            })
438        },
439        GroupsType::Slice { .. } => {
440            agg_quantile_generic::<T, K>(ca, groups, 0.5, QuantileMethod::Linear)
441        },
442    }
443}
444
445/// # Safety
446///
447/// No bounds checks on `groups`.
448#[cfg(feature = "bitwise")]
449unsafe fn bitwise_agg<T: PolarsNumericType>(
450    ca: &ChunkedArray<T>,
451    groups: &GroupsType,
452    f: fn(&ChunkedArray<T>) -> Option<T::Native>,
453) -> Series
454where
455    ChunkedArray<T>:
456        ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce<Physical = T::Native> + IntoSeries,
457{
458    // Prevent a rechunk for every individual group.
459    let s = if groups.len() > 1 {
460        ca.rechunk()
461    } else {
462        ca.clone()
463    };
464
465    match groups {
466        GroupsType::Idx(groups) => agg_helper_idx_on_all::<T, _>(groups, |idx| {
467            debug_assert!(idx.len() <= s.len());
468            if idx.is_empty() {
469                None
470            } else {
471                let take = unsafe { s.take_unchecked(idx) };
472                f(&take)
473            }
474        }),
475        GroupsType::Slice { groups, .. } => _agg_helper_slice::<T, _>(groups, |[first, len]| {
476            debug_assert!(len <= s.len() as IdxSize);
477            if len == 0 {
478                None
479            } else {
480                let take = _slice_from_offsets(&s, first, len);
481                f(&take)
482            }
483        }),
484    }
485}
486
487#[cfg(feature = "bitwise")]
488impl<T> ChunkedArray<T>
489where
490    T: PolarsNumericType,
491    ChunkedArray<T>:
492        ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce<Physical = T::Native> + IntoSeries,
493{
494    /// # Safety
495    ///
496    /// No bounds checks on `groups`.
497    pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> Series {
498        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::and_reduce) }
499    }
500
501    /// # Safety
502    ///
503    /// No bounds checks on `groups`.
504    pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> Series {
505        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::or_reduce) }
506    }
507
508    /// # Safety
509    ///
510    /// No bounds checks on `groups`.
511    pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> Series {
512        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::xor_reduce) }
513    }
514}
515
516impl<T> ChunkedArray<T>
517where
518    T: PolarsNumericType + Sync,
519    T::Native: NativeType
520        + PartialOrd
521        + Num
522        + NumCast
523        + Zero
524        + Bounded
525        + std::iter::Sum<T::Native>
526        + TakeExtremum,
527    ChunkedArray<T>: IntoSeries + ChunkAgg<T::Native>,
528{
529    pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series {
530        // faster paths
531        match (self.is_sorted_flag(), self.null_count()) {
532            (IsSorted::Ascending, 0) => {
533                return self.clone().into_series().agg_first(groups);
534            },
535            (IsSorted::Descending, 0) => {
536                return self.clone().into_series().agg_last(groups);
537            },
538            _ => {},
539        }
540        match groups {
541            GroupsType::Idx(groups) => {
542                let ca = self.rechunk();
543                let arr = ca.downcast_iter().next().unwrap();
544                let no_nulls = arr.null_count() == 0;
545                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
546                    debug_assert!(idx.len() <= arr.len());
547                    if idx.is_empty() {
548                        None
549                    } else if idx.len() == 1 {
550                        arr.get(first as usize)
551                    } else if no_nulls {
552                        take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>(
553                            arr,
554                            idx2usize(idx),
555                            |a, b| a.take_min(b),
556                        )
557                    } else {
558                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a.take_min(b))
559                    }
560                })
561            },
562            GroupsType::Slice {
563                groups: groups_slice,
564                ..
565            } => {
566                if _use_rolling_kernels(groups_slice, self.chunks()) {
567                    let arr = self.downcast_iter().next().unwrap();
568                    let values = arr.values().as_slice();
569                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
570                    let arr = match arr.validity() {
571                        None => _rolling_apply_agg_window_no_nulls::<MinWindow<_>, _, _>(
572                            values,
573                            offset_iter,
574                            None,
575                        ),
576                        Some(validity) => _rolling_apply_agg_window_nulls::<
577                            rolling::nulls::MinWindow<_>,
578                            _,
579                            _,
580                        >(
581                            values, validity, offset_iter, None
582                        ),
583                    };
584                    Self::from(arr).into_series()
585                } else {
586                    _agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
587                        debug_assert!(len <= self.len() as IdxSize);
588                        match len {
589                            0 => None,
590                            1 => self.get(first as usize),
591                            _ => {
592                                let arr_group = _slice_from_offsets(self, first, len);
593                                ChunkAgg::min(&arr_group)
594                            },
595                        }
596                    })
597                }
598            },
599        }
600    }
601
602    pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series {
603        // faster paths
604        match (self.is_sorted_flag(), self.null_count()) {
605            (IsSorted::Ascending, 0) => {
606                return self.clone().into_series().agg_last(groups);
607            },
608            (IsSorted::Descending, 0) => {
609                return self.clone().into_series().agg_first(groups);
610            },
611            _ => {},
612        }
613
614        match groups {
615            GroupsType::Idx(groups) => {
616                let ca = self.rechunk();
617                let arr = ca.downcast_iter().next().unwrap();
618                let no_nulls = arr.null_count() == 0;
619                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
620                    debug_assert!(idx.len() <= arr.len());
621                    if idx.is_empty() {
622                        None
623                    } else if idx.len() == 1 {
624                        arr.get(first as usize)
625                    } else if no_nulls {
626                        take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>(
627                            arr,
628                            idx2usize(idx),
629                            |a, b| a.take_max(b),
630                        )
631                    } else {
632                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a.take_max(b))
633                    }
634                })
635            },
636            GroupsType::Slice {
637                groups: groups_slice,
638                ..
639            } => {
640                if _use_rolling_kernels(groups_slice, self.chunks()) {
641                    let arr = self.downcast_iter().next().unwrap();
642                    let values = arr.values().as_slice();
643                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
644                    let arr = match arr.validity() {
645                        None => _rolling_apply_agg_window_no_nulls::<MaxWindow<_>, _, _>(
646                            values,
647                            offset_iter,
648                            None,
649                        ),
650                        Some(validity) => _rolling_apply_agg_window_nulls::<
651                            rolling::nulls::MaxWindow<_>,
652                            _,
653                            _,
654                        >(
655                            values, validity, offset_iter, None
656                        ),
657                    };
658                    Self::from(arr).into_series()
659                } else {
660                    _agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
661                        debug_assert!(len <= self.len() as IdxSize);
662                        match len {
663                            0 => None,
664                            1 => self.get(first as usize),
665                            _ => {
666                                let arr_group = _slice_from_offsets(self, first, len);
667                                ChunkAgg::max(&arr_group)
668                            },
669                        }
670                    })
671                }
672            },
673        }
674    }
675
676    pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series {
677        match groups {
678            GroupsType::Idx(groups) => {
679                let ca = self.rechunk();
680                let arr = ca.downcast_iter().next().unwrap();
681                let no_nulls = arr.null_count() == 0;
682                _agg_helper_idx_no_null::<T, _>(groups, |(first, idx)| {
683                    debug_assert!(idx.len() <= self.len());
684                    if idx.is_empty() {
685                        T::Native::zero()
686                    } else if idx.len() == 1 {
687                        arr.get(first as usize).unwrap_or(T::Native::zero())
688                    } else if no_nulls {
689                        take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a + b)
690                            .unwrap_or(T::Native::zero())
691                    } else {
692                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a + b)
693                            .unwrap_or(T::Native::zero())
694                    }
695                })
696            },
697            GroupsType::Slice { groups, .. } => {
698                if _use_rolling_kernels(groups, self.chunks()) {
699                    let arr = self.downcast_iter().next().unwrap();
700                    let values = arr.values().as_slice();
701                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
702                    let arr = match arr.validity() {
703                        None => _rolling_apply_agg_window_no_nulls::<SumWindow<_>, _, _>(
704                            values,
705                            offset_iter,
706                            None,
707                        ),
708                        Some(validity) => _rolling_apply_agg_window_nulls::<
709                            rolling::nulls::SumWindow<_>,
710                            _,
711                            _,
712                        >(
713                            values, validity, offset_iter, None
714                        ),
715                    };
716                    Self::from(arr).into_series()
717                } else {
718                    _agg_helper_slice_no_null::<T, _>(groups, |[first, len]| {
719                        debug_assert!(len <= self.len() as IdxSize);
720                        match len {
721                            0 => T::Native::zero(),
722                            1 => self.get(first as usize).unwrap_or(T::Native::zero()),
723                            _ => {
724                                let arr_group = _slice_from_offsets(self, first, len);
725                                arr_group.sum().unwrap_or(T::Native::zero())
726                            },
727                        }
728                    })
729                }
730            },
731        }
732    }
733}
734
735impl<T> SeriesWrap<ChunkedArray<T>>
736where
737    T: PolarsFloatType,
738    ChunkedArray<T>: IntoSeries
739        + ChunkVar
740        + VarAggSeries
741        + ChunkQuantile<T::Native>
742        + QuantileAggSeries
743        + ChunkAgg<T::Native>,
744    T::Native: Pow<T::Native, Output = T::Native>,
745{
746    pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
747        match groups {
748            GroupsType::Idx(groups) => {
749                let ca = self.rechunk();
750                let arr = ca.downcast_iter().next().unwrap();
751                let no_nulls = arr.null_count() == 0;
752                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
753                    // this can fail due to a bug in lazy code.
754                    // here users can create filters in aggregations
755                    // and thereby creating shorter columns than the original group tuples.
756                    // the group tuples are modified, but if that's done incorrect there can be out of bounds
757                    // access
758                    debug_assert!(idx.len() <= self.len());
759                    let out = if idx.is_empty() {
760                        None
761                    } else if idx.len() == 1 {
762                        arr.get(first as usize).map(|sum| sum.to_f64().unwrap())
763                    } else if no_nulls {
764                        take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>(
765                            arr,
766                            idx2usize(idx),
767                            |a, b| a + b,
768                        )
769                        .unwrap()
770                        .to_f64()
771                        .map(|sum| sum / idx.len() as f64)
772                    } else {
773                        take_agg_primitive_iter_unchecked_count_nulls::<T::Native, _, _, _>(
774                            arr,
775                            idx2usize(idx),
776                            |a, b| a + b,
777                            T::Native::zero(),
778                            idx.len() as IdxSize,
779                        )
780                        .map(|(sum, null_count)| {
781                            sum.to_f64()
782                                .map(|sum| sum / (idx.len() as f64 - null_count as f64))
783                                .unwrap()
784                        })
785                    };
786                    out.map(|flt| NumCast::from(flt).unwrap())
787                })
788            },
789            GroupsType::Slice { groups, .. } => {
790                if _use_rolling_kernels(groups, self.chunks()) {
791                    let arr = self.downcast_iter().next().unwrap();
792                    let values = arr.values().as_slice();
793                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
794                    let arr = match arr.validity() {
795                        None => _rolling_apply_agg_window_no_nulls::<MeanWindow<_>, _, _>(
796                            values,
797                            offset_iter,
798                            None,
799                        ),
800                        Some(validity) => _rolling_apply_agg_window_nulls::<
801                            rolling::nulls::MeanWindow<_>,
802                            _,
803                            _,
804                        >(
805                            values, validity, offset_iter, None
806                        ),
807                    };
808                    ChunkedArray::from(arr).into_series()
809                } else {
810                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
811                        debug_assert!(len <= self.len() as IdxSize);
812                        match len {
813                            0 => None,
814                            1 => self.get(first as usize),
815                            _ => {
816                                let arr_group = _slice_from_offsets(self, first, len);
817                                arr_group.mean().map(|flt| NumCast::from(flt).unwrap())
818                            },
819                        }
820                    })
821                }
822            },
823        }
824    }
825
826    pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series
827    where
828        <T as datatypes::PolarsNumericType>::Native: num_traits::Float,
829    {
830        let ca = &self.0.rechunk();
831        match groups {
832            GroupsType::Idx(groups) => {
833                let ca = ca.rechunk();
834                let arr = ca.downcast_iter().next().unwrap();
835                let no_nulls = arr.null_count() == 0;
836                agg_helper_idx_on_all::<T, _>(groups, |idx| {
837                    debug_assert!(idx.len() <= ca.len());
838                    if idx.is_empty() {
839                        return None;
840                    }
841                    let out = if no_nulls {
842                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
843                    } else {
844                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
845                    };
846                    out.map(|flt| NumCast::from(flt).unwrap())
847                })
848            },
849            GroupsType::Slice { groups, .. } => {
850                if _use_rolling_kernels(groups, self.chunks()) {
851                    let arr = self.downcast_iter().next().unwrap();
852                    let values = arr.values().as_slice();
853                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
854                    let arr = match arr.validity() {
855                        None => _rolling_apply_agg_window_no_nulls::<VarWindow<_>, _, _>(
856                            values,
857                            offset_iter,
858                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
859                        ),
860                        Some(validity) => {
861                            _rolling_apply_agg_window_nulls::<rolling::nulls::VarWindow<_>, _, _>(
862                                values,
863                                validity,
864                                offset_iter,
865                                Some(RollingFnParams::Var(RollingVarParams { ddof })),
866                            )
867                        },
868                    };
869                    ChunkedArray::from(arr).into_series()
870                } else {
871                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
872                        debug_assert!(len <= self.len() as IdxSize);
873                        match len {
874                            0 => None,
875                            1 => {
876                                if ddof == 0 {
877                                    NumCast::from(0)
878                                } else {
879                                    None
880                                }
881                            },
882                            _ => {
883                                let arr_group = _slice_from_offsets(self, first, len);
884                                arr_group.var(ddof).map(|flt| NumCast::from(flt).unwrap())
885                            },
886                        }
887                    })
888                }
889            },
890        }
891    }
892    pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series
893    where
894        <T as datatypes::PolarsNumericType>::Native: num_traits::Float,
895    {
896        let ca = &self.0.rechunk();
897        match groups {
898            GroupsType::Idx(groups) => {
899                let arr = ca.downcast_iter().next().unwrap();
900                let no_nulls = arr.null_count() == 0;
901                agg_helper_idx_on_all::<T, _>(groups, |idx| {
902                    debug_assert!(idx.len() <= ca.len());
903                    if idx.is_empty() {
904                        return None;
905                    }
906                    let out = if no_nulls {
907                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
908                    } else {
909                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
910                    };
911                    out.map(|flt| NumCast::from(flt.sqrt()).unwrap())
912                })
913            },
914            GroupsType::Slice { groups, .. } => {
915                if _use_rolling_kernels(groups, self.chunks()) {
916                    let arr = ca.downcast_iter().next().unwrap();
917                    let values = arr.values().as_slice();
918                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
919                    let arr = match arr.validity() {
920                        None => _rolling_apply_agg_window_no_nulls::<VarWindow<_>, _, _>(
921                            values,
922                            offset_iter,
923                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
924                        ),
925                        Some(validity) => {
926                            _rolling_apply_agg_window_nulls::<rolling::nulls::VarWindow<_>, _, _>(
927                                values,
928                                validity,
929                                offset_iter,
930                                Some(RollingFnParams::Var(RollingVarParams { ddof })),
931                            )
932                        },
933                    };
934
935                    let mut ca = ChunkedArray::<T>::from(arr);
936                    ca.apply_mut(|v| v.powf(NumCast::from(0.5).unwrap()));
937                    ca.into_series()
938                } else {
939                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
940                        debug_assert!(len <= self.len() as IdxSize);
941                        match len {
942                            0 => None,
943                            1 => {
944                                if ddof == 0 {
945                                    NumCast::from(0)
946                                } else {
947                                    None
948                                }
949                            },
950                            _ => {
951                                let arr_group = _slice_from_offsets(self, first, len);
952                                arr_group.std(ddof).map(|flt| NumCast::from(flt).unwrap())
953                            },
954                        }
955                    })
956                }
957            },
958        }
959    }
960}
961
962impl Float32Chunked {
963    pub(crate) unsafe fn agg_quantile(
964        &self,
965        groups: &GroupsType,
966        quantile: f64,
967        method: QuantileMethod,
968    ) -> Series {
969        agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method)
970    }
971    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
972        agg_median_generic::<_, Float32Type>(self, groups)
973    }
974}
975impl Float64Chunked {
976    pub(crate) unsafe fn agg_quantile(
977        &self,
978        groups: &GroupsType,
979        quantile: f64,
980        method: QuantileMethod,
981    ) -> Series {
982        agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
983    }
984    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
985        agg_median_generic::<_, Float64Type>(self, groups)
986    }
987}
988
989impl<T> ChunkedArray<T>
990where
991    T: PolarsIntegerType,
992    ChunkedArray<T>: IntoSeries + ChunkAgg<T::Native> + ChunkVar,
993    T::Native: NumericNative + Ord,
994{
995    pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
996        match groups {
997            GroupsType::Idx(groups) => {
998                let ca = self.rechunk();
999                let arr = ca.downcast_get(0).unwrap();
1000                _agg_helper_idx::<Float64Type, _>(groups, |(first, idx)| {
1001                    // this can fail due to a bug in lazy code.
1002                    // here users can create filters in aggregations
1003                    // and thereby creating shorter columns than the original group tuples.
1004                    // the group tuples are modified, but if that's done incorrect there can be out of bounds
1005                    // access
1006                    debug_assert!(idx.len() <= self.len());
1007                    if idx.is_empty() {
1008                        None
1009                    } else if idx.len() == 1 {
1010                        self.get(first as usize).map(|sum| sum.to_f64().unwrap())
1011                    } else {
1012                        match (self.has_nulls(), self.chunks.len()) {
1013                            (false, 1) => {
1014                                take_agg_no_null_primitive_iter_unchecked::<_, f64, _, _>(
1015                                    arr,
1016                                    idx2usize(idx),
1017                                    |a, b| a + b,
1018                                )
1019                                .map(|sum| sum / idx.len() as f64)
1020                            },
1021                            (_, 1) => {
1022                                {
1023                                    take_agg_primitive_iter_unchecked_count_nulls::<
1024                                        T::Native,
1025                                        f64,
1026                                        _,
1027                                        _,
1028                                    >(
1029                                        arr, idx2usize(idx), |a, b| a + b, 0.0, idx.len() as IdxSize
1030                                    )
1031                                }
1032                                .map(|(sum, null_count)| {
1033                                    sum / (idx.len() as f64 - null_count as f64)
1034                                })
1035                            },
1036                            _ => {
1037                                let take = { self.take_unchecked(idx) };
1038                                take.mean()
1039                            },
1040                        }
1041                    }
1042                })
1043            },
1044            GroupsType::Slice {
1045                groups: groups_slice,
1046                ..
1047            } => {
1048                if _use_rolling_kernels(groups_slice, self.chunks()) {
1049                    let ca = self
1050                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1051                        .unwrap();
1052                    ca.agg_mean(groups)
1053                } else {
1054                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1055                        debug_assert!(first + len <= self.len() as IdxSize);
1056                        match len {
1057                            0 => None,
1058                            1 => self.get(first as usize).map(|v| NumCast::from(v).unwrap()),
1059                            _ => {
1060                                let arr_group = _slice_from_offsets(self, first, len);
1061                                arr_group.mean()
1062                            },
1063                        }
1064                    })
1065                }
1066            },
1067        }
1068    }
1069
1070    pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series {
1071        match groups {
1072            GroupsType::Idx(groups) => {
1073                let ca_self = self.rechunk();
1074                let arr = ca_self.downcast_iter().next().unwrap();
1075                let no_nulls = arr.null_count() == 0;
1076                agg_helper_idx_on_all::<Float64Type, _>(groups, |idx| {
1077                    debug_assert!(idx.len() <= arr.len());
1078                    if idx.is_empty() {
1079                        return None;
1080                    }
1081                    if no_nulls {
1082                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1083                    } else {
1084                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1085                    }
1086                })
1087            },
1088            GroupsType::Slice {
1089                groups: groups_slice,
1090                ..
1091            } => {
1092                if _use_rolling_kernels(groups_slice, self.chunks()) {
1093                    let ca = self
1094                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1095                        .unwrap();
1096                    ca.agg_var(groups, ddof)
1097                } else {
1098                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1099                        debug_assert!(first + len <= self.len() as IdxSize);
1100                        match len {
1101                            0 => None,
1102                            1 => {
1103                                if ddof == 0 {
1104                                    NumCast::from(0)
1105                                } else {
1106                                    None
1107                                }
1108                            },
1109                            _ => {
1110                                let arr_group = _slice_from_offsets(self, first, len);
1111                                arr_group.var(ddof)
1112                            },
1113                        }
1114                    })
1115                }
1116            },
1117        }
1118    }
1119    pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series {
1120        match groups {
1121            GroupsType::Idx(groups) => {
1122                let ca_self = self.rechunk();
1123                let arr = ca_self.downcast_iter().next().unwrap();
1124                let no_nulls = arr.null_count() == 0;
1125                agg_helper_idx_on_all::<Float64Type, _>(groups, |idx| {
1126                    debug_assert!(idx.len() <= self.len());
1127                    if idx.is_empty() {
1128                        return None;
1129                    }
1130                    let out = if no_nulls {
1131                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1132                    } else {
1133                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1134                    };
1135                    out.map(|v| v.sqrt())
1136                })
1137            },
1138            GroupsType::Slice {
1139                groups: groups_slice,
1140                ..
1141            } => {
1142                if _use_rolling_kernels(groups_slice, self.chunks()) {
1143                    let ca = self
1144                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1145                        .unwrap();
1146                    ca.agg_std(groups, ddof)
1147                } else {
1148                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1149                        debug_assert!(first + len <= self.len() as IdxSize);
1150                        match len {
1151                            0 => None,
1152                            1 => {
1153                                if ddof == 0 {
1154                                    NumCast::from(0)
1155                                } else {
1156                                    None
1157                                }
1158                            },
1159                            _ => {
1160                                let arr_group = _slice_from_offsets(self, first, len);
1161                                arr_group.std(ddof)
1162                            },
1163                        }
1164                    })
1165                }
1166            },
1167        }
1168    }
1169
1170    pub(crate) unsafe fn agg_quantile(
1171        &self,
1172        groups: &GroupsType,
1173        quantile: f64,
1174        method: QuantileMethod,
1175    ) -> Series {
1176        agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
1177    }
1178    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
1179        agg_median_generic::<_, Float64Type>(self, groups)
1180    }
1181}