polars_compute/horizontal_flatten/
mod.rs1use arrow::array::{
2 Array, ArrayCollectIterExt, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray,
3 ListArray, NullArray, PrimitiveArray, StaticArray, StructArray, Utf8ViewArray,
4};
5use arrow::bitmap::Bitmap;
6use arrow::datatypes::{ArrowDataType, PhysicalType};
7use arrow::with_match_primitive_type_full;
8use strength_reduce::StrengthReducedUsize;
9mod struct_;
10
11pub unsafe fn horizontal_flatten_unchecked(
23 arrays: &[Box<dyn Array>],
24 widths: &[usize],
25 output_height: usize,
26) -> Box<dyn Array> {
27 use PhysicalType::*;
28
29 let dtype = arrays[0].dtype();
30
31 match dtype.to_physical_type() {
32 Null => Box::new(NullArray::new(
33 dtype.clone(),
34 output_height * widths.iter().copied().sum::<usize>(),
35 )),
36 Boolean => Box::new(horizontal_flatten_unchecked_impl_generic(
37 &arrays
38 .iter()
39 .map(|x| x.as_any().downcast_ref::<BooleanArray>().unwrap().clone())
40 .collect::<Vec<_>>(),
41 widths,
42 output_height,
43 dtype,
44 )),
45 Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {
46 Box::new(horizontal_flatten_unchecked_impl_generic(
47 &arrays
48 .iter()
49 .map(|x| x.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap().clone())
50 .collect::<Vec<_>>(),
51 widths,
52 output_height,
53 dtype
54 ))
55 }),
56 LargeBinary => Box::new(horizontal_flatten_unchecked_impl_generic(
57 &arrays
58 .iter()
59 .map(|x| {
60 x.as_any()
61 .downcast_ref::<BinaryArray<i64>>()
62 .unwrap()
63 .clone()
64 })
65 .collect::<Vec<_>>(),
66 widths,
67 output_height,
68 dtype,
69 )),
70 Struct => Box::new(struct_::horizontal_flatten_unchecked(
71 &arrays
72 .iter()
73 .map(|x| x.as_any().downcast_ref::<StructArray>().unwrap().clone())
74 .collect::<Vec<_>>(),
75 widths,
76 output_height,
77 )),
78 LargeList => Box::new(horizontal_flatten_unchecked_impl_generic(
79 &arrays
80 .iter()
81 .map(|x| x.as_any().downcast_ref::<ListArray<i64>>().unwrap().clone())
82 .collect::<Vec<_>>(),
83 widths,
84 output_height,
85 dtype,
86 )),
87 FixedSizeList => Box::new(horizontal_flatten_unchecked_impl_generic(
88 &arrays
89 .iter()
90 .map(|x| {
91 x.as_any()
92 .downcast_ref::<FixedSizeListArray>()
93 .unwrap()
94 .clone()
95 })
96 .collect::<Vec<_>>(),
97 widths,
98 output_height,
99 dtype,
100 )),
101 BinaryView => Box::new(horizontal_flatten_unchecked_impl_generic(
102 &arrays
103 .iter()
104 .map(|x| {
105 x.as_any()
106 .downcast_ref::<BinaryViewArray>()
107 .unwrap()
108 .clone()
109 })
110 .collect::<Vec<_>>(),
111 widths,
112 output_height,
113 dtype,
114 )),
115 Utf8View => Box::new(horizontal_flatten_unchecked_impl_generic(
116 &arrays
117 .iter()
118 .map(|x| x.as_any().downcast_ref::<Utf8ViewArray>().unwrap().clone())
119 .collect::<Vec<_>>(),
120 widths,
121 output_height,
122 dtype,
123 )),
124 t => unimplemented!("horizontal_flatten not supported for data type {:?}", t),
125 }
126}
127
128unsafe fn horizontal_flatten_unchecked_impl_generic<T>(
129 arrays: &[T],
130 widths: &[usize],
131 output_height: usize,
132 dtype: &ArrowDataType,
133) -> T
134where
135 T: StaticArray,
136{
137 assert!(!arrays.is_empty());
138 assert_eq!(widths.len(), arrays.len());
139
140 debug_assert!(widths.iter().all(|x| *x > 0));
141 debug_assert!(arrays
142 .iter()
143 .zip(widths)
144 .all(|(arr, width)| arr.len() == output_height * *width || arr.len() == *width));
145
146 let lengths = arrays
148 .iter()
149 .map(|x| StrengthReducedUsize::new(x.len()))
150 .collect::<Vec<_>>();
151 let out_row_width: usize = widths.iter().cloned().sum();
152 let out_len = out_row_width.checked_mul(output_height).unwrap();
153
154 let mut col_idx = 0;
155 let mut row_idx = 0;
156 let mut until = widths[0];
157 let mut outer_row_idx = 0;
158
159 (0..out_len)
161 .map(|_| {
162 let arr = arrays.get_unchecked(col_idx);
163 let out = arr.get_unchecked(row_idx % *lengths.get_unchecked(col_idx));
164
165 row_idx += 1;
166
167 if row_idx == until {
168 col_idx = if 1 + col_idx == widths.len() {
170 outer_row_idx += 1;
171 0
172 } else {
173 1 + col_idx
174 };
175 row_idx = outer_row_idx * *widths.get_unchecked(col_idx);
176 until = (1 + outer_row_idx) * *widths.get_unchecked(col_idx)
177 }
178
179 out
180 })
181 .collect_arr_trusted_with_dtype(dtype.clone())
182}