polars_compute/horizontal_flatten/
mod.rs

1use arrow::array::{
2    Array, ArrayCollectIterExt, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray,
3    ListArray, NullArray, PrimitiveArray, StaticArray, StructArray, Utf8ViewArray,
4};
5use arrow::bitmap::Bitmap;
6use arrow::datatypes::{ArrowDataType, PhysicalType};
7use arrow::with_match_primitive_type_full;
8use strength_reduce::StrengthReducedUsize;
9mod struct_;
10
11/// Low-level operation used by `concat_arr`. This should be called with the inner values array of
12/// every FixedSizeList array.
13///
14/// # Safety
15/// * `arrays` is non-empty
16/// * `arrays` and `widths` have equal length
17/// * All widths in `widths` are non-zero
18/// * Every array `arrays[i]` has a length of either
19///   * `widths[i] * output_height`
20///   * `widths[i]` (this would be broadcasted)
21/// * All arrays in `arrays` have the same type
22pub unsafe fn horizontal_flatten_unchecked(
23    arrays: &[Box<dyn Array>],
24    widths: &[usize],
25    output_height: usize,
26) -> Box<dyn Array> {
27    use PhysicalType::*;
28
29    let dtype = arrays[0].dtype();
30
31    match dtype.to_physical_type() {
32        Null => Box::new(NullArray::new(
33            dtype.clone(),
34            output_height * widths.iter().copied().sum::<usize>(),
35        )),
36        Boolean => Box::new(horizontal_flatten_unchecked_impl_generic(
37            &arrays
38                .iter()
39                .map(|x| x.as_any().downcast_ref::<BooleanArray>().unwrap().clone())
40                .collect::<Vec<_>>(),
41            widths,
42            output_height,
43            dtype,
44        )),
45        Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {
46            Box::new(horizontal_flatten_unchecked_impl_generic(
47                &arrays
48                    .iter()
49                    .map(|x| x.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap().clone())
50                    .collect::<Vec<_>>(),
51                widths,
52                output_height,
53                dtype
54            ))
55        }),
56        LargeBinary => Box::new(horizontal_flatten_unchecked_impl_generic(
57            &arrays
58                .iter()
59                .map(|x| {
60                    x.as_any()
61                        .downcast_ref::<BinaryArray<i64>>()
62                        .unwrap()
63                        .clone()
64                })
65                .collect::<Vec<_>>(),
66            widths,
67            output_height,
68            dtype,
69        )),
70        Struct => Box::new(struct_::horizontal_flatten_unchecked(
71            &arrays
72                .iter()
73                .map(|x| x.as_any().downcast_ref::<StructArray>().unwrap().clone())
74                .collect::<Vec<_>>(),
75            widths,
76            output_height,
77        )),
78        LargeList => Box::new(horizontal_flatten_unchecked_impl_generic(
79            &arrays
80                .iter()
81                .map(|x| x.as_any().downcast_ref::<ListArray<i64>>().unwrap().clone())
82                .collect::<Vec<_>>(),
83            widths,
84            output_height,
85            dtype,
86        )),
87        FixedSizeList => Box::new(horizontal_flatten_unchecked_impl_generic(
88            &arrays
89                .iter()
90                .map(|x| {
91                    x.as_any()
92                        .downcast_ref::<FixedSizeListArray>()
93                        .unwrap()
94                        .clone()
95                })
96                .collect::<Vec<_>>(),
97            widths,
98            output_height,
99            dtype,
100        )),
101        BinaryView => Box::new(horizontal_flatten_unchecked_impl_generic(
102            &arrays
103                .iter()
104                .map(|x| {
105                    x.as_any()
106                        .downcast_ref::<BinaryViewArray>()
107                        .unwrap()
108                        .clone()
109                })
110                .collect::<Vec<_>>(),
111            widths,
112            output_height,
113            dtype,
114        )),
115        Utf8View => Box::new(horizontal_flatten_unchecked_impl_generic(
116            &arrays
117                .iter()
118                .map(|x| x.as_any().downcast_ref::<Utf8ViewArray>().unwrap().clone())
119                .collect::<Vec<_>>(),
120            widths,
121            output_height,
122            dtype,
123        )),
124        t => unimplemented!("horizontal_flatten not supported for data type {:?}", t),
125    }
126}
127
128unsafe fn horizontal_flatten_unchecked_impl_generic<T>(
129    arrays: &[T],
130    widths: &[usize],
131    output_height: usize,
132    dtype: &ArrowDataType,
133) -> T
134where
135    T: StaticArray,
136{
137    assert!(!arrays.is_empty());
138    assert_eq!(widths.len(), arrays.len());
139
140    debug_assert!(widths.iter().all(|x| *x > 0));
141    debug_assert!(arrays
142        .iter()
143        .zip(widths)
144        .all(|(arr, width)| arr.len() == output_height * *width || arr.len() == *width));
145
146    // We modulo the array length to support broadcasting.
147    let lengths = arrays
148        .iter()
149        .map(|x| StrengthReducedUsize::new(x.len()))
150        .collect::<Vec<_>>();
151    let out_row_width: usize = widths.iter().cloned().sum();
152    let out_len = out_row_width.checked_mul(output_height).unwrap();
153
154    let mut col_idx = 0;
155    let mut row_idx = 0;
156    let mut until = widths[0];
157    let mut outer_row_idx = 0;
158
159    // We do `0..out_len` to get an `ExactSizeIterator`.
160    (0..out_len)
161        .map(|_| {
162            let arr = arrays.get_unchecked(col_idx);
163            let out = arr.get_unchecked(row_idx % *lengths.get_unchecked(col_idx));
164
165            row_idx += 1;
166
167            if row_idx == until {
168                // Safety: All widths are non-zero so we only need to increment once.
169                col_idx = if 1 + col_idx == widths.len() {
170                    outer_row_idx += 1;
171                    0
172                } else {
173                    1 + col_idx
174                };
175                row_idx = outer_row_idx * *widths.get_unchecked(col_idx);
176                until = (1 + outer_row_idx) * *widths.get_unchecked(col_idx)
177            }
178
179            out
180        })
181        .collect_arr_trusted_with_dtype(dtype.clone())
182}