polars_compute/gather/sublist/
list.rs

1use arrow::array::{Array, ArrayRef, ListArray};
2use arrow::legacy::prelude::*;
3use arrow::legacy::trusted_len::TrustedLenPush;
4use arrow::legacy::utils::CustomIterTools;
5use arrow::offset::{Offsets, OffsetsBuffer};
6use polars_utils::IdxSize;
7
8use crate::gather::take_unchecked;
9
10/// Get the indices that would result in a get operation on the lists values.
11/// for example, consider this list:
12/// ```text
13/// [[1, 2, 3],
14///  [4, 5],
15///  [6]]
16///
17///  This contains the following values array:
18/// [1, 2, 3, 4, 5, 6]
19///
20/// get index 0
21/// would lead to the following indexes:
22///     [0, 3, 5].
23/// if we use those in a take operation on the values array we get:
24///     [1, 4, 6]
25///
26///
27/// get index -1
28/// would lead to the following indexes:
29///     [2, 4, 5].
30/// if we use those in a take operation on the values array we get:
31///     [3, 5, 6]
32///
33/// ```
34fn sublist_get_indexes(arr: &ListArray<i64>, index: i64) -> IdxArr {
35    let offsets = arr.offsets().as_slice();
36    let mut iter = offsets.iter();
37
38    // the indices can be sliced, so we should not start at 0.
39    let mut cum_offset = (*offsets.first().unwrap_or(&0)) as IdxSize;
40
41    if let Some(mut previous) = iter.next().copied() {
42        if arr.null_count() == 0 {
43            iter.map(|&offset| {
44                let len = offset - previous;
45                previous = offset;
46                // make sure that empty lists don't get accessed
47                // and out of bounds return null
48                if len == 0 {
49                    return None;
50                }
51                if index >= len {
52                    cum_offset += len as IdxSize;
53                    return None;
54                }
55
56                let out = index
57                    .negative_to_usize(len as usize)
58                    .map(|idx| idx as IdxSize + cum_offset);
59                cum_offset += len as IdxSize;
60                out
61            })
62            .collect_trusted()
63        } else {
64            // we can ensure that validity is not none as we have null value.
65            let validity = arr.validity().unwrap();
66            iter.enumerate()
67                .map(|(i, &offset)| {
68                    let len = offset - previous;
69                    previous = offset;
70                    // make sure that empty and null lists don't get accessed and return null.
71                    // SAFETY, we are within bounds
72                    if len == 0 || !unsafe { validity.get_bit_unchecked(i) } {
73                        cum_offset += len as IdxSize;
74                        return None;
75                    }
76
77                    // make sure that out of bounds return null
78                    if index >= len {
79                        cum_offset += len as IdxSize;
80                        return None;
81                    }
82
83                    let out = index
84                        .negative_to_usize(len as usize)
85                        .map(|idx| idx as IdxSize + cum_offset);
86                    cum_offset += len as IdxSize;
87                    out
88                })
89                .collect_trusted()
90        }
91    } else {
92        IdxArr::from_slice([])
93    }
94}
95
96pub fn sublist_get(arr: &ListArray<i64>, index: i64) -> ArrayRef {
97    let take_by = sublist_get_indexes(arr, index);
98    let values = arr.values();
99    // SAFETY:
100    // the indices we generate are in bounds
101    unsafe { take_unchecked(&**values, &take_by) }
102}
103
104/// Check if an index is out of bounds for at least one sublist.
105pub fn index_is_oob(arr: &ListArray<i64>, index: i64) -> bool {
106    if arr.null_count() == 0 {
107        arr.offsets()
108            .lengths()
109            .any(|len| index.negative_to_usize(len).is_none())
110    } else {
111        arr.offsets()
112            .lengths()
113            .zip(arr.validity().unwrap())
114            .any(|(len, valid)| {
115                if valid {
116                    index.negative_to_usize(len).is_none()
117                } else {
118                    // skip nulls
119                    false
120                }
121            })
122    }
123}
124
125/// Convert a list `[1, 2, 3]` to a list type of `[[1], [2], [3]]`
126pub fn array_to_unit_list(array: ArrayRef) -> ListArray<i64> {
127    let len = array.len();
128    let mut offsets = Vec::with_capacity(len + 1);
129    // SAFETY: we allocated enough
130    unsafe {
131        offsets.push_unchecked(0i64);
132
133        for _ in 0..len {
134            offsets.push_unchecked(offsets.len() as i64)
135        }
136    };
137
138    // SAFETY:
139    // offsets are monotonically increasing
140    unsafe {
141        let offsets: OffsetsBuffer<i64> = Offsets::new_unchecked(offsets).into();
142        let dtype = ListArray::<i64>::default_datatype(array.dtype().clone());
143        ListArray::<i64>::new(dtype, offsets, array, None)
144    }
145}
146
147#[cfg(test)]
148mod test {
149    use arrow::array::{Int32Array, PrimitiveArray};
150    use arrow::datatypes::ArrowDataType;
151
152    use super::*;
153
154    fn get_array() -> ListArray<i64> {
155        let values = Int32Array::from_slice([1, 2, 3, 4, 5, 6]);
156        let offsets = OffsetsBuffer::try_from(vec![0i64, 3, 5, 6]).unwrap();
157
158        let dtype = ListArray::<i64>::default_datatype(ArrowDataType::Int32);
159        ListArray::<i64>::new(dtype, offsets, Box::new(values), None)
160    }
161
162    #[test]
163    fn test_sublist_get_indexes() {
164        let arr = get_array();
165        let out = sublist_get_indexes(&arr, 0);
166        assert_eq!(out.values().as_slice(), &[0, 3, 5]);
167        let out = sublist_get_indexes(&arr, -1);
168        assert_eq!(out.values().as_slice(), &[2, 4, 5]);
169        let out = sublist_get_indexes(&arr, 3);
170        assert_eq!(out.null_count(), 3);
171
172        let values = Int32Array::from_iter([
173            Some(1),
174            Some(1),
175            Some(3),
176            Some(4),
177            Some(5),
178            Some(6),
179            Some(7),
180            Some(8),
181            Some(9),
182            None,
183            Some(11),
184        ]);
185        let offsets = OffsetsBuffer::try_from(vec![0i64, 1, 2, 3, 6, 9, 11]).unwrap();
186
187        let dtype = ListArray::<i64>::default_datatype(ArrowDataType::Int32);
188        let arr = ListArray::<i64>::new(dtype, offsets, Box::new(values), None);
189
190        let out = sublist_get_indexes(&arr, 1);
191        assert_eq!(
192            out.into_iter().collect::<Vec<_>>(),
193            &[None, None, None, Some(4), Some(7), Some(10)]
194        );
195    }
196
197    #[test]
198    fn test_sublist_get() {
199        let arr = get_array();
200
201        let out = sublist_get(&arr, 0);
202        let out = out.as_any().downcast_ref::<PrimitiveArray<i32>>().unwrap();
203
204        assert_eq!(out.values().as_slice(), &[1, 4, 6]);
205        let out = sublist_get(&arr, -1);
206        let out = out.as_any().downcast_ref::<PrimitiveArray<i32>>().unwrap();
207        assert_eq!(out.values().as_slice(), &[3, 5, 6]);
208    }
209}