polars_arrow/io/ipc/read/
common.rs

1use std::collections::VecDeque;
2use std::io::{Read, Seek};
3use std::sync::Arc;
4
5use polars_error::{polars_bail, polars_err, PolarsResult};
6use polars_utils::aliases::PlHashMap;
7use polars_utils::pl_str::PlSmallStr;
8
9use super::deserialize::{read, skip};
10use super::Dictionaries;
11use crate::array::*;
12use crate::datatypes::{ArrowDataType, ArrowSchema, Field};
13use crate::io::ipc::read::OutOfSpecKind;
14use crate::io::ipc::{IpcField, IpcSchema};
15use crate::record_batch::RecordBatchT;
16
17#[derive(Debug, Eq, PartialEq, Hash)]
18enum ProjectionResult<A> {
19    Selected(A),
20    NotSelected(A),
21}
22
23/// An iterator adapter that will return `Some(x)` or `None`
24/// # Panics
25/// The iterator panics iff the `projection` is not strictly increasing.
26struct ProjectionIter<'a, A, I: Iterator<Item = A>> {
27    projection: &'a [usize],
28    iter: I,
29    current_count: usize,
30    current_projection: usize,
31}
32
33impl<'a, A, I: Iterator<Item = A>> ProjectionIter<'a, A, I> {
34    /// # Panics
35    /// iff `projection` is empty
36    pub fn new(projection: &'a [usize], iter: I) -> Self {
37        Self {
38            projection: &projection[1..],
39            iter,
40            current_count: 0,
41            current_projection: projection[0],
42        }
43    }
44}
45
46impl<A, I: Iterator<Item = A>> Iterator for ProjectionIter<'_, A, I> {
47    type Item = ProjectionResult<A>;
48
49    fn next(&mut self) -> Option<Self::Item> {
50        if let Some(item) = self.iter.next() {
51            let result = if self.current_count == self.current_projection {
52                if !self.projection.is_empty() {
53                    assert!(self.projection[0] > self.current_projection);
54                    self.current_projection = self.projection[0];
55                    self.projection = &self.projection[1..];
56                } else {
57                    self.current_projection = 0 // a value that most likely already passed
58                };
59                Some(ProjectionResult::Selected(item))
60            } else {
61                Some(ProjectionResult::NotSelected(item))
62            };
63            self.current_count += 1;
64            result
65        } else {
66            None
67        }
68    }
69
70    fn size_hint(&self) -> (usize, Option<usize>) {
71        self.iter.size_hint()
72    }
73}
74
75/// Returns a [`RecordBatchT`] from a reader.
76/// # Panic
77/// Panics iff the projection is not in increasing order (e.g. `[1, 0]` nor `[0, 1, 1]` are valid)
78#[allow(clippy::too_many_arguments)]
79pub fn read_record_batch<R: Read + Seek>(
80    batch: arrow_format::ipc::RecordBatchRef,
81    fields: &ArrowSchema,
82    ipc_schema: &IpcSchema,
83    projection: Option<&[usize]>,
84    limit: Option<usize>,
85    dictionaries: &Dictionaries,
86    version: arrow_format::ipc::MetadataVersion,
87    reader: &mut R,
88    block_offset: u64,
89    file_size: u64,
90    scratch: &mut Vec<u8>,
91) -> PolarsResult<RecordBatchT<Box<dyn Array>>> {
92    assert_eq!(fields.len(), ipc_schema.fields.len());
93    let buffers = batch
94        .buffers()
95        .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBuffers(err)))?
96        .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageBuffers))?;
97    let mut variadic_buffer_counts = batch
98        .variadic_buffer_counts()
99        .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)))?
100        .map(|v| v.iter().map(|v| v as usize).collect::<VecDeque<usize>>())
101        .unwrap_or_else(VecDeque::new);
102    let mut buffers: VecDeque<arrow_format::ipc::BufferRef> = buffers.iter().collect();
103
104    // check that the sum of the sizes of all buffers is <= than the size of the file
105    let buffers_size = buffers
106        .iter()
107        .map(|buffer| {
108            let buffer_size: u64 = buffer
109                .length()
110                .try_into()
111                .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
112            Ok(buffer_size)
113        })
114        .sum::<PolarsResult<u64>>()?;
115    if buffers_size > file_size {
116        return Err(polars_err!(
117            oos = OutOfSpecKind::InvalidBuffersLength {
118                buffers_size,
119                file_size,
120            }
121        ));
122    }
123
124    let field_nodes = batch
125        .nodes()
126        .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferNodes(err)))?
127        .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageNodes))?;
128    let mut field_nodes = field_nodes.iter().collect::<VecDeque<_>>();
129
130    let columns = if let Some(projection) = projection {
131        let projection = ProjectionIter::new(
132            projection,
133            fields.iter_values().zip(ipc_schema.fields.iter()),
134        );
135
136        projection
137            .map(|maybe_field| match maybe_field {
138                ProjectionResult::Selected((field, ipc_field)) => Ok(Some(read(
139                    &mut field_nodes,
140                    &mut variadic_buffer_counts,
141                    field,
142                    ipc_field,
143                    &mut buffers,
144                    reader,
145                    dictionaries,
146                    block_offset,
147                    ipc_schema.is_little_endian,
148                    batch.compression().map_err(|err| {
149                        polars_err!(oos = OutOfSpecKind::InvalidFlatbufferCompression(err))
150                    })?,
151                    limit,
152                    version,
153                    scratch,
154                )?)),
155                ProjectionResult::NotSelected((field, _)) => {
156                    skip(
157                        &mut field_nodes,
158                        &field.dtype,
159                        &mut buffers,
160                        &mut variadic_buffer_counts,
161                    )?;
162                    Ok(None)
163                },
164            })
165            .filter_map(|x| x.transpose())
166            .collect::<PolarsResult<Vec<_>>>()?
167    } else {
168        fields
169            .iter_values()
170            .zip(ipc_schema.fields.iter())
171            .map(|(field, ipc_field)| {
172                read(
173                    &mut field_nodes,
174                    &mut variadic_buffer_counts,
175                    field,
176                    ipc_field,
177                    &mut buffers,
178                    reader,
179                    dictionaries,
180                    block_offset,
181                    ipc_schema.is_little_endian,
182                    batch.compression().map_err(|err| {
183                        polars_err!(oos = OutOfSpecKind::InvalidFlatbufferCompression(err))
184                    })?,
185                    limit,
186                    version,
187                    scratch,
188                )
189            })
190            .collect::<PolarsResult<Vec<_>>>()?
191    };
192
193    let length = batch
194        .length()
195        .map_err(|_| polars_err!(oos = OutOfSpecKind::MissingData))
196        .unwrap()
197        .try_into()
198        .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
199    let length = limit.map(|limit| limit.min(length)).unwrap_or(length);
200
201    let mut schema: ArrowSchema = fields.iter_values().cloned().collect();
202    if let Some(projection) = projection {
203        schema = schema.try_project_indices(projection).unwrap();
204    }
205    RecordBatchT::try_new(length, Arc::new(schema), columns)
206}
207
208fn find_first_dict_field_d<'a>(
209    id: i64,
210    dtype: &'a ArrowDataType,
211    ipc_field: &'a IpcField,
212) -> Option<(&'a Field, &'a IpcField)> {
213    use ArrowDataType::*;
214    match dtype {
215        Dictionary(_, inner, _) => find_first_dict_field_d(id, inner.as_ref(), ipc_field),
216        List(field) | LargeList(field) | FixedSizeList(field, ..) | Map(field, ..) => {
217            find_first_dict_field(id, field.as_ref(), &ipc_field.fields[0])
218        },
219        Struct(fields) => {
220            for (field, ipc_field) in fields.iter().zip(ipc_field.fields.iter()) {
221                if let Some(f) = find_first_dict_field(id, field, ipc_field) {
222                    return Some(f);
223                }
224            }
225            None
226        },
227        Union(u) => {
228            for (field, ipc_field) in u.fields.iter().zip(ipc_field.fields.iter()) {
229                if let Some(f) = find_first_dict_field(id, field, ipc_field) {
230                    return Some(f);
231                }
232            }
233            None
234        },
235        _ => None,
236    }
237}
238
239fn find_first_dict_field<'a>(
240    id: i64,
241    field: &'a Field,
242    ipc_field: &'a IpcField,
243) -> Option<(&'a Field, &'a IpcField)> {
244    if let Some(field_id) = ipc_field.dictionary_id {
245        if id == field_id {
246            return Some((field, ipc_field));
247        }
248    }
249    find_first_dict_field_d(id, &field.dtype, ipc_field)
250}
251
252pub(crate) fn first_dict_field<'a>(
253    id: i64,
254    fields: &'a ArrowSchema,
255    ipc_fields: &'a [IpcField],
256) -> PolarsResult<(&'a Field, &'a IpcField)> {
257    assert_eq!(fields.len(), ipc_fields.len());
258    for (field, ipc_field) in fields.iter_values().zip(ipc_fields.iter()) {
259        if let Some(field) = find_first_dict_field(id, field, ipc_field) {
260            return Ok(field);
261        }
262    }
263    Err(polars_err!(
264        oos = OutOfSpecKind::InvalidId { requested_id: id }
265    ))
266}
267
268/// Reads a dictionary from the reader,
269/// updating `dictionaries` with the resulting dictionary
270#[allow(clippy::too_many_arguments)]
271pub fn read_dictionary<R: Read + Seek>(
272    batch: arrow_format::ipc::DictionaryBatchRef,
273    fields: &ArrowSchema,
274    ipc_schema: &IpcSchema,
275    dictionaries: &mut Dictionaries,
276    reader: &mut R,
277    block_offset: u64,
278    file_size: u64,
279    scratch: &mut Vec<u8>,
280) -> PolarsResult<()> {
281    if batch
282        .is_delta()
283        .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferIsDelta(err)))?
284    {
285        polars_bail!(ComputeError: "delta dictionary batches not supported")
286    }
287
288    let id = batch
289        .id()
290        .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferId(err)))?;
291    let (first_field, first_ipc_field) = first_dict_field(id, fields, &ipc_schema.fields)?;
292
293    let batch = batch
294        .data()
295        .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferData(err)))?
296        .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingData))?;
297
298    let value_type =
299        if let ArrowDataType::Dictionary(_, value_type, _) = first_field.dtype.to_logical_type() {
300            value_type.as_ref()
301        } else {
302            polars_bail!(oos = OutOfSpecKind::InvalidIdDataType { requested_id: id })
303        };
304
305    // Make a fake schema for the dictionary batch.
306    let fields = std::iter::once((
307        PlSmallStr::EMPTY,
308        Field::new(PlSmallStr::EMPTY, value_type.clone(), false),
309    ))
310    .collect();
311    let ipc_schema = IpcSchema {
312        fields: vec![first_ipc_field.clone()],
313        is_little_endian: ipc_schema.is_little_endian,
314    };
315    let chunk = read_record_batch(
316        batch,
317        &fields,
318        &ipc_schema,
319        None,
320        None, // we must read the whole dictionary
321        dictionaries,
322        arrow_format::ipc::MetadataVersion::V5,
323        reader,
324        block_offset,
325        file_size,
326        scratch,
327    )?;
328
329    dictionaries.insert(id, chunk.into_arrays().pop().unwrap());
330
331    Ok(())
332}
333
334#[derive(Clone)]
335pub struct ProjectionInfo {
336    pub columns: Vec<usize>,
337    pub map: PlHashMap<usize, usize>,
338    pub schema: ArrowSchema,
339}
340
341pub fn prepare_projection(schema: &ArrowSchema, mut projection: Vec<usize>) -> ProjectionInfo {
342    let schema = projection
343        .iter()
344        .map(|x| {
345            let (k, v) = schema.get_at_index(*x).unwrap();
346            (k.clone(), v.clone())
347        })
348        .collect();
349
350    // todo: find way to do this more efficiently
351    let mut indices = (0..projection.len()).collect::<Vec<_>>();
352    indices.sort_unstable_by_key(|&i| &projection[i]);
353    let map = indices.iter().copied().enumerate().fold(
354        PlHashMap::default(),
355        |mut acc, (index, new_index)| {
356            acc.insert(index, new_index);
357            acc
358        },
359    );
360    projection.sort_unstable();
361
362    // check unique
363    if !projection.is_empty() {
364        let mut previous = projection[0];
365
366        for &i in &projection[1..] {
367            assert!(
368                previous < i,
369                "The projection on IPC must not contain duplicates"
370            );
371            previous = i;
372        }
373    }
374
375    ProjectionInfo {
376        columns: projection,
377        map,
378        schema,
379    }
380}
381
382pub fn apply_projection(
383    chunk: RecordBatchT<Box<dyn Array>>,
384    map: &PlHashMap<usize, usize>,
385) -> RecordBatchT<Box<dyn Array>> {
386    let length = chunk.len();
387
388    // re-order according to projection
389    let (schema, arrays) = chunk.into_schema_and_arrays();
390    let mut new_schema = schema.as_ref().clone();
391    let mut new_arrays = arrays.clone();
392
393    map.iter().for_each(|(old, new)| {
394        let (old_name, old_field) = schema.get_at_index(*old).unwrap();
395        let (new_name, new_field) = new_schema.get_at_index_mut(*new).unwrap();
396
397        *new_name = old_name.clone();
398        *new_field = old_field.clone();
399
400        new_arrays[*new] = arrays[*old].clone();
401    });
402
403    RecordBatchT::new(length, Arc::new(new_schema), new_arrays)
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    #[test]
411    fn project_iter() {
412        let iter = 1..6;
413        let iter = ProjectionIter::new(&[0, 2, 4], iter);
414        let result: Vec<_> = iter.collect();
415        use ProjectionResult::*;
416        assert_eq!(
417            result,
418            vec![
419                Selected(1),
420                NotSelected(2),
421                Selected(3),
422                NotSelected(4),
423                Selected(5)
424            ]
425        )
426    }
427}