polars_arrow/array/union/
mod.rs

1use polars_error::{polars_bail, polars_err, PolarsResult};
2
3use super::{new_empty_array, new_null_array, Array, Splitable};
4use crate::bitmap::Bitmap;
5use crate::buffer::Buffer;
6use crate::datatypes::{ArrowDataType, Field, UnionMode};
7use crate::scalar::{new_scalar, Scalar};
8
9mod ffi;
10pub(super) mod fmt;
11mod iterator;
12
13type UnionComponents<'a> = (&'a [Field], Option<&'a [i32]>, UnionMode);
14
15/// [`UnionArray`] represents an array whose each slot can contain different values.
16///
17// How to read a value at slot i:
18// ```
19// let index = self.types()[i] as usize;
20// let field = self.fields()[index];
21// let offset = self.offsets().map(|x| x[index]).unwrap_or(i);
22// let field = field.as_any().downcast to correct type;
23// let value = field.value(offset);
24// ```
25#[derive(Clone)]
26pub struct UnionArray {
27    // Invariant: every item in `types` is `> 0 && < fields.len()`
28    types: Buffer<i8>,
29    // Invariant: `map.len() == fields.len()`
30    // Invariant: every item in `map` is `> 0 && < fields.len()`
31    map: Option<[usize; 127]>,
32    fields: Vec<Box<dyn Array>>,
33    // Invariant: when set, `offsets.len() == types.len()`
34    offsets: Option<Buffer<i32>>,
35    dtype: ArrowDataType,
36    offset: usize,
37}
38
39impl UnionArray {
40    /// Returns a new [`UnionArray`].
41    /// # Errors
42    /// This function errors iff:
43    /// * `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Union`].
44    /// * the fields's len is different from the `dtype`'s children's length
45    /// * The number of `fields` is larger than `i8::MAX`
46    /// * any of the values's data type is different from its corresponding children' data type
47    pub fn try_new(
48        dtype: ArrowDataType,
49        types: Buffer<i8>,
50        fields: Vec<Box<dyn Array>>,
51        offsets: Option<Buffer<i32>>,
52    ) -> PolarsResult<Self> {
53        let (f, ids, mode) = Self::try_get_all(&dtype)?;
54
55        if f.len() != fields.len() {
56            polars_bail!(ComputeError: "the number of `fields` must equal the number of children fields in DataType::Union")
57        };
58        let number_of_fields: i8 = fields.len().try_into().map_err(
59            |_| polars_err!(ComputeError: "the number of `fields` cannot be larger than i8::MAX"),
60        )?;
61
62        f
63            .iter().map(|a| a.dtype())
64            .zip(fields.iter().map(|a| a.dtype()))
65            .enumerate()
66            .try_for_each(|(index, (dtype, child))| {
67                if dtype != child {
68                    polars_bail!(ComputeError:
69                        "the children DataTypes of a UnionArray must equal the children data types.
70                         However, the field {index} has data type {dtype:?} but the value has data type {child:?}"
71                    )
72                } else {
73                    Ok(())
74                }
75            })?;
76
77        if let Some(offsets) = &offsets {
78            if offsets.len() != types.len() {
79                polars_bail!(ComputeError:
80                "in a UnionArray, the offsets' length must be equal to the number of types"
81                )
82            }
83        }
84        if offsets.is_none() != mode.is_sparse() {
85            polars_bail!(ComputeError:
86            "in a sparse UnionArray, the offsets must be set (and vice-versa)",
87                )
88        }
89
90        // build hash
91        let map = if let Some(&ids) = ids.as_ref() {
92            if ids.len() != fields.len() {
93                polars_bail!(ComputeError:
94                "in a union, when the ids are set, their length must be equal to the number of fields",
95                )
96            }
97
98            // example:
99            // * types = [5, 7, 5, 7, 7, 7, 5, 7, 7, 5, 5]
100            // * ids = [5, 7]
101            // => hash = [0, 0, 0, 0, 0, 0, 1, 0, ...]
102            let mut hash = [0; 127];
103
104            for (pos, &id) in ids.iter().enumerate() {
105                if !(0..=127).contains(&id) {
106                    polars_bail!(ComputeError:
107                        "in a union, when the ids are set, every id must belong to [0, 128[",
108                    )
109                }
110                hash[id as usize] = pos;
111            }
112
113            types.iter().try_for_each(|&type_| {
114                if type_ < 0 {
115                    polars_bail!(ComputeError:
116                        "in a union, when the ids are set, every type must be >= 0"
117                    )
118                }
119                let id = hash[type_ as usize];
120                if id >= fields.len() {
121                    polars_bail!(ComputeError:
122    "in a union, when the ids are set, each id must be smaller than the number of fields."
123                    )
124                } else {
125                    Ok(())
126                }
127            })?;
128
129            Some(hash)
130        } else {
131            // SAFETY: every type in types is smaller than number of fields
132            let mut is_valid = true;
133            for &type_ in types.iter() {
134                if type_ < 0 || type_ >= number_of_fields {
135                    is_valid = false
136                }
137            }
138            if !is_valid {
139                polars_bail!(ComputeError:
140                    "every type in `types` must be larger than 0 and smaller than the number of fields.",
141                )
142            }
143
144            None
145        };
146
147        Ok(Self {
148            dtype,
149            map,
150            fields,
151            offsets,
152            types,
153            offset: 0,
154        })
155    }
156
157    /// Returns a new [`UnionArray`].
158    /// # Panics
159    /// This function panics iff:
160    /// * `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Union`].
161    /// * the fields's len is different from the `dtype`'s children's length
162    /// * any of the values's data type is different from its corresponding children' data type
163    pub fn new(
164        dtype: ArrowDataType,
165        types: Buffer<i8>,
166        fields: Vec<Box<dyn Array>>,
167        offsets: Option<Buffer<i32>>,
168    ) -> Self {
169        Self::try_new(dtype, types, fields, offsets).unwrap()
170    }
171
172    /// Creates a new null [`UnionArray`].
173    pub fn new_null(dtype: ArrowDataType, length: usize) -> Self {
174        if let ArrowDataType::Union(u) = &dtype {
175            let fields = u
176                .fields
177                .iter()
178                .map(|x| new_null_array(x.dtype().clone(), length))
179                .collect();
180
181            let offsets = if u.mode.is_sparse() {
182                None
183            } else {
184                Some((0..length as i32).collect::<Vec<_>>().into())
185            };
186
187            // all from the same field
188            let types = vec![0i8; length].into();
189
190            Self::new(dtype, types, fields, offsets)
191        } else {
192            panic!("Union struct must be created with the corresponding Union DataType")
193        }
194    }
195
196    /// Creates a new empty [`UnionArray`].
197    pub fn new_empty(dtype: ArrowDataType) -> Self {
198        if let ArrowDataType::Union(u) = dtype.to_logical_type() {
199            let fields = u
200                .fields
201                .iter()
202                .map(|x| new_empty_array(x.dtype().clone()))
203                .collect();
204
205            let offsets = if u.mode.is_sparse() {
206                None
207            } else {
208                Some(Buffer::default())
209            };
210
211            Self {
212                dtype,
213                map: None,
214                fields,
215                offsets,
216                types: Buffer::new(),
217                offset: 0,
218            }
219        } else {
220            panic!("Union struct must be created with the corresponding Union DataType")
221        }
222    }
223}
224
225impl UnionArray {
226    /// Returns a slice of this [`UnionArray`].
227    /// # Implementation
228    /// This operation is `O(F)` where `F` is the number of fields.
229    /// # Panic
230    /// This function panics iff `offset + length > self.len()`.
231    #[inline]
232    pub fn slice(&mut self, offset: usize, length: usize) {
233        assert!(
234            offset + length <= self.len(),
235            "the offset of the new array cannot exceed the existing length"
236        );
237        unsafe { self.slice_unchecked(offset, length) }
238    }
239
240    /// Returns a slice of this [`UnionArray`].
241    /// # Implementation
242    /// This operation is `O(F)` where `F` is the number of fields.
243    ///
244    /// # Safety
245    /// The caller must ensure that `offset + length <= self.len()`.
246    #[inline]
247    pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) {
248        debug_assert!(offset + length <= self.len());
249
250        self.types.slice_unchecked(offset, length);
251        if let Some(offsets) = self.offsets.as_mut() {
252            offsets.slice_unchecked(offset, length)
253        }
254        self.offset += offset;
255    }
256
257    impl_sliced!();
258    impl_into_array!();
259}
260
261impl UnionArray {
262    /// Returns the length of this array
263    #[inline]
264    pub fn len(&self) -> usize {
265        self.types.len()
266    }
267
268    /// The optional offsets.
269    pub fn offsets(&self) -> Option<&Buffer<i32>> {
270        self.offsets.as_ref()
271    }
272
273    /// The fields.
274    pub fn fields(&self) -> &Vec<Box<dyn Array>> {
275        &self.fields
276    }
277
278    /// The types.
279    pub fn types(&self) -> &Buffer<i8> {
280        &self.types
281    }
282
283    #[inline]
284    unsafe fn field_slot_unchecked(&self, index: usize) -> usize {
285        self.offsets()
286            .as_ref()
287            .map(|x| *x.get_unchecked(index) as usize)
288            .unwrap_or(index + self.offset)
289    }
290
291    /// Returns the index and slot of the field to select from `self.fields`.
292    #[inline]
293    pub fn index(&self, index: usize) -> (usize, usize) {
294        assert!(index < self.len());
295        unsafe { self.index_unchecked(index) }
296    }
297
298    /// Returns the index and slot of the field to select from `self.fields`.
299    /// The first value is guaranteed to be `< self.fields().len()`
300    ///
301    /// # Safety
302    /// This function is safe iff `index < self.len`.
303    #[inline]
304    pub unsafe fn index_unchecked(&self, index: usize) -> (usize, usize) {
305        debug_assert!(index < self.len());
306        // SAFETY: assumption of the function
307        let type_ = unsafe { *self.types.get_unchecked(index) };
308        // SAFETY: assumption of the struct
309        let type_ = self
310            .map
311            .as_ref()
312            .map(|map| unsafe { *map.get_unchecked(type_ as usize) })
313            .unwrap_or(type_ as usize);
314        // SAFETY: assumption of the function
315        let index = self.field_slot_unchecked(index);
316        (type_, index)
317    }
318
319    /// Returns the slot `index` as a [`Scalar`].
320    /// # Panics
321    /// iff `index >= self.len()`
322    pub fn value(&self, index: usize) -> Box<dyn Scalar> {
323        assert!(index < self.len());
324        unsafe { self.value_unchecked(index) }
325    }
326
327    /// Returns the slot `index` as a [`Scalar`].
328    ///
329    /// # Safety
330    /// This function is safe iff `i < self.len`.
331    pub unsafe fn value_unchecked(&self, index: usize) -> Box<dyn Scalar> {
332        debug_assert!(index < self.len());
333        let (type_, index) = self.index_unchecked(index);
334        // SAFETY: assumption of the struct
335        debug_assert!(type_ < self.fields.len());
336        let field = self.fields.get_unchecked(type_).as_ref();
337        new_scalar(field, index)
338    }
339}
340
341impl Array for UnionArray {
342    impl_common_array!();
343
344    fn validity(&self) -> Option<&Bitmap> {
345        None
346    }
347
348    fn with_validity(&self, _: Option<Bitmap>) -> Box<dyn Array> {
349        panic!("cannot set validity of a union array")
350    }
351}
352
353impl UnionArray {
354    fn try_get_all(dtype: &ArrowDataType) -> PolarsResult<UnionComponents> {
355        match dtype.to_logical_type() {
356            ArrowDataType::Union(u) => Ok((&u.fields, u.ids.as_ref().map(|x| x.as_ref()), u.mode)),
357            _ => polars_bail!(ComputeError:
358                "The UnionArray requires a logical type of DataType::Union",
359            ),
360        }
361    }
362
363    fn get_all(dtype: &ArrowDataType) -> (&[Field], Option<&[i32]>, UnionMode) {
364        Self::try_get_all(dtype).unwrap()
365    }
366
367    /// Returns all fields from [`ArrowDataType::Union`].
368    /// # Panic
369    /// Panics iff `dtype`'s logical type is not [`ArrowDataType::Union`].
370    pub fn get_fields(dtype: &ArrowDataType) -> &[Field] {
371        Self::get_all(dtype).0
372    }
373
374    /// Returns whether the [`ArrowDataType::Union`] is sparse or not.
375    /// # Panic
376    /// Panics iff `dtype`'s logical type is not [`ArrowDataType::Union`].
377    pub fn is_sparse(dtype: &ArrowDataType) -> bool {
378        Self::get_all(dtype).2.is_sparse()
379    }
380}
381
382impl Splitable for UnionArray {
383    fn check_bound(&self, offset: usize) -> bool {
384        offset <= self.len()
385    }
386
387    unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) {
388        let (lhs_types, rhs_types) = unsafe { self.types.split_at_unchecked(offset) };
389        let (lhs_offsets, rhs_offsets) = self.offsets.as_ref().map_or((None, None), |v| {
390            let (lhs, rhs) = unsafe { v.split_at_unchecked(offset) };
391            (Some(lhs), Some(rhs))
392        });
393
394        (
395            Self {
396                types: lhs_types,
397                map: self.map,
398                fields: self.fields.clone(),
399                offsets: lhs_offsets,
400                dtype: self.dtype.clone(),
401                offset: self.offset,
402            },
403            Self {
404                types: rhs_types,
405                map: self.map,
406                fields: self.fields.clone(),
407                offsets: rhs_offsets,
408                dtype: self.dtype.clone(),
409                offset: self.offset + offset,
410            },
411        )
412    }
413}