polars_arrow/array/union/
mod.rs1use polars_error::{polars_bail, polars_err, PolarsResult};
2
3use super::{new_empty_array, new_null_array, Array, Splitable};
4use crate::bitmap::Bitmap;
5use crate::buffer::Buffer;
6use crate::datatypes::{ArrowDataType, Field, UnionMode};
7use crate::scalar::{new_scalar, Scalar};
8
9mod ffi;
10pub(super) mod fmt;
11mod iterator;
12
13type UnionComponents<'a> = (&'a [Field], Option<&'a [i32]>, UnionMode);
14
15#[derive(Clone)]
26pub struct UnionArray {
27 types: Buffer<i8>,
29 map: Option<[usize; 127]>,
32 fields: Vec<Box<dyn Array>>,
33 offsets: Option<Buffer<i32>>,
35 dtype: ArrowDataType,
36 offset: usize,
37}
38
39impl UnionArray {
40 pub fn try_new(
48 dtype: ArrowDataType,
49 types: Buffer<i8>,
50 fields: Vec<Box<dyn Array>>,
51 offsets: Option<Buffer<i32>>,
52 ) -> PolarsResult<Self> {
53 let (f, ids, mode) = Self::try_get_all(&dtype)?;
54
55 if f.len() != fields.len() {
56 polars_bail!(ComputeError: "the number of `fields` must equal the number of children fields in DataType::Union")
57 };
58 let number_of_fields: i8 = fields.len().try_into().map_err(
59 |_| polars_err!(ComputeError: "the number of `fields` cannot be larger than i8::MAX"),
60 )?;
61
62 f
63 .iter().map(|a| a.dtype())
64 .zip(fields.iter().map(|a| a.dtype()))
65 .enumerate()
66 .try_for_each(|(index, (dtype, child))| {
67 if dtype != child {
68 polars_bail!(ComputeError:
69 "the children DataTypes of a UnionArray must equal the children data types.
70 However, the field {index} has data type {dtype:?} but the value has data type {child:?}"
71 )
72 } else {
73 Ok(())
74 }
75 })?;
76
77 if let Some(offsets) = &offsets {
78 if offsets.len() != types.len() {
79 polars_bail!(ComputeError:
80 "in a UnionArray, the offsets' length must be equal to the number of types"
81 )
82 }
83 }
84 if offsets.is_none() != mode.is_sparse() {
85 polars_bail!(ComputeError:
86 "in a sparse UnionArray, the offsets must be set (and vice-versa)",
87 )
88 }
89
90 let map = if let Some(&ids) = ids.as_ref() {
92 if ids.len() != fields.len() {
93 polars_bail!(ComputeError:
94 "in a union, when the ids are set, their length must be equal to the number of fields",
95 )
96 }
97
98 let mut hash = [0; 127];
103
104 for (pos, &id) in ids.iter().enumerate() {
105 if !(0..=127).contains(&id) {
106 polars_bail!(ComputeError:
107 "in a union, when the ids are set, every id must belong to [0, 128[",
108 )
109 }
110 hash[id as usize] = pos;
111 }
112
113 types.iter().try_for_each(|&type_| {
114 if type_ < 0 {
115 polars_bail!(ComputeError:
116 "in a union, when the ids are set, every type must be >= 0"
117 )
118 }
119 let id = hash[type_ as usize];
120 if id >= fields.len() {
121 polars_bail!(ComputeError:
122 "in a union, when the ids are set, each id must be smaller than the number of fields."
123 )
124 } else {
125 Ok(())
126 }
127 })?;
128
129 Some(hash)
130 } else {
131 let mut is_valid = true;
133 for &type_ in types.iter() {
134 if type_ < 0 || type_ >= number_of_fields {
135 is_valid = false
136 }
137 }
138 if !is_valid {
139 polars_bail!(ComputeError:
140 "every type in `types` must be larger than 0 and smaller than the number of fields.",
141 )
142 }
143
144 None
145 };
146
147 Ok(Self {
148 dtype,
149 map,
150 fields,
151 offsets,
152 types,
153 offset: 0,
154 })
155 }
156
157 pub fn new(
164 dtype: ArrowDataType,
165 types: Buffer<i8>,
166 fields: Vec<Box<dyn Array>>,
167 offsets: Option<Buffer<i32>>,
168 ) -> Self {
169 Self::try_new(dtype, types, fields, offsets).unwrap()
170 }
171
172 pub fn new_null(dtype: ArrowDataType, length: usize) -> Self {
174 if let ArrowDataType::Union(u) = &dtype {
175 let fields = u
176 .fields
177 .iter()
178 .map(|x| new_null_array(x.dtype().clone(), length))
179 .collect();
180
181 let offsets = if u.mode.is_sparse() {
182 None
183 } else {
184 Some((0..length as i32).collect::<Vec<_>>().into())
185 };
186
187 let types = vec![0i8; length].into();
189
190 Self::new(dtype, types, fields, offsets)
191 } else {
192 panic!("Union struct must be created with the corresponding Union DataType")
193 }
194 }
195
196 pub fn new_empty(dtype: ArrowDataType) -> Self {
198 if let ArrowDataType::Union(u) = dtype.to_logical_type() {
199 let fields = u
200 .fields
201 .iter()
202 .map(|x| new_empty_array(x.dtype().clone()))
203 .collect();
204
205 let offsets = if u.mode.is_sparse() {
206 None
207 } else {
208 Some(Buffer::default())
209 };
210
211 Self {
212 dtype,
213 map: None,
214 fields,
215 offsets,
216 types: Buffer::new(),
217 offset: 0,
218 }
219 } else {
220 panic!("Union struct must be created with the corresponding Union DataType")
221 }
222 }
223}
224
225impl UnionArray {
226 #[inline]
232 pub fn slice(&mut self, offset: usize, length: usize) {
233 assert!(
234 offset + length <= self.len(),
235 "the offset of the new array cannot exceed the existing length"
236 );
237 unsafe { self.slice_unchecked(offset, length) }
238 }
239
240 #[inline]
247 pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) {
248 debug_assert!(offset + length <= self.len());
249
250 self.types.slice_unchecked(offset, length);
251 if let Some(offsets) = self.offsets.as_mut() {
252 offsets.slice_unchecked(offset, length)
253 }
254 self.offset += offset;
255 }
256
257 impl_sliced!();
258 impl_into_array!();
259}
260
261impl UnionArray {
262 #[inline]
264 pub fn len(&self) -> usize {
265 self.types.len()
266 }
267
268 pub fn offsets(&self) -> Option<&Buffer<i32>> {
270 self.offsets.as_ref()
271 }
272
273 pub fn fields(&self) -> &Vec<Box<dyn Array>> {
275 &self.fields
276 }
277
278 pub fn types(&self) -> &Buffer<i8> {
280 &self.types
281 }
282
283 #[inline]
284 unsafe fn field_slot_unchecked(&self, index: usize) -> usize {
285 self.offsets()
286 .as_ref()
287 .map(|x| *x.get_unchecked(index) as usize)
288 .unwrap_or(index + self.offset)
289 }
290
291 #[inline]
293 pub fn index(&self, index: usize) -> (usize, usize) {
294 assert!(index < self.len());
295 unsafe { self.index_unchecked(index) }
296 }
297
298 #[inline]
304 pub unsafe fn index_unchecked(&self, index: usize) -> (usize, usize) {
305 debug_assert!(index < self.len());
306 let type_ = unsafe { *self.types.get_unchecked(index) };
308 let type_ = self
310 .map
311 .as_ref()
312 .map(|map| unsafe { *map.get_unchecked(type_ as usize) })
313 .unwrap_or(type_ as usize);
314 let index = self.field_slot_unchecked(index);
316 (type_, index)
317 }
318
319 pub fn value(&self, index: usize) -> Box<dyn Scalar> {
323 assert!(index < self.len());
324 unsafe { self.value_unchecked(index) }
325 }
326
327 pub unsafe fn value_unchecked(&self, index: usize) -> Box<dyn Scalar> {
332 debug_assert!(index < self.len());
333 let (type_, index) = self.index_unchecked(index);
334 debug_assert!(type_ < self.fields.len());
336 let field = self.fields.get_unchecked(type_).as_ref();
337 new_scalar(field, index)
338 }
339}
340
341impl Array for UnionArray {
342 impl_common_array!();
343
344 fn validity(&self) -> Option<&Bitmap> {
345 None
346 }
347
348 fn with_validity(&self, _: Option<Bitmap>) -> Box<dyn Array> {
349 panic!("cannot set validity of a union array")
350 }
351}
352
353impl UnionArray {
354 fn try_get_all(dtype: &ArrowDataType) -> PolarsResult<UnionComponents> {
355 match dtype.to_logical_type() {
356 ArrowDataType::Union(u) => Ok((&u.fields, u.ids.as_ref().map(|x| x.as_ref()), u.mode)),
357 _ => polars_bail!(ComputeError:
358 "The UnionArray requires a logical type of DataType::Union",
359 ),
360 }
361 }
362
363 fn get_all(dtype: &ArrowDataType) -> (&[Field], Option<&[i32]>, UnionMode) {
364 Self::try_get_all(dtype).unwrap()
365 }
366
367 pub fn get_fields(dtype: &ArrowDataType) -> &[Field] {
371 Self::get_all(dtype).0
372 }
373
374 pub fn is_sparse(dtype: &ArrowDataType) -> bool {
378 Self::get_all(dtype).2.is_sparse()
379 }
380}
381
382impl Splitable for UnionArray {
383 fn check_bound(&self, offset: usize) -> bool {
384 offset <= self.len()
385 }
386
387 unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) {
388 let (lhs_types, rhs_types) = unsafe { self.types.split_at_unchecked(offset) };
389 let (lhs_offsets, rhs_offsets) = self.offsets.as_ref().map_or((None, None), |v| {
390 let (lhs, rhs) = unsafe { v.split_at_unchecked(offset) };
391 (Some(lhs), Some(rhs))
392 });
393
394 (
395 Self {
396 types: lhs_types,
397 map: self.map,
398 fields: self.fields.clone(),
399 offsets: lhs_offsets,
400 dtype: self.dtype.clone(),
401 offset: self.offset,
402 },
403 Self {
404 types: rhs_types,
405 map: self.map,
406 fields: self.fields.clone(),
407 offsets: rhs_offsets,
408 dtype: self.dtype.clone(),
409 offset: self.offset + offset,
410 },
411 )
412 }
413}