1use std::collections::VecDeque;
2use std::io::{Read, Seek};
3use std::sync::Arc;
4
5use polars_error::{polars_bail, polars_err, PolarsResult};
6use polars_utils::aliases::PlHashMap;
7use polars_utils::pl_str::PlSmallStr;
8
9use super::deserialize::{read, skip};
10use super::Dictionaries;
11use crate::array::*;
12use crate::datatypes::{ArrowDataType, ArrowSchema, Field};
13use crate::io::ipc::read::OutOfSpecKind;
14use crate::io::ipc::{IpcField, IpcSchema};
15use crate::record_batch::RecordBatchT;
16
17#[derive(Debug, Eq, PartialEq, Hash)]
18enum ProjectionResult<A> {
19 Selected(A),
20 NotSelected(A),
21}
22
23struct ProjectionIter<'a, A, I: Iterator<Item = A>> {
27 projection: &'a [usize],
28 iter: I,
29 current_count: usize,
30 current_projection: usize,
31}
32
33impl<'a, A, I: Iterator<Item = A>> ProjectionIter<'a, A, I> {
34 pub fn new(projection: &'a [usize], iter: I) -> Self {
37 Self {
38 projection: &projection[1..],
39 iter,
40 current_count: 0,
41 current_projection: projection[0],
42 }
43 }
44}
45
46impl<A, I: Iterator<Item = A>> Iterator for ProjectionIter<'_, A, I> {
47 type Item = ProjectionResult<A>;
48
49 fn next(&mut self) -> Option<Self::Item> {
50 if let Some(item) = self.iter.next() {
51 let result = if self.current_count == self.current_projection {
52 if !self.projection.is_empty() {
53 assert!(self.projection[0] > self.current_projection);
54 self.current_projection = self.projection[0];
55 self.projection = &self.projection[1..];
56 } else {
57 self.current_projection = 0 };
59 Some(ProjectionResult::Selected(item))
60 } else {
61 Some(ProjectionResult::NotSelected(item))
62 };
63 self.current_count += 1;
64 result
65 } else {
66 None
67 }
68 }
69
70 fn size_hint(&self) -> (usize, Option<usize>) {
71 self.iter.size_hint()
72 }
73}
74
75#[allow(clippy::too_many_arguments)]
79pub fn read_record_batch<R: Read + Seek>(
80 batch: arrow_format::ipc::RecordBatchRef,
81 fields: &ArrowSchema,
82 ipc_schema: &IpcSchema,
83 projection: Option<&[usize]>,
84 limit: Option<usize>,
85 dictionaries: &Dictionaries,
86 version: arrow_format::ipc::MetadataVersion,
87 reader: &mut R,
88 block_offset: u64,
89 file_size: u64,
90 scratch: &mut Vec<u8>,
91) -> PolarsResult<RecordBatchT<Box<dyn Array>>> {
92 assert_eq!(fields.len(), ipc_schema.fields.len());
93 let buffers = batch
94 .buffers()
95 .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBuffers(err)))?
96 .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageBuffers))?;
97 let mut variadic_buffer_counts = batch
98 .variadic_buffer_counts()
99 .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)))?
100 .map(|v| v.iter().map(|v| v as usize).collect::<VecDeque<usize>>())
101 .unwrap_or_else(VecDeque::new);
102 let mut buffers: VecDeque<arrow_format::ipc::BufferRef> = buffers.iter().collect();
103
104 let buffers_size = buffers
106 .iter()
107 .map(|buffer| {
108 let buffer_size: u64 = buffer
109 .length()
110 .try_into()
111 .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
112 Ok(buffer_size)
113 })
114 .sum::<PolarsResult<u64>>()?;
115 if buffers_size > file_size {
116 return Err(polars_err!(
117 oos = OutOfSpecKind::InvalidBuffersLength {
118 buffers_size,
119 file_size,
120 }
121 ));
122 }
123
124 let field_nodes = batch
125 .nodes()
126 .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferNodes(err)))?
127 .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageNodes))?;
128 let mut field_nodes = field_nodes.iter().collect::<VecDeque<_>>();
129
130 let columns = if let Some(projection) = projection {
131 let projection = ProjectionIter::new(
132 projection,
133 fields.iter_values().zip(ipc_schema.fields.iter()),
134 );
135
136 projection
137 .map(|maybe_field| match maybe_field {
138 ProjectionResult::Selected((field, ipc_field)) => Ok(Some(read(
139 &mut field_nodes,
140 &mut variadic_buffer_counts,
141 field,
142 ipc_field,
143 &mut buffers,
144 reader,
145 dictionaries,
146 block_offset,
147 ipc_schema.is_little_endian,
148 batch.compression().map_err(|err| {
149 polars_err!(oos = OutOfSpecKind::InvalidFlatbufferCompression(err))
150 })?,
151 limit,
152 version,
153 scratch,
154 )?)),
155 ProjectionResult::NotSelected((field, _)) => {
156 skip(
157 &mut field_nodes,
158 &field.dtype,
159 &mut buffers,
160 &mut variadic_buffer_counts,
161 )?;
162 Ok(None)
163 },
164 })
165 .filter_map(|x| x.transpose())
166 .collect::<PolarsResult<Vec<_>>>()?
167 } else {
168 fields
169 .iter_values()
170 .zip(ipc_schema.fields.iter())
171 .map(|(field, ipc_field)| {
172 read(
173 &mut field_nodes,
174 &mut variadic_buffer_counts,
175 field,
176 ipc_field,
177 &mut buffers,
178 reader,
179 dictionaries,
180 block_offset,
181 ipc_schema.is_little_endian,
182 batch.compression().map_err(|err| {
183 polars_err!(oos = OutOfSpecKind::InvalidFlatbufferCompression(err))
184 })?,
185 limit,
186 version,
187 scratch,
188 )
189 })
190 .collect::<PolarsResult<Vec<_>>>()?
191 };
192
193 let length = batch
194 .length()
195 .map_err(|_| polars_err!(oos = OutOfSpecKind::MissingData))
196 .unwrap()
197 .try_into()
198 .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
199 let length = limit.map(|limit| limit.min(length)).unwrap_or(length);
200
201 let mut schema: ArrowSchema = fields.iter_values().cloned().collect();
202 if let Some(projection) = projection {
203 schema = schema.try_project_indices(projection).unwrap();
204 }
205 RecordBatchT::try_new(length, Arc::new(schema), columns)
206}
207
208fn find_first_dict_field_d<'a>(
209 id: i64,
210 dtype: &'a ArrowDataType,
211 ipc_field: &'a IpcField,
212) -> Option<(&'a Field, &'a IpcField)> {
213 use ArrowDataType::*;
214 match dtype {
215 Dictionary(_, inner, _) => find_first_dict_field_d(id, inner.as_ref(), ipc_field),
216 List(field) | LargeList(field) | FixedSizeList(field, ..) | Map(field, ..) => {
217 find_first_dict_field(id, field.as_ref(), &ipc_field.fields[0])
218 },
219 Struct(fields) => {
220 for (field, ipc_field) in fields.iter().zip(ipc_field.fields.iter()) {
221 if let Some(f) = find_first_dict_field(id, field, ipc_field) {
222 return Some(f);
223 }
224 }
225 None
226 },
227 Union(u) => {
228 for (field, ipc_field) in u.fields.iter().zip(ipc_field.fields.iter()) {
229 if let Some(f) = find_first_dict_field(id, field, ipc_field) {
230 return Some(f);
231 }
232 }
233 None
234 },
235 _ => None,
236 }
237}
238
239fn find_first_dict_field<'a>(
240 id: i64,
241 field: &'a Field,
242 ipc_field: &'a IpcField,
243) -> Option<(&'a Field, &'a IpcField)> {
244 if let Some(field_id) = ipc_field.dictionary_id {
245 if id == field_id {
246 return Some((field, ipc_field));
247 }
248 }
249 find_first_dict_field_d(id, &field.dtype, ipc_field)
250}
251
252pub(crate) fn first_dict_field<'a>(
253 id: i64,
254 fields: &'a ArrowSchema,
255 ipc_fields: &'a [IpcField],
256) -> PolarsResult<(&'a Field, &'a IpcField)> {
257 assert_eq!(fields.len(), ipc_fields.len());
258 for (field, ipc_field) in fields.iter_values().zip(ipc_fields.iter()) {
259 if let Some(field) = find_first_dict_field(id, field, ipc_field) {
260 return Ok(field);
261 }
262 }
263 Err(polars_err!(
264 oos = OutOfSpecKind::InvalidId { requested_id: id }
265 ))
266}
267
268#[allow(clippy::too_many_arguments)]
271pub fn read_dictionary<R: Read + Seek>(
272 batch: arrow_format::ipc::DictionaryBatchRef,
273 fields: &ArrowSchema,
274 ipc_schema: &IpcSchema,
275 dictionaries: &mut Dictionaries,
276 reader: &mut R,
277 block_offset: u64,
278 file_size: u64,
279 scratch: &mut Vec<u8>,
280) -> PolarsResult<()> {
281 if batch
282 .is_delta()
283 .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferIsDelta(err)))?
284 {
285 polars_bail!(ComputeError: "delta dictionary batches not supported")
286 }
287
288 let id = batch
289 .id()
290 .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferId(err)))?;
291 let (first_field, first_ipc_field) = first_dict_field(id, fields, &ipc_schema.fields)?;
292
293 let batch = batch
294 .data()
295 .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferData(err)))?
296 .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingData))?;
297
298 let value_type =
299 if let ArrowDataType::Dictionary(_, value_type, _) = first_field.dtype.to_logical_type() {
300 value_type.as_ref()
301 } else {
302 polars_bail!(oos = OutOfSpecKind::InvalidIdDataType { requested_id: id })
303 };
304
305 let fields = std::iter::once((
307 PlSmallStr::EMPTY,
308 Field::new(PlSmallStr::EMPTY, value_type.clone(), false),
309 ))
310 .collect();
311 let ipc_schema = IpcSchema {
312 fields: vec![first_ipc_field.clone()],
313 is_little_endian: ipc_schema.is_little_endian,
314 };
315 let chunk = read_record_batch(
316 batch,
317 &fields,
318 &ipc_schema,
319 None,
320 None, dictionaries,
322 arrow_format::ipc::MetadataVersion::V5,
323 reader,
324 block_offset,
325 file_size,
326 scratch,
327 )?;
328
329 dictionaries.insert(id, chunk.into_arrays().pop().unwrap());
330
331 Ok(())
332}
333
334#[derive(Clone)]
335pub struct ProjectionInfo {
336 pub columns: Vec<usize>,
337 pub map: PlHashMap<usize, usize>,
338 pub schema: ArrowSchema,
339}
340
341pub fn prepare_projection(schema: &ArrowSchema, mut projection: Vec<usize>) -> ProjectionInfo {
342 let schema = projection
343 .iter()
344 .map(|x| {
345 let (k, v) = schema.get_at_index(*x).unwrap();
346 (k.clone(), v.clone())
347 })
348 .collect();
349
350 let mut indices = (0..projection.len()).collect::<Vec<_>>();
352 indices.sort_unstable_by_key(|&i| &projection[i]);
353 let map = indices.iter().copied().enumerate().fold(
354 PlHashMap::default(),
355 |mut acc, (index, new_index)| {
356 acc.insert(index, new_index);
357 acc
358 },
359 );
360 projection.sort_unstable();
361
362 if !projection.is_empty() {
364 let mut previous = projection[0];
365
366 for &i in &projection[1..] {
367 assert!(
368 previous < i,
369 "The projection on IPC must not contain duplicates"
370 );
371 previous = i;
372 }
373 }
374
375 ProjectionInfo {
376 columns: projection,
377 map,
378 schema,
379 }
380}
381
382pub fn apply_projection(
383 chunk: RecordBatchT<Box<dyn Array>>,
384 map: &PlHashMap<usize, usize>,
385) -> RecordBatchT<Box<dyn Array>> {
386 let length = chunk.len();
387
388 let (schema, arrays) = chunk.into_schema_and_arrays();
390 let mut new_schema = schema.as_ref().clone();
391 let mut new_arrays = arrays.clone();
392
393 map.iter().for_each(|(old, new)| {
394 let (old_name, old_field) = schema.get_at_index(*old).unwrap();
395 let (new_name, new_field) = new_schema.get_at_index_mut(*new).unwrap();
396
397 *new_name = old_name.clone();
398 *new_field = old_field.clone();
399
400 new_arrays[*new] = arrays[*old].clone();
401 });
402
403 RecordBatchT::new(length, Arc::new(new_schema), new_arrays)
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn project_iter() {
412 let iter = 1..6;
413 let iter = ProjectionIter::new(&[0, 2, 4], iter);
414 let result: Vec<_> = iter.collect();
415 use ProjectionResult::*;
416 assert_eq!(
417 result,
418 vec![
419 Selected(1),
420 NotSelected(2),
421 Selected(3),
422 NotSelected(4),
423 Selected(5)
424 ]
425 )
426 }
427}