polars_arrow/io/ipc/write/
stream.rs1use std::io::Write;
7use std::sync::Arc;
8
9use polars_error::{PolarsError, PolarsResult};
10
11use super::super::IpcField;
12use super::common::{encode_chunk, DictionaryTracker, EncodedData, WriteOptions};
13use super::common_sync::{write_continuation, write_message};
14use super::{default_ipc_fields, schema_to_bytes};
15use crate::array::Array;
16use crate::datatypes::*;
17use crate::record_batch::RecordBatchT;
18
19pub struct StreamWriter<W: Write> {
26 writer: W,
28 write_options: WriteOptions,
30 finished: bool,
32 dictionary_tracker: DictionaryTracker,
34 custom_schema_metadata: Option<Arc<Metadata>>,
36
37 ipc_fields: Option<Vec<IpcField>>,
38}
39
40impl<W: Write> StreamWriter<W> {
41 pub fn new(writer: W, write_options: WriteOptions) -> Self {
43 Self {
44 writer,
45 write_options,
46 finished: false,
47 dictionary_tracker: DictionaryTracker {
48 dictionaries: Default::default(),
49 cannot_replace: false,
50 },
51 ipc_fields: None,
52 custom_schema_metadata: None,
53 }
54 }
55
56 pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc<Metadata>) {
58 self.custom_schema_metadata = Some(custom_metadata);
59 }
60
61 pub fn start(
64 &mut self,
65 schema: &ArrowSchema,
66 ipc_fields: Option<Vec<IpcField>>,
67 ) -> PolarsResult<()> {
68 self.ipc_fields = Some(if let Some(ipc_fields) = ipc_fields {
69 ipc_fields
70 } else {
71 default_ipc_fields(schema.iter_values())
72 });
73
74 let encoded_message = EncodedData {
75 ipc_message: schema_to_bytes(
76 schema,
77 self.ipc_fields.as_ref().unwrap(),
78 self.custom_schema_metadata.as_deref(),
79 ),
80 arrow_data: vec![],
81 };
82 write_message(&mut self.writer, &encoded_message)?;
83 Ok(())
84 }
85
86 pub fn write(
88 &mut self,
89 columns: &RecordBatchT<Box<dyn Array>>,
90 ipc_fields: Option<&[IpcField]>,
91 ) -> PolarsResult<()> {
92 if self.finished {
93 let io_err = std::io::Error::new(
94 std::io::ErrorKind::UnexpectedEof,
95 "Cannot write to a finished stream".to_string(),
96 );
97 return Err(PolarsError::from(io_err));
98 }
99
100 #[allow(clippy::or_fun_call)]
102 let fields = ipc_fields.unwrap_or(self.ipc_fields.as_ref().unwrap());
103
104 let (encoded_dictionaries, encoded_message) = encode_chunk(
105 columns,
106 fields,
107 &mut self.dictionary_tracker,
108 &self.write_options,
109 )?;
110
111 for encoded_dictionary in encoded_dictionaries {
112 write_message(&mut self.writer, &encoded_dictionary)?;
113 }
114
115 write_message(&mut self.writer, &encoded_message)?;
116 Ok(())
117 }
118
119 pub fn finish(&mut self) -> PolarsResult<()> {
121 write_continuation(&mut self.writer, 0)?;
122
123 self.finished = true;
124
125 Ok(())
126 }
127
128 pub fn into_inner(self) -> W {
130 self.writer
131 }
132}