1use std::borrow::{Borrow, Cow};
2
3use arrow_format::ipc::planus::Builder;
4use polars_error::{polars_bail, polars_err, PolarsResult};
5
6use super::super::IpcField;
7use super::{write, write_dictionary};
8use crate::array::*;
9use crate::datatypes::*;
10use crate::io::ipc::endianness::is_native_little_endian;
11use crate::io::ipc::read::Dictionaries;
12use crate::legacy::prelude::LargeListArray;
13use crate::match_integer_type;
14use crate::record_batch::RecordBatchT;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum Compression {
19 LZ4,
21 ZSTD,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
27pub struct WriteOptions {
28 pub compression: Option<Compression>,
31}
32
33pub fn dictionaries_to_encode(
35 field: &IpcField,
36 array: &dyn Array,
37 dictionary_tracker: &mut DictionaryTracker,
38 dicts_to_encode: &mut Vec<(i64, Box<dyn Array>)>,
39) -> PolarsResult<()> {
40 use PhysicalType::*;
41 match array.dtype().to_physical_type() {
42 Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null
43 | FixedSizeBinary | BinaryView | Utf8View => Ok(()),
44 Dictionary(key_type) => match_integer_type!(key_type, |$T| {
45 let dict_id = field.dictionary_id
46 .ok_or_else(|| polars_err!(InvalidOperation: "Dictionaries must have an associated id"))?;
47
48 if dictionary_tracker.insert(dict_id, array)? {
49 dicts_to_encode.push((dict_id, array.to_boxed()));
50 }
51
52 let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
53 let values = array.values();
54 dictionaries_to_encode(field,
56 values.as_ref(),
57 dictionary_tracker,
58 dicts_to_encode,
59 )?;
60
61 Ok(())
62 }),
63 Struct => {
64 let array = array.as_any().downcast_ref::<StructArray>().unwrap();
65 let fields = field.fields.as_slice();
66 if array.fields().len() != fields.len() {
67 polars_bail!(InvalidOperation:
68 "The number of fields in a struct must equal the number of children in IpcField".to_string(),
69 );
70 }
71 fields
72 .iter()
73 .zip(array.values().iter())
74 .try_for_each(|(field, values)| {
75 dictionaries_to_encode(
76 field,
77 values.as_ref(),
78 dictionary_tracker,
79 dicts_to_encode,
80 )
81 })
82 },
83 List => {
84 let values = array
85 .as_any()
86 .downcast_ref::<ListArray<i32>>()
87 .unwrap()
88 .values();
89 let field = &field.fields[0]; dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
91 },
92 LargeList => {
93 let values = array
94 .as_any()
95 .downcast_ref::<ListArray<i64>>()
96 .unwrap()
97 .values();
98 let field = &field.fields[0]; dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
100 },
101 FixedSizeList => {
102 let values = array
103 .as_any()
104 .downcast_ref::<FixedSizeListArray>()
105 .unwrap()
106 .values();
107 let field = &field.fields[0]; dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
109 },
110 Union => {
111 let values = array
112 .as_any()
113 .downcast_ref::<UnionArray>()
114 .unwrap()
115 .fields();
116 let fields = &field.fields[..]; if values.len() != fields.len() {
118 polars_bail!(InvalidOperation:
119 "The number of fields in a union must equal the number of children in IpcField"
120 );
121 }
122 fields
123 .iter()
124 .zip(values.iter())
125 .try_for_each(|(field, values)| {
126 dictionaries_to_encode(
127 field,
128 values.as_ref(),
129 dictionary_tracker,
130 dicts_to_encode,
131 )
132 })
133 },
134 Map => {
135 let values = array.as_any().downcast_ref::<MapArray>().unwrap().field();
136 let field = &field.fields[0]; dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
138 },
139 }
140}
141
142pub fn encode_dictionary(
148 dict_id: i64,
149 array: &dyn Array,
150 options: &WriteOptions,
151 encoded_dictionaries: &mut Vec<EncodedData>,
152) -> PolarsResult<()> {
153 let PhysicalType::Dictionary(key_type) = array.dtype().to_physical_type() else {
154 panic!("Given array is not a DictionaryArray")
155 };
156
157 match_integer_type!(key_type, |$T| {
158 let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
159 encoded_dictionaries.push(dictionary_batch_to_bytes::<$T>(
160 dict_id,
161 array,
162 options,
163 is_native_little_endian(),
164 ));
165 });
166
167 Ok(())
168}
169
170pub fn encode_new_dictionaries(
171 field: &IpcField,
172 array: &dyn Array,
173 options: &WriteOptions,
174 dictionary_tracker: &mut DictionaryTracker,
175 encoded_dictionaries: &mut Vec<EncodedData>,
176) -> PolarsResult<()> {
177 let mut dicts_to_encode = Vec::new();
178 dictionaries_to_encode(field, array, dictionary_tracker, &mut dicts_to_encode)?;
179 for (dict_id, dict_array) in dicts_to_encode {
180 encode_dictionary(dict_id, dict_array.as_ref(), options, encoded_dictionaries)?;
181 }
182 Ok(())
183}
184
185pub fn encode_chunk(
186 chunk: &RecordBatchT<Box<dyn Array>>,
187 fields: &[IpcField],
188 dictionary_tracker: &mut DictionaryTracker,
189 options: &WriteOptions,
190) -> PolarsResult<(Vec<EncodedData>, EncodedData)> {
191 let mut encoded_message = EncodedData::default();
192 let encoded_dictionaries = encode_chunk_amortized(
193 chunk,
194 fields,
195 dictionary_tracker,
196 options,
197 &mut encoded_message,
198 )?;
199 Ok((encoded_dictionaries, encoded_message))
200}
201
202pub fn encode_chunk_amortized(
204 chunk: &RecordBatchT<Box<dyn Array>>,
205 fields: &[IpcField],
206 dictionary_tracker: &mut DictionaryTracker,
207 options: &WriteOptions,
208 encoded_message: &mut EncodedData,
209) -> PolarsResult<Vec<EncodedData>> {
210 let mut encoded_dictionaries = vec![];
211
212 for (field, array) in fields.iter().zip(chunk.as_ref()) {
213 encode_new_dictionaries(
214 field,
215 array.as_ref(),
216 options,
217 dictionary_tracker,
218 &mut encoded_dictionaries,
219 )?;
220 }
221
222 encode_record_batch(chunk, options, encoded_message);
223
224 Ok(encoded_dictionaries)
225}
226
227fn serialize_compression(
228 compression: Option<Compression>,
229) -> Option<Box<arrow_format::ipc::BodyCompression>> {
230 if let Some(compression) = compression {
231 let codec = match compression {
232 Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame,
233 Compression::ZSTD => arrow_format::ipc::CompressionType::Zstd,
234 };
235 Some(Box::new(arrow_format::ipc::BodyCompression {
236 codec,
237 method: arrow_format::ipc::BodyCompressionMethod::Buffer,
238 }))
239 } else {
240 None
241 }
242}
243
244fn set_variadic_buffer_counts(counts: &mut Vec<i64>, array: &dyn Array) {
245 match array.dtype() {
246 ArrowDataType::Utf8View => {
247 let array = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
248 counts.push(array.data_buffers().len() as i64);
249 },
250 ArrowDataType::BinaryView => {
251 let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
252 counts.push(array.data_buffers().len() as i64);
253 },
254 ArrowDataType::Struct(_) => {
255 let array = array.as_any().downcast_ref::<StructArray>().unwrap();
256 for array in array.values() {
257 set_variadic_buffer_counts(counts, array.as_ref())
258 }
259 },
260 ArrowDataType::LargeList(_) => {
261 let array = array.as_any().downcast_ref::<LargeListArray>().unwrap();
262 set_variadic_buffer_counts(counts, array.values().as_ref())
263 },
264 ArrowDataType::FixedSizeList(_, _) => {
265 let array = array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
266 set_variadic_buffer_counts(counts, array.values().as_ref())
267 },
268 ArrowDataType::Dictionary(_, _, _) => (),
271 _ => (),
272 }
273}
274
275fn gc_bin_view<'a, T: ViewType + ?Sized>(
276 arr: &'a Box<dyn Array>,
277 concrete_arr: &'a BinaryViewArrayGeneric<T>,
278) -> Cow<'a, Box<dyn Array>> {
279 let bytes_len = concrete_arr.total_bytes_len();
280 let buffer_len = concrete_arr.total_buffer_len();
281 let extra_len = buffer_len.saturating_sub(bytes_len);
282 if extra_len < bytes_len.min(1024) {
283 Cow::Borrowed(arr)
285 } else {
286 Cow::Owned(concrete_arr.clone().gc().boxed())
288 }
289}
290
291pub fn encode_record_batch(
294 chunk: &RecordBatchT<Box<dyn Array>>,
295 options: &WriteOptions,
296 encoded_message: &mut EncodedData,
297) {
298 let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
299 let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
300 let mut arrow_data = std::mem::take(&mut encoded_message.arrow_data);
301 arrow_data.clear();
302
303 let mut offset = 0;
304 let mut variadic_buffer_counts = vec![];
305 for array in chunk.arrays() {
306 let array = match array.dtype() {
308 ArrowDataType::BinaryView => {
309 let concrete_arr = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
310 gc_bin_view(array, concrete_arr)
311 },
312 ArrowDataType::Utf8View => {
313 let concrete_arr = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
314 gc_bin_view(array, concrete_arr)
315 },
316 _ => Cow::Borrowed(array),
317 };
318 let array = array.as_ref().as_ref();
319
320 set_variadic_buffer_counts(&mut variadic_buffer_counts, array);
321
322 write(
323 array,
324 &mut buffers,
325 &mut arrow_data,
326 &mut nodes,
327 &mut offset,
328 is_native_little_endian(),
329 options.compression,
330 )
331 }
332
333 let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
334 None
335 } else {
336 Some(variadic_buffer_counts)
337 };
338
339 let compression = serialize_compression(options.compression);
340
341 let message = arrow_format::ipc::Message {
342 version: arrow_format::ipc::MetadataVersion::V5,
343 header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new(
344 arrow_format::ipc::RecordBatch {
345 length: chunk.len() as i64,
346 nodes: Some(nodes),
347 buffers: Some(buffers),
348 compression,
349 variadic_buffer_counts,
350 },
351 ))),
352 body_length: arrow_data.len() as i64,
353 custom_metadata: None,
354 };
355
356 let mut builder = Builder::new();
357 let ipc_message = builder.finish(&message, None);
358 encoded_message.ipc_message = ipc_message.to_vec();
359 encoded_message.arrow_data = arrow_data
360}
361
362fn dictionary_batch_to_bytes<K: DictionaryKey>(
365 dict_id: i64,
366 array: &DictionaryArray<K>,
367 options: &WriteOptions,
368 is_little_endian: bool,
369) -> EncodedData {
370 let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
371 let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
372 let mut arrow_data: Vec<u8> = vec![];
373 let mut variadic_buffer_counts = vec![];
374 set_variadic_buffer_counts(&mut variadic_buffer_counts, array.values().as_ref());
375
376 let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
377 None
378 } else {
379 Some(variadic_buffer_counts)
380 };
381
382 let length = write_dictionary(
383 array,
384 &mut buffers,
385 &mut arrow_data,
386 &mut nodes,
387 &mut 0,
388 is_little_endian,
389 options.compression,
390 false,
391 );
392
393 let compression = serialize_compression(options.compression);
394
395 let message = arrow_format::ipc::Message {
396 version: arrow_format::ipc::MetadataVersion::V5,
397 header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new(
398 arrow_format::ipc::DictionaryBatch {
399 id: dict_id,
400 data: Some(Box::new(arrow_format::ipc::RecordBatch {
401 length: length as i64,
402 nodes: Some(nodes),
403 buffers: Some(buffers),
404 compression,
405 variadic_buffer_counts,
406 })),
407 is_delta: false,
408 },
409 ))),
410 body_length: arrow_data.len() as i64,
411 custom_metadata: None,
412 };
413
414 let mut builder = Builder::new();
415 let ipc_message = builder.finish(&message, None);
416
417 EncodedData {
418 ipc_message: ipc_message.to_vec(),
419 arrow_data,
420 }
421}
422
423pub struct DictionaryTracker {
427 pub dictionaries: Dictionaries,
428 pub cannot_replace: bool,
429}
430
431impl DictionaryTracker {
432 pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> PolarsResult<bool> {
442 let values = match array.dtype() {
443 ArrowDataType::Dictionary(key_type, _, _) => {
444 match_integer_type!(key_type, |$T| {
445 let array = array
446 .as_any()
447 .downcast_ref::<DictionaryArray<$T>>()
448 .unwrap();
449 array.values()
450 })
451 },
452 _ => unreachable!(),
453 };
454
455 if let Some(last) = self.dictionaries.get(&dict_id) {
457 if last.as_ref() == values.as_ref() {
458 return Ok(false);
460 } else if self.cannot_replace {
461 polars_bail!(InvalidOperation:
462 "Dictionary replacement detected when writing IPC file format. \
463 Arrow IPC files only support a single dictionary for a given field \
464 across all batches."
465 );
466 }
467 };
468
469 self.dictionaries.insert(dict_id, values.clone());
470 Ok(true)
471 }
472}
473
474#[derive(Debug, Default)]
476pub struct EncodedData {
477 pub ipc_message: Vec<u8>,
479 pub arrow_data: Vec<u8>,
481}
482
483#[inline]
485pub(crate) fn pad_to_64(len: usize) -> usize {
486 ((len + 63) & !63) - len
487}
488
489#[derive(Debug, Clone, PartialEq)]
491pub struct Record<'a> {
492 columns: Cow<'a, RecordBatchT<Box<dyn Array>>>,
493 fields: Option<Cow<'a, [IpcField]>>,
494}
495
496impl Record<'_> {
497 pub fn fields(&self) -> Option<&[IpcField]> {
499 self.fields.as_deref()
500 }
501
502 pub fn columns(&self) -> &RecordBatchT<Box<dyn Array>> {
504 self.columns.borrow()
505 }
506}
507
508impl From<RecordBatchT<Box<dyn Array>>> for Record<'static> {
509 fn from(columns: RecordBatchT<Box<dyn Array>>) -> Self {
510 Self {
511 columns: Cow::Owned(columns),
512 fields: None,
513 }
514 }
515}
516
517impl<'a, F> From<(RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
518where
519 F: Into<Cow<'a, [IpcField]>>,
520{
521 fn from((columns, fields): (RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
522 Self {
523 columns: Cow::Owned(columns),
524 fields: fields.map(|f| f.into()),
525 }
526 }
527}
528
529impl<'a, F> From<(&'a RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
530where
531 F: Into<Cow<'a, [IpcField]>>,
532{
533 fn from((columns, fields): (&'a RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
534 Self {
535 columns: Cow::Borrowed(columns),
536 fields: fields.map(|f| f.into()),
537 }
538 }
539}