polars_arrow/io/ipc/write/
common.rs

1use std::borrow::{Borrow, Cow};
2
3use arrow_format::ipc::planus::Builder;
4use polars_error::{polars_bail, polars_err, PolarsResult};
5
6use super::super::IpcField;
7use super::{write, write_dictionary};
8use crate::array::*;
9use crate::datatypes::*;
10use crate::io::ipc::endianness::is_native_little_endian;
11use crate::io::ipc::read::Dictionaries;
12use crate::legacy::prelude::LargeListArray;
13use crate::match_integer_type;
14use crate::record_batch::RecordBatchT;
15
16/// Compression codec
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum Compression {
19    /// LZ4 (framed)
20    LZ4,
21    /// ZSTD
22    ZSTD,
23}
24
25/// Options declaring the behaviour of writing to IPC
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
27pub struct WriteOptions {
28    /// Whether the buffers should be compressed and which codec to use.
29    /// Note: to use compression the crate must be compiled with feature `io_ipc_compression`.
30    pub compression: Option<Compression>,
31}
32
33/// Find the dictionary that are new and need to be encoded.
34pub fn dictionaries_to_encode(
35    field: &IpcField,
36    array: &dyn Array,
37    dictionary_tracker: &mut DictionaryTracker,
38    dicts_to_encode: &mut Vec<(i64, Box<dyn Array>)>,
39) -> PolarsResult<()> {
40    use PhysicalType::*;
41    match array.dtype().to_physical_type() {
42        Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null
43        | FixedSizeBinary | BinaryView | Utf8View => Ok(()),
44        Dictionary(key_type) => match_integer_type!(key_type, |$T| {
45            let dict_id = field.dictionary_id
46                .ok_or_else(|| polars_err!(InvalidOperation: "Dictionaries must have an associated id"))?;
47
48            if dictionary_tracker.insert(dict_id, array)? {
49                dicts_to_encode.push((dict_id, array.to_boxed()));
50            }
51
52            let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
53            let values = array.values();
54            // @Q? Should this not pick fields[0]?
55            dictionaries_to_encode(field,
56                values.as_ref(),
57                dictionary_tracker,
58                dicts_to_encode,
59            )?;
60
61            Ok(())
62        }),
63        Struct => {
64            let array = array.as_any().downcast_ref::<StructArray>().unwrap();
65            let fields = field.fields.as_slice();
66            if array.fields().len() != fields.len() {
67                polars_bail!(InvalidOperation:
68                    "The number of fields in a struct must equal the number of children in IpcField".to_string(),
69                );
70            }
71            fields
72                .iter()
73                .zip(array.values().iter())
74                .try_for_each(|(field, values)| {
75                    dictionaries_to_encode(
76                        field,
77                        values.as_ref(),
78                        dictionary_tracker,
79                        dicts_to_encode,
80                    )
81                })
82        },
83        List => {
84            let values = array
85                .as_any()
86                .downcast_ref::<ListArray<i32>>()
87                .unwrap()
88                .values();
89            let field = &field.fields[0]; // todo: error instead
90            dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
91        },
92        LargeList => {
93            let values = array
94                .as_any()
95                .downcast_ref::<ListArray<i64>>()
96                .unwrap()
97                .values();
98            let field = &field.fields[0]; // todo: error instead
99            dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
100        },
101        FixedSizeList => {
102            let values = array
103                .as_any()
104                .downcast_ref::<FixedSizeListArray>()
105                .unwrap()
106                .values();
107            let field = &field.fields[0]; // todo: error instead
108            dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
109        },
110        Union => {
111            let values = array
112                .as_any()
113                .downcast_ref::<UnionArray>()
114                .unwrap()
115                .fields();
116            let fields = &field.fields[..]; // todo: error instead
117            if values.len() != fields.len() {
118                polars_bail!(InvalidOperation:
119                    "The number of fields in a union must equal the number of children in IpcField"
120                );
121            }
122            fields
123                .iter()
124                .zip(values.iter())
125                .try_for_each(|(field, values)| {
126                    dictionaries_to_encode(
127                        field,
128                        values.as_ref(),
129                        dictionary_tracker,
130                        dicts_to_encode,
131                    )
132                })
133        },
134        Map => {
135            let values = array.as_any().downcast_ref::<MapArray>().unwrap().field();
136            let field = &field.fields[0]; // todo: error instead
137            dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
138        },
139    }
140}
141
142/// Encode a dictionary array with a certain id.
143///
144/// # Panics
145///
146/// This will panic if the given array is not a [`DictionaryArray`].
147pub fn encode_dictionary(
148    dict_id: i64,
149    array: &dyn Array,
150    options: &WriteOptions,
151    encoded_dictionaries: &mut Vec<EncodedData>,
152) -> PolarsResult<()> {
153    let PhysicalType::Dictionary(key_type) = array.dtype().to_physical_type() else {
154        panic!("Given array is not a DictionaryArray")
155    };
156
157    match_integer_type!(key_type, |$T| {
158        let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
159        encoded_dictionaries.push(dictionary_batch_to_bytes::<$T>(
160            dict_id,
161            array,
162            options,
163            is_native_little_endian(),
164        ));
165    });
166
167    Ok(())
168}
169
170pub fn encode_new_dictionaries(
171    field: &IpcField,
172    array: &dyn Array,
173    options: &WriteOptions,
174    dictionary_tracker: &mut DictionaryTracker,
175    encoded_dictionaries: &mut Vec<EncodedData>,
176) -> PolarsResult<()> {
177    let mut dicts_to_encode = Vec::new();
178    dictionaries_to_encode(field, array, dictionary_tracker, &mut dicts_to_encode)?;
179    for (dict_id, dict_array) in dicts_to_encode {
180        encode_dictionary(dict_id, dict_array.as_ref(), options, encoded_dictionaries)?;
181    }
182    Ok(())
183}
184
185pub fn encode_chunk(
186    chunk: &RecordBatchT<Box<dyn Array>>,
187    fields: &[IpcField],
188    dictionary_tracker: &mut DictionaryTracker,
189    options: &WriteOptions,
190) -> PolarsResult<(Vec<EncodedData>, EncodedData)> {
191    let mut encoded_message = EncodedData::default();
192    let encoded_dictionaries = encode_chunk_amortized(
193        chunk,
194        fields,
195        dictionary_tracker,
196        options,
197        &mut encoded_message,
198    )?;
199    Ok((encoded_dictionaries, encoded_message))
200}
201
202// Amortizes `EncodedData` allocation.
203pub fn encode_chunk_amortized(
204    chunk: &RecordBatchT<Box<dyn Array>>,
205    fields: &[IpcField],
206    dictionary_tracker: &mut DictionaryTracker,
207    options: &WriteOptions,
208    encoded_message: &mut EncodedData,
209) -> PolarsResult<Vec<EncodedData>> {
210    let mut encoded_dictionaries = vec![];
211
212    for (field, array) in fields.iter().zip(chunk.as_ref()) {
213        encode_new_dictionaries(
214            field,
215            array.as_ref(),
216            options,
217            dictionary_tracker,
218            &mut encoded_dictionaries,
219        )?;
220    }
221
222    encode_record_batch(chunk, options, encoded_message);
223
224    Ok(encoded_dictionaries)
225}
226
227fn serialize_compression(
228    compression: Option<Compression>,
229) -> Option<Box<arrow_format::ipc::BodyCompression>> {
230    if let Some(compression) = compression {
231        let codec = match compression {
232            Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame,
233            Compression::ZSTD => arrow_format::ipc::CompressionType::Zstd,
234        };
235        Some(Box::new(arrow_format::ipc::BodyCompression {
236            codec,
237            method: arrow_format::ipc::BodyCompressionMethod::Buffer,
238        }))
239    } else {
240        None
241    }
242}
243
244fn set_variadic_buffer_counts(counts: &mut Vec<i64>, array: &dyn Array) {
245    match array.dtype() {
246        ArrowDataType::Utf8View => {
247            let array = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
248            counts.push(array.data_buffers().len() as i64);
249        },
250        ArrowDataType::BinaryView => {
251            let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
252            counts.push(array.data_buffers().len() as i64);
253        },
254        ArrowDataType::Struct(_) => {
255            let array = array.as_any().downcast_ref::<StructArray>().unwrap();
256            for array in array.values() {
257                set_variadic_buffer_counts(counts, array.as_ref())
258            }
259        },
260        ArrowDataType::LargeList(_) => {
261            let array = array.as_any().downcast_ref::<LargeListArray>().unwrap();
262            set_variadic_buffer_counts(counts, array.values().as_ref())
263        },
264        ArrowDataType::FixedSizeList(_, _) => {
265            let array = array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
266            set_variadic_buffer_counts(counts, array.values().as_ref())
267        },
268        // Don't traverse dictionary values as those are set when the `Dictionary` IPC struct
269        // is read.
270        ArrowDataType::Dictionary(_, _, _) => (),
271        _ => (),
272    }
273}
274
275fn gc_bin_view<'a, T: ViewType + ?Sized>(
276    arr: &'a Box<dyn Array>,
277    concrete_arr: &'a BinaryViewArrayGeneric<T>,
278) -> Cow<'a, Box<dyn Array>> {
279    let bytes_len = concrete_arr.total_bytes_len();
280    let buffer_len = concrete_arr.total_buffer_len();
281    let extra_len = buffer_len.saturating_sub(bytes_len);
282    if extra_len < bytes_len.min(1024) {
283        // We can afford some tiny waste.
284        Cow::Borrowed(arr)
285    } else {
286        // Force GC it.
287        Cow::Owned(concrete_arr.clone().gc().boxed())
288    }
289}
290
291/// Write [`RecordBatchT`] into two sets of bytes, one for the header (ipc::Schema::Message) and the
292/// other for the batch's data
293pub fn encode_record_batch(
294    chunk: &RecordBatchT<Box<dyn Array>>,
295    options: &WriteOptions,
296    encoded_message: &mut EncodedData,
297) {
298    let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
299    let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
300    let mut arrow_data = std::mem::take(&mut encoded_message.arrow_data);
301    arrow_data.clear();
302
303    let mut offset = 0;
304    let mut variadic_buffer_counts = vec![];
305    for array in chunk.arrays() {
306        // We don't want to write all buffers in sliced arrays.
307        let array = match array.dtype() {
308            ArrowDataType::BinaryView => {
309                let concrete_arr = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
310                gc_bin_view(array, concrete_arr)
311            },
312            ArrowDataType::Utf8View => {
313                let concrete_arr = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
314                gc_bin_view(array, concrete_arr)
315            },
316            _ => Cow::Borrowed(array),
317        };
318        let array = array.as_ref().as_ref();
319
320        set_variadic_buffer_counts(&mut variadic_buffer_counts, array);
321
322        write(
323            array,
324            &mut buffers,
325            &mut arrow_data,
326            &mut nodes,
327            &mut offset,
328            is_native_little_endian(),
329            options.compression,
330        )
331    }
332
333    let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
334        None
335    } else {
336        Some(variadic_buffer_counts)
337    };
338
339    let compression = serialize_compression(options.compression);
340
341    let message = arrow_format::ipc::Message {
342        version: arrow_format::ipc::MetadataVersion::V5,
343        header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new(
344            arrow_format::ipc::RecordBatch {
345                length: chunk.len() as i64,
346                nodes: Some(nodes),
347                buffers: Some(buffers),
348                compression,
349                variadic_buffer_counts,
350            },
351        ))),
352        body_length: arrow_data.len() as i64,
353        custom_metadata: None,
354    };
355
356    let mut builder = Builder::new();
357    let ipc_message = builder.finish(&message, None);
358    encoded_message.ipc_message = ipc_message.to_vec();
359    encoded_message.arrow_data = arrow_data
360}
361
362/// Write dictionary values into two sets of bytes, one for the header (ipc::Schema::Message) and the
363/// other for the data
364fn dictionary_batch_to_bytes<K: DictionaryKey>(
365    dict_id: i64,
366    array: &DictionaryArray<K>,
367    options: &WriteOptions,
368    is_little_endian: bool,
369) -> EncodedData {
370    let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
371    let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
372    let mut arrow_data: Vec<u8> = vec![];
373    let mut variadic_buffer_counts = vec![];
374    set_variadic_buffer_counts(&mut variadic_buffer_counts, array.values().as_ref());
375
376    let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
377        None
378    } else {
379        Some(variadic_buffer_counts)
380    };
381
382    let length = write_dictionary(
383        array,
384        &mut buffers,
385        &mut arrow_data,
386        &mut nodes,
387        &mut 0,
388        is_little_endian,
389        options.compression,
390        false,
391    );
392
393    let compression = serialize_compression(options.compression);
394
395    let message = arrow_format::ipc::Message {
396        version: arrow_format::ipc::MetadataVersion::V5,
397        header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new(
398            arrow_format::ipc::DictionaryBatch {
399                id: dict_id,
400                data: Some(Box::new(arrow_format::ipc::RecordBatch {
401                    length: length as i64,
402                    nodes: Some(nodes),
403                    buffers: Some(buffers),
404                    compression,
405                    variadic_buffer_counts,
406                })),
407                is_delta: false,
408            },
409        ))),
410        body_length: arrow_data.len() as i64,
411        custom_metadata: None,
412    };
413
414    let mut builder = Builder::new();
415    let ipc_message = builder.finish(&message, None);
416
417    EncodedData {
418        ipc_message: ipc_message.to_vec(),
419        arrow_data,
420    }
421}
422
423/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary
424/// multiple times. Can optionally error if an update to an existing dictionary is attempted, which
425/// isn't allowed in the `FileWriter`.
426pub struct DictionaryTracker {
427    pub dictionaries: Dictionaries,
428    pub cannot_replace: bool,
429}
430
431impl DictionaryTracker {
432    /// Keep track of the dictionary with the given ID and values. Behavior:
433    ///
434    /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate
435    ///   that the dictionary was not actually inserted (because it's already been seen).
436    /// * If this ID has been written already but with different data, and this tracker is
437    ///   configured to return an error, return an error.
438    /// * If the tracker has not been configured to error on replacement or this dictionary
439    ///   has never been seen before, return `Ok(true)` to indicate that the dictionary was just
440    ///   inserted.
441    pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> PolarsResult<bool> {
442        let values = match array.dtype() {
443            ArrowDataType::Dictionary(key_type, _, _) => {
444                match_integer_type!(key_type, |$T| {
445                    let array = array
446                        .as_any()
447                        .downcast_ref::<DictionaryArray<$T>>()
448                        .unwrap();
449                    array.values()
450                })
451            },
452            _ => unreachable!(),
453        };
454
455        // If a dictionary with this id was already emitted, check if it was the same.
456        if let Some(last) = self.dictionaries.get(&dict_id) {
457            if last.as_ref() == values.as_ref() {
458                // Same dictionary values => no need to emit it again
459                return Ok(false);
460            } else if self.cannot_replace {
461                polars_bail!(InvalidOperation:
462                    "Dictionary replacement detected when writing IPC file format. \
463                     Arrow IPC files only support a single dictionary for a given field \
464                     across all batches."
465                );
466            }
467        };
468
469        self.dictionaries.insert(dict_id, values.clone());
470        Ok(true)
471    }
472}
473
474/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data
475#[derive(Debug, Default)]
476pub struct EncodedData {
477    /// An encoded ipc::Schema::Message
478    pub ipc_message: Vec<u8>,
479    /// Arrow buffers to be written, should be an empty vec for schema messages
480    pub arrow_data: Vec<u8>,
481}
482
483/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes
484#[inline]
485pub(crate) fn pad_to_64(len: usize) -> usize {
486    ((len + 63) & !63) - len
487}
488
489/// An array [`RecordBatchT`] with optional accompanying IPC fields.
490#[derive(Debug, Clone, PartialEq)]
491pub struct Record<'a> {
492    columns: Cow<'a, RecordBatchT<Box<dyn Array>>>,
493    fields: Option<Cow<'a, [IpcField]>>,
494}
495
496impl Record<'_> {
497    /// Get the IPC fields for this record.
498    pub fn fields(&self) -> Option<&[IpcField]> {
499        self.fields.as_deref()
500    }
501
502    /// Get the Arrow columns in this record.
503    pub fn columns(&self) -> &RecordBatchT<Box<dyn Array>> {
504        self.columns.borrow()
505    }
506}
507
508impl From<RecordBatchT<Box<dyn Array>>> for Record<'static> {
509    fn from(columns: RecordBatchT<Box<dyn Array>>) -> Self {
510        Self {
511            columns: Cow::Owned(columns),
512            fields: None,
513        }
514    }
515}
516
517impl<'a, F> From<(RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
518where
519    F: Into<Cow<'a, [IpcField]>>,
520{
521    fn from((columns, fields): (RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
522        Self {
523            columns: Cow::Owned(columns),
524            fields: fields.map(|f| f.into()),
525        }
526    }
527}
528
529impl<'a, F> From<(&'a RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
530where
531    F: Into<Cow<'a, [IpcField]>>,
532{
533    fn from((columns, fields): (&'a RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
534        Self {
535            columns: Cow::Borrowed(columns),
536            fields: fields.map(|f| f.into()),
537        }
538    }
539}