polars_arrow/array/growable/
fixed_size_list.rs

1use std::sync::Arc;
2
3use super::{make_growable, Growable};
4use crate::array::growable::utils::{extend_validity, extend_validity_copies, prepare_validity};
5use crate::array::{Array, FixedSizeListArray};
6use crate::bitmap::BitmapBuilder;
7
8/// Concrete [`Growable`] for the [`FixedSizeListArray`].
9pub struct GrowableFixedSizeList<'a> {
10    arrays: Vec<&'a FixedSizeListArray>,
11    validity: Option<BitmapBuilder>,
12    values: Box<dyn Growable<'a> + 'a>,
13    size: usize,
14    length: usize,
15}
16
17impl<'a> GrowableFixedSizeList<'a> {
18    /// Creates a new [`GrowableFixedSizeList`] bound to `arrays` with a pre-allocated `capacity`.
19    /// # Panics
20    /// If `arrays` is empty.
21    pub fn new(
22        arrays: Vec<&'a FixedSizeListArray>,
23        mut use_validity: bool,
24        capacity: usize,
25    ) -> Self {
26        assert!(!arrays.is_empty());
27
28        // if any of the arrays has nulls, insertions from any array requires setting bits
29        // as there is at least one array with nulls.
30        if !use_validity & arrays.iter().any(|array| array.null_count() > 0) {
31            use_validity = true;
32        };
33
34        let size = arrays[0].size();
35
36        let inner = arrays
37            .iter()
38            .map(|array| {
39                debug_assert_eq!(array.size(), size);
40                array.values().as_ref()
41            })
42            .collect::<Vec<_>>();
43        let values = make_growable(&inner, use_validity, 0);
44
45        assert_eq!(values.len(), 0);
46
47        Self {
48            arrays,
49            values,
50            validity: prepare_validity(use_validity, capacity),
51            size,
52            length: 0,
53        }
54    }
55
56    pub fn to(&mut self) -> FixedSizeListArray {
57        let validity = std::mem::take(&mut self.validity);
58        let values = self.values.as_box();
59
60        FixedSizeListArray::new(
61            self.arrays[0].dtype().clone(),
62            self.length,
63            values,
64            validity.map(|v| v.freeze()),
65        )
66    }
67}
68
69impl<'a> Growable<'a> for GrowableFixedSizeList<'a> {
70    unsafe fn extend(&mut self, index: usize, start: usize, len: usize) {
71        let array = *self.arrays.get_unchecked(index);
72        extend_validity(&mut self.validity, array, start, len);
73
74        self.length += len;
75        let start_length = self.values.len();
76        self.values
77            .extend(index, start * self.size, len * self.size);
78        debug_assert!(self.size == 0 || (self.values.len() - start_length) / self.size == len);
79    }
80
81    unsafe fn extend_copies(&mut self, index: usize, start: usize, len: usize, copies: usize) {
82        let array = *self.arrays.get_unchecked(index);
83        extend_validity_copies(&mut self.validity, array, start, len, copies);
84
85        self.length += len * copies;
86        let start_length = self.values.len();
87        self.values
88            .extend_copies(index, start * self.size, len * self.size, copies);
89        debug_assert!(
90            self.size == 0 || (self.values.len() - start_length) / self.size == len * copies
91        );
92    }
93
94    fn extend_validity(&mut self, additional: usize) {
95        self.values.extend_validity(additional * self.size);
96        if let Some(validity) = &mut self.validity {
97            validity.extend_constant(additional, false);
98        }
99        self.length += additional;
100    }
101
102    #[inline]
103    fn len(&self) -> usize {
104        self.length
105    }
106
107    fn as_arc(&mut self) -> Arc<dyn Array> {
108        Arc::new(self.to())
109    }
110
111    fn as_box(&mut self) -> Box<dyn Array> {
112        Box::new(self.to())
113    }
114}
115
116impl<'a> From<GrowableFixedSizeList<'a>> for FixedSizeListArray {
117    fn from(val: GrowableFixedSizeList<'a>) -> Self {
118        let mut values = val.values;
119        let values = values.as_box();
120
121        Self::new(
122            val.arrays[0].dtype().clone(),
123            val.length,
124            values,
125            val.validity.map(|v| v.freeze()),
126        )
127    }
128}