polars_compute/gather/sublist/
list.rs1use arrow::array::{Array, ArrayRef, ListArray};
2use arrow::legacy::prelude::*;
3use arrow::legacy::trusted_len::TrustedLenPush;
4use arrow::legacy::utils::CustomIterTools;
5use arrow::offset::{Offsets, OffsetsBuffer};
6use polars_utils::IdxSize;
7
8use crate::gather::take_unchecked;
9
10fn sublist_get_indexes(arr: &ListArray<i64>, index: i64) -> IdxArr {
35 let offsets = arr.offsets().as_slice();
36 let mut iter = offsets.iter();
37
38 let mut cum_offset = (*offsets.first().unwrap_or(&0)) as IdxSize;
40
41 if let Some(mut previous) = iter.next().copied() {
42 if arr.null_count() == 0 {
43 iter.map(|&offset| {
44 let len = offset - previous;
45 previous = offset;
46 if len == 0 {
49 return None;
50 }
51 if index >= len {
52 cum_offset += len as IdxSize;
53 return None;
54 }
55
56 let out = index
57 .negative_to_usize(len as usize)
58 .map(|idx| idx as IdxSize + cum_offset);
59 cum_offset += len as IdxSize;
60 out
61 })
62 .collect_trusted()
63 } else {
64 let validity = arr.validity().unwrap();
66 iter.enumerate()
67 .map(|(i, &offset)| {
68 let len = offset - previous;
69 previous = offset;
70 if len == 0 || !unsafe { validity.get_bit_unchecked(i) } {
73 cum_offset += len as IdxSize;
74 return None;
75 }
76
77 if index >= len {
79 cum_offset += len as IdxSize;
80 return None;
81 }
82
83 let out = index
84 .negative_to_usize(len as usize)
85 .map(|idx| idx as IdxSize + cum_offset);
86 cum_offset += len as IdxSize;
87 out
88 })
89 .collect_trusted()
90 }
91 } else {
92 IdxArr::from_slice([])
93 }
94}
95
96pub fn sublist_get(arr: &ListArray<i64>, index: i64) -> ArrayRef {
97 let take_by = sublist_get_indexes(arr, index);
98 let values = arr.values();
99 unsafe { take_unchecked(&**values, &take_by) }
102}
103
104pub fn index_is_oob(arr: &ListArray<i64>, index: i64) -> bool {
106 if arr.null_count() == 0 {
107 arr.offsets()
108 .lengths()
109 .any(|len| index.negative_to_usize(len).is_none())
110 } else {
111 arr.offsets()
112 .lengths()
113 .zip(arr.validity().unwrap())
114 .any(|(len, valid)| {
115 if valid {
116 index.negative_to_usize(len).is_none()
117 } else {
118 false
120 }
121 })
122 }
123}
124
125pub fn array_to_unit_list(array: ArrayRef) -> ListArray<i64> {
127 let len = array.len();
128 let mut offsets = Vec::with_capacity(len + 1);
129 unsafe {
131 offsets.push_unchecked(0i64);
132
133 for _ in 0..len {
134 offsets.push_unchecked(offsets.len() as i64)
135 }
136 };
137
138 unsafe {
141 let offsets: OffsetsBuffer<i64> = Offsets::new_unchecked(offsets).into();
142 let dtype = ListArray::<i64>::default_datatype(array.dtype().clone());
143 ListArray::<i64>::new(dtype, offsets, array, None)
144 }
145}
146
147#[cfg(test)]
148mod test {
149 use arrow::array::{Int32Array, PrimitiveArray};
150 use arrow::datatypes::ArrowDataType;
151
152 use super::*;
153
154 fn get_array() -> ListArray<i64> {
155 let values = Int32Array::from_slice([1, 2, 3, 4, 5, 6]);
156 let offsets = OffsetsBuffer::try_from(vec![0i64, 3, 5, 6]).unwrap();
157
158 let dtype = ListArray::<i64>::default_datatype(ArrowDataType::Int32);
159 ListArray::<i64>::new(dtype, offsets, Box::new(values), None)
160 }
161
162 #[test]
163 fn test_sublist_get_indexes() {
164 let arr = get_array();
165 let out = sublist_get_indexes(&arr, 0);
166 assert_eq!(out.values().as_slice(), &[0, 3, 5]);
167 let out = sublist_get_indexes(&arr, -1);
168 assert_eq!(out.values().as_slice(), &[2, 4, 5]);
169 let out = sublist_get_indexes(&arr, 3);
170 assert_eq!(out.null_count(), 3);
171
172 let values = Int32Array::from_iter([
173 Some(1),
174 Some(1),
175 Some(3),
176 Some(4),
177 Some(5),
178 Some(6),
179 Some(7),
180 Some(8),
181 Some(9),
182 None,
183 Some(11),
184 ]);
185 let offsets = OffsetsBuffer::try_from(vec![0i64, 1, 2, 3, 6, 9, 11]).unwrap();
186
187 let dtype = ListArray::<i64>::default_datatype(ArrowDataType::Int32);
188 let arr = ListArray::<i64>::new(dtype, offsets, Box::new(values), None);
189
190 let out = sublist_get_indexes(&arr, 1);
191 assert_eq!(
192 out.into_iter().collect::<Vec<_>>(),
193 &[None, None, None, Some(4), Some(7), Some(10)]
194 );
195 }
196
197 #[test]
198 fn test_sublist_get() {
199 let arr = get_array();
200
201 let out = sublist_get(&arr, 0);
202 let out = out.as_any().downcast_ref::<PrimitiveArray<i32>>().unwrap();
203
204 assert_eq!(out.values().as_slice(), &[1, 4, 6]);
205 let out = sublist_get(&arr, -1);
206 let out = out.as_any().downcast_ref::<PrimitiveArray<i32>>().unwrap();
207 assert_eq!(out.values().as_slice(), &[3, 5, 6]);
208 }
209}