polars_arrow/io/ipc/write/
writer.rs

1use std::io::Write;
2use std::sync::Arc;
3
4use arrow_format::ipc::planus::Builder;
5use polars_error::{polars_bail, PolarsResult};
6
7use super::super::{IpcField, ARROW_MAGIC_V2};
8use super::common::{DictionaryTracker, EncodedData, WriteOptions};
9use super::common_sync::{write_continuation, write_message};
10use super::{default_ipc_fields, encode_record_batch, schema, schema_to_bytes};
11use crate::array::Array;
12use crate::datatypes::*;
13use crate::io::ipc::write::common::encode_chunk_amortized;
14use crate::record_batch::RecordBatchT;
15
16#[derive(Clone, Copy, PartialEq, Eq)]
17pub(crate) enum State {
18    None,
19    Started,
20    Finished,
21}
22
23/// Arrow file writer
24pub struct FileWriter<W: Write> {
25    /// The object to write to
26    pub(crate) writer: W,
27    /// IPC write options
28    pub(crate) options: WriteOptions,
29    /// A reference to the schema, used in validating record batches
30    pub(crate) schema: ArrowSchemaRef,
31    pub(crate) ipc_fields: Vec<IpcField>,
32    /// The number of bytes between each block of bytes, as an offset for random access
33    pub(crate) block_offsets: usize,
34    /// Dictionary blocks that will be written as part of the IPC footer
35    pub(crate) dictionary_blocks: Vec<arrow_format::ipc::Block>,
36    /// Record blocks that will be written as part of the IPC footer
37    pub(crate) record_blocks: Vec<arrow_format::ipc::Block>,
38    /// Whether the writer footer has been written, and the writer is finished
39    pub(crate) state: State,
40    /// Keeps track of dictionaries that have been written
41    pub(crate) dictionary_tracker: DictionaryTracker,
42    /// Buffer/scratch that is reused between writes
43    pub(crate) encoded_message: EncodedData,
44    /// Custom schema-level metadata
45    pub(crate) custom_schema_metadata: Option<Arc<Metadata>>,
46}
47
48impl<W: Write> FileWriter<W> {
49    /// Creates a new [`FileWriter`] and writes the header to `writer`
50    pub fn try_new(
51        writer: W,
52        schema: ArrowSchemaRef,
53        ipc_fields: Option<Vec<IpcField>>,
54        options: WriteOptions,
55    ) -> PolarsResult<Self> {
56        let mut slf = Self::new(writer, schema, ipc_fields, options);
57        slf.start()?;
58
59        Ok(slf)
60    }
61
62    /// Creates a new [`FileWriter`].
63    pub fn new(
64        writer: W,
65        schema: ArrowSchemaRef,
66        ipc_fields: Option<Vec<IpcField>>,
67        options: WriteOptions,
68    ) -> Self {
69        let ipc_fields = if let Some(ipc_fields) = ipc_fields {
70            ipc_fields
71        } else {
72            default_ipc_fields(schema.iter_values())
73        };
74
75        Self {
76            writer,
77            options,
78            schema,
79            ipc_fields,
80            block_offsets: 0,
81            dictionary_blocks: vec![],
82            record_blocks: vec![],
83            state: State::None,
84            dictionary_tracker: DictionaryTracker {
85                dictionaries: Default::default(),
86                cannot_replace: true,
87            },
88            encoded_message: Default::default(),
89            custom_schema_metadata: None,
90        }
91    }
92
93    /// Consumes itself into the inner writer
94    pub fn into_inner(self) -> W {
95        self.writer
96    }
97
98    /// Get the inner memory scratches so they can be reused in a new writer.
99    /// This can be utilized to save memory allocations for performance reasons.
100    pub fn get_scratches(&mut self) -> EncodedData {
101        std::mem::take(&mut self.encoded_message)
102    }
103    /// Set the inner memory scratches so they can be reused in a new writer.
104    /// This can be utilized to save memory allocations for performance reasons.
105    pub fn set_scratches(&mut self, scratches: EncodedData) {
106        self.encoded_message = scratches;
107    }
108
109    /// Writes the header and first (schema) message to the file.
110    /// # Errors
111    /// Errors if the file has been started or has finished.
112    pub fn start(&mut self) -> PolarsResult<()> {
113        if self.state != State::None {
114            polars_bail!(oos = "The IPC file can only be started once");
115        }
116        // write magic to header
117        self.writer.write_all(&ARROW_MAGIC_V2[..])?;
118        // create an 8-byte boundary after the header
119        self.writer.write_all(&[0, 0])?;
120        // write the schema, set the written bytes to the schema
121
122        let encoded_message = EncodedData {
123            ipc_message: schema_to_bytes(
124                &self.schema,
125                &self.ipc_fields,
126                // No need to pass metadata here, as it is already written to the footer in `finish`
127                None,
128            ),
129            arrow_data: vec![],
130        };
131
132        let (meta, data) = write_message(&mut self.writer, &encoded_message)?;
133        self.block_offsets += meta + data + 8; // 8 <=> arrow magic + 2 bytes for alignment
134        self.state = State::Started;
135        Ok(())
136    }
137
138    /// Writes [`RecordBatchT`] to the file
139    pub fn write(
140        &mut self,
141        chunk: &RecordBatchT<Box<dyn Array>>,
142        ipc_fields: Option<&[IpcField]>,
143    ) -> PolarsResult<()> {
144        if self.state != State::Started {
145            polars_bail!(
146                oos ="The IPC file must be started before it can be written to. Call `start` before `write`"
147            );
148        }
149
150        let ipc_fields = if let Some(ipc_fields) = ipc_fields {
151            ipc_fields
152        } else {
153            self.ipc_fields.as_ref()
154        };
155        let encoded_dictionaries = encode_chunk_amortized(
156            chunk,
157            ipc_fields,
158            &mut self.dictionary_tracker,
159            &self.options,
160            &mut self.encoded_message,
161        )?;
162        encode_record_batch(chunk, &self.options, &mut self.encoded_message);
163
164        let encoded_message = std::mem::take(&mut self.encoded_message);
165        self.write_encoded(&encoded_dictionaries[..], &encoded_message)?;
166        self.encoded_message = encoded_message;
167
168        Ok(())
169    }
170
171    pub fn write_encoded(
172        &mut self,
173        encoded_dictionaries: &[EncodedData],
174        encoded_message: &EncodedData,
175    ) -> PolarsResult<()> {
176        // add all dictionaries
177        for encoded_dictionary in encoded_dictionaries {
178            let (meta, data) = write_message(&mut self.writer, encoded_dictionary)?;
179
180            let block = arrow_format::ipc::Block {
181                offset: self.block_offsets as i64,
182                meta_data_length: meta as i32,
183                body_length: data as i64,
184            };
185            self.dictionary_blocks.push(block);
186            self.block_offsets += meta + data;
187        }
188
189        self.write_encoded_record_batch(encoded_message)?;
190
191        Ok(())
192    }
193
194    pub fn write_encoded_record_batch(
195        &mut self,
196        encoded_message: &EncodedData,
197    ) -> PolarsResult<()> {
198        let (meta, data) = write_message(&mut self.writer, encoded_message)?;
199        // add a record block for the footer
200        let block = arrow_format::ipc::Block {
201            offset: self.block_offsets as i64,
202            meta_data_length: meta as i32, // TODO: is this still applicable?
203            body_length: data as i64,
204        };
205        self.record_blocks.push(block);
206        self.block_offsets += meta + data;
207
208        Ok(())
209    }
210
211    /// Write footer and closing tag, then mark the writer as done
212    pub fn finish(&mut self) -> PolarsResult<()> {
213        if self.state != State::Started {
214            polars_bail!(
215                oos = "The IPC file must be started before it can be finished. Call `start` before `finish`"
216            );
217        }
218
219        // write EOS
220        write_continuation(&mut self.writer, 0)?;
221
222        let schema = schema::serialize_schema(
223            &self.schema,
224            &self.ipc_fields,
225            self.custom_schema_metadata.as_deref(),
226        );
227
228        let root = arrow_format::ipc::Footer {
229            version: arrow_format::ipc::MetadataVersion::V5,
230            schema: Some(Box::new(schema)),
231            dictionaries: Some(std::mem::take(&mut self.dictionary_blocks)),
232            record_batches: Some(std::mem::take(&mut self.record_blocks)),
233            custom_metadata: None,
234        };
235        let mut builder = Builder::new();
236        let footer_data = builder.finish(&root, None);
237        self.writer.write_all(footer_data)?;
238        self.writer
239            .write_all(&(footer_data.len() as i32).to_le_bytes())?;
240        self.writer.write_all(&ARROW_MAGIC_V2)?;
241        self.writer.flush()?;
242        self.state = State::Finished;
243
244        Ok(())
245    }
246
247    /// Sets custom schema metadata. Must be called before `start` is called
248    pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc<Metadata>) {
249        self.custom_schema_metadata = Some(custom_metadata);
250    }
251}