polars_arrow/io/ipc/write/
writer.rs1use std::io::Write;
2use std::sync::Arc;
3
4use arrow_format::ipc::planus::Builder;
5use polars_error::{polars_bail, PolarsResult};
6
7use super::super::{IpcField, ARROW_MAGIC_V2};
8use super::common::{DictionaryTracker, EncodedData, WriteOptions};
9use super::common_sync::{write_continuation, write_message};
10use super::{default_ipc_fields, encode_record_batch, schema, schema_to_bytes};
11use crate::array::Array;
12use crate::datatypes::*;
13use crate::io::ipc::write::common::encode_chunk_amortized;
14use crate::record_batch::RecordBatchT;
15
16#[derive(Clone, Copy, PartialEq, Eq)]
17pub(crate) enum State {
18 None,
19 Started,
20 Finished,
21}
22
23pub struct FileWriter<W: Write> {
25 pub(crate) writer: W,
27 pub(crate) options: WriteOptions,
29 pub(crate) schema: ArrowSchemaRef,
31 pub(crate) ipc_fields: Vec<IpcField>,
32 pub(crate) block_offsets: usize,
34 pub(crate) dictionary_blocks: Vec<arrow_format::ipc::Block>,
36 pub(crate) record_blocks: Vec<arrow_format::ipc::Block>,
38 pub(crate) state: State,
40 pub(crate) dictionary_tracker: DictionaryTracker,
42 pub(crate) encoded_message: EncodedData,
44 pub(crate) custom_schema_metadata: Option<Arc<Metadata>>,
46}
47
48impl<W: Write> FileWriter<W> {
49 pub fn try_new(
51 writer: W,
52 schema: ArrowSchemaRef,
53 ipc_fields: Option<Vec<IpcField>>,
54 options: WriteOptions,
55 ) -> PolarsResult<Self> {
56 let mut slf = Self::new(writer, schema, ipc_fields, options);
57 slf.start()?;
58
59 Ok(slf)
60 }
61
62 pub fn new(
64 writer: W,
65 schema: ArrowSchemaRef,
66 ipc_fields: Option<Vec<IpcField>>,
67 options: WriteOptions,
68 ) -> Self {
69 let ipc_fields = if let Some(ipc_fields) = ipc_fields {
70 ipc_fields
71 } else {
72 default_ipc_fields(schema.iter_values())
73 };
74
75 Self {
76 writer,
77 options,
78 schema,
79 ipc_fields,
80 block_offsets: 0,
81 dictionary_blocks: vec![],
82 record_blocks: vec![],
83 state: State::None,
84 dictionary_tracker: DictionaryTracker {
85 dictionaries: Default::default(),
86 cannot_replace: true,
87 },
88 encoded_message: Default::default(),
89 custom_schema_metadata: None,
90 }
91 }
92
93 pub fn into_inner(self) -> W {
95 self.writer
96 }
97
98 pub fn get_scratches(&mut self) -> EncodedData {
101 std::mem::take(&mut self.encoded_message)
102 }
103 pub fn set_scratches(&mut self, scratches: EncodedData) {
106 self.encoded_message = scratches;
107 }
108
109 pub fn start(&mut self) -> PolarsResult<()> {
113 if self.state != State::None {
114 polars_bail!(oos = "The IPC file can only be started once");
115 }
116 self.writer.write_all(&ARROW_MAGIC_V2[..])?;
118 self.writer.write_all(&[0, 0])?;
120 let encoded_message = EncodedData {
123 ipc_message: schema_to_bytes(
124 &self.schema,
125 &self.ipc_fields,
126 None,
128 ),
129 arrow_data: vec![],
130 };
131
132 let (meta, data) = write_message(&mut self.writer, &encoded_message)?;
133 self.block_offsets += meta + data + 8; self.state = State::Started;
135 Ok(())
136 }
137
138 pub fn write(
140 &mut self,
141 chunk: &RecordBatchT<Box<dyn Array>>,
142 ipc_fields: Option<&[IpcField]>,
143 ) -> PolarsResult<()> {
144 if self.state != State::Started {
145 polars_bail!(
146 oos ="The IPC file must be started before it can be written to. Call `start` before `write`"
147 );
148 }
149
150 let ipc_fields = if let Some(ipc_fields) = ipc_fields {
151 ipc_fields
152 } else {
153 self.ipc_fields.as_ref()
154 };
155 let encoded_dictionaries = encode_chunk_amortized(
156 chunk,
157 ipc_fields,
158 &mut self.dictionary_tracker,
159 &self.options,
160 &mut self.encoded_message,
161 )?;
162 encode_record_batch(chunk, &self.options, &mut self.encoded_message);
163
164 let encoded_message = std::mem::take(&mut self.encoded_message);
165 self.write_encoded(&encoded_dictionaries[..], &encoded_message)?;
166 self.encoded_message = encoded_message;
167
168 Ok(())
169 }
170
171 pub fn write_encoded(
172 &mut self,
173 encoded_dictionaries: &[EncodedData],
174 encoded_message: &EncodedData,
175 ) -> PolarsResult<()> {
176 for encoded_dictionary in encoded_dictionaries {
178 let (meta, data) = write_message(&mut self.writer, encoded_dictionary)?;
179
180 let block = arrow_format::ipc::Block {
181 offset: self.block_offsets as i64,
182 meta_data_length: meta as i32,
183 body_length: data as i64,
184 };
185 self.dictionary_blocks.push(block);
186 self.block_offsets += meta + data;
187 }
188
189 self.write_encoded_record_batch(encoded_message)?;
190
191 Ok(())
192 }
193
194 pub fn write_encoded_record_batch(
195 &mut self,
196 encoded_message: &EncodedData,
197 ) -> PolarsResult<()> {
198 let (meta, data) = write_message(&mut self.writer, encoded_message)?;
199 let block = arrow_format::ipc::Block {
201 offset: self.block_offsets as i64,
202 meta_data_length: meta as i32, body_length: data as i64,
204 };
205 self.record_blocks.push(block);
206 self.block_offsets += meta + data;
207
208 Ok(())
209 }
210
211 pub fn finish(&mut self) -> PolarsResult<()> {
213 if self.state != State::Started {
214 polars_bail!(
215 oos = "The IPC file must be started before it can be finished. Call `start` before `finish`"
216 );
217 }
218
219 write_continuation(&mut self.writer, 0)?;
221
222 let schema = schema::serialize_schema(
223 &self.schema,
224 &self.ipc_fields,
225 self.custom_schema_metadata.as_deref(),
226 );
227
228 let root = arrow_format::ipc::Footer {
229 version: arrow_format::ipc::MetadataVersion::V5,
230 schema: Some(Box::new(schema)),
231 dictionaries: Some(std::mem::take(&mut self.dictionary_blocks)),
232 record_batches: Some(std::mem::take(&mut self.record_blocks)),
233 custom_metadata: None,
234 };
235 let mut builder = Builder::new();
236 let footer_data = builder.finish(&root, None);
237 self.writer.write_all(footer_data)?;
238 self.writer
239 .write_all(&(footer_data.len() as i32).to_le_bytes())?;
240 self.writer.write_all(&ARROW_MAGIC_V2)?;
241 self.writer.flush()?;
242 self.state = State::Finished;
243
244 Ok(())
245 }
246
247 pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc<Metadata>) {
249 self.custom_schema_metadata = Some(custom_metadata);
250 }
251}