polars_arrow/io/ipc/write/
schema.rs

1use arrow_format::ipc::planus::Builder;
2
3use super::super::IpcField;
4use crate::datatypes::{
5    ArrowDataType, ArrowSchema, Field, IntegerType, IntervalUnit, Metadata, TimeUnit, UnionMode,
6};
7use crate::io::ipc::endianness::is_native_little_endian;
8
9/// Converts a [ArrowSchema] and [IpcField]s to a flatbuffers-encoded [arrow_format::ipc::Message].
10pub fn schema_to_bytes(
11    schema: &ArrowSchema,
12    ipc_fields: &[IpcField],
13    custom_metadata: Option<&Metadata>,
14) -> Vec<u8> {
15    let schema = serialize_schema(schema, ipc_fields, custom_metadata);
16
17    let message = arrow_format::ipc::Message {
18        version: arrow_format::ipc::MetadataVersion::V5,
19        header: Some(arrow_format::ipc::MessageHeader::Schema(Box::new(schema))),
20        body_length: 0,
21        custom_metadata: None, // todo: allow writing custom metadata
22    };
23    let mut builder = Builder::new();
24    let footer_data = builder.finish(&message, None);
25    footer_data.to_vec()
26}
27
28pub fn serialize_schema(
29    schema: &ArrowSchema,
30    ipc_fields: &[IpcField],
31    custom_schema_metadata: Option<&Metadata>,
32) -> arrow_format::ipc::Schema {
33    let endianness = if is_native_little_endian() {
34        arrow_format::ipc::Endianness::Little
35    } else {
36        arrow_format::ipc::Endianness::Big
37    };
38
39    let fields = schema
40        .iter_values()
41        .zip(ipc_fields.iter())
42        .map(|(field, ipc_field)| serialize_field(field, ipc_field))
43        .collect::<Vec<_>>();
44
45    let custom_metadata = custom_schema_metadata.and_then(|custom_meta| {
46        let as_kv = custom_meta
47            .iter()
48            .map(|(key, val)| key_value(key.clone().into_string(), val.clone().into_string()))
49            .collect::<Vec<_>>();
50        (!as_kv.is_empty()).then_some(as_kv)
51    });
52
53    arrow_format::ipc::Schema {
54        endianness,
55        fields: Some(fields),
56        custom_metadata,
57        features: None, // todo add this one
58    }
59}
60
61fn key_value(key: impl Into<String>, val: impl Into<String>) -> arrow_format::ipc::KeyValue {
62    arrow_format::ipc::KeyValue {
63        key: Some(key.into()),
64        value: Some(val.into()),
65    }
66}
67
68fn write_metadata(metadata: &Metadata, kv_vec: &mut Vec<arrow_format::ipc::KeyValue>) {
69    for (k, v) in metadata {
70        if k.as_str() != "ARROW:extension:name" && k.as_str() != "ARROW:extension:metadata" {
71            kv_vec.push(key_value(k.clone().into_string(), v.clone().into_string()));
72        }
73    }
74}
75
76fn write_extension(
77    name: &str,
78    metadata: Option<&str>,
79    kv_vec: &mut Vec<arrow_format::ipc::KeyValue>,
80) {
81    if let Some(metadata) = metadata {
82        kv_vec.push(key_value("ARROW:extension:metadata".to_string(), metadata));
83    }
84
85    kv_vec.push(key_value("ARROW:extension:name".to_string(), name));
86}
87
88/// Create an IPC Field from an Arrow Field
89pub(crate) fn serialize_field(field: &Field, ipc_field: &IpcField) -> arrow_format::ipc::Field {
90    // custom metadata.
91    let mut kv_vec = vec![];
92    if let ArrowDataType::Extension(ext) = field.dtype() {
93        write_extension(
94            &ext.name,
95            ext.metadata.as_ref().map(|x| x.as_str()),
96            &mut kv_vec,
97        );
98    }
99
100    let type_ = serialize_type(field.dtype());
101    let children = serialize_children(field.dtype(), ipc_field);
102
103    let dictionary = if let ArrowDataType::Dictionary(index_type, inner, is_ordered) = field.dtype()
104    {
105        if let ArrowDataType::Extension(ext) = inner.as_ref() {
106            write_extension(
107                ext.name.as_str(),
108                ext.metadata.as_ref().map(|x| x.as_str()),
109                &mut kv_vec,
110            );
111        }
112        Some(serialize_dictionary(
113            index_type,
114            ipc_field
115                .dictionary_id
116                .expect("All Dictionary types have `dict_id`"),
117            *is_ordered,
118        ))
119    } else {
120        None
121    };
122
123    if let Some(metadata) = &field.metadata {
124        write_metadata(metadata, &mut kv_vec);
125    }
126
127    let custom_metadata = if !kv_vec.is_empty() {
128        Some(kv_vec)
129    } else {
130        None
131    };
132
133    arrow_format::ipc::Field {
134        name: Some(field.name.to_string()),
135        nullable: field.is_nullable,
136        type_: Some(type_),
137        dictionary: dictionary.map(Box::new),
138        children: Some(children),
139        custom_metadata,
140    }
141}
142
143fn serialize_time_unit(unit: &TimeUnit) -> arrow_format::ipc::TimeUnit {
144    match unit {
145        TimeUnit::Second => arrow_format::ipc::TimeUnit::Second,
146        TimeUnit::Millisecond => arrow_format::ipc::TimeUnit::Millisecond,
147        TimeUnit::Microsecond => arrow_format::ipc::TimeUnit::Microsecond,
148        TimeUnit::Nanosecond => arrow_format::ipc::TimeUnit::Nanosecond,
149    }
150}
151
152fn serialize_type(dtype: &ArrowDataType) -> arrow_format::ipc::Type {
153    use arrow_format::ipc;
154    use ArrowDataType::*;
155    match dtype {
156        Null => ipc::Type::Null(Box::new(ipc::Null {})),
157        Boolean => ipc::Type::Bool(Box::new(ipc::Bool {})),
158        UInt8 => ipc::Type::Int(Box::new(ipc::Int {
159            bit_width: 8,
160            is_signed: false,
161        })),
162        UInt16 => ipc::Type::Int(Box::new(ipc::Int {
163            bit_width: 16,
164            is_signed: false,
165        })),
166        UInt32 => ipc::Type::Int(Box::new(ipc::Int {
167            bit_width: 32,
168            is_signed: false,
169        })),
170        UInt64 => ipc::Type::Int(Box::new(ipc::Int {
171            bit_width: 64,
172            is_signed: false,
173        })),
174        Int8 => ipc::Type::Int(Box::new(ipc::Int {
175            bit_width: 8,
176            is_signed: true,
177        })),
178        Int16 => ipc::Type::Int(Box::new(ipc::Int {
179            bit_width: 16,
180            is_signed: true,
181        })),
182        Int32 => ipc::Type::Int(Box::new(ipc::Int {
183            bit_width: 32,
184            is_signed: true,
185        })),
186        Int64 => ipc::Type::Int(Box::new(ipc::Int {
187            bit_width: 64,
188            is_signed: true,
189        })),
190        Int128 => ipc::Type::Int(Box::new(ipc::Int {
191            bit_width: 128,
192            is_signed: true,
193        })),
194        Float16 => ipc::Type::FloatingPoint(Box::new(ipc::FloatingPoint {
195            precision: ipc::Precision::Half,
196        })),
197        Float32 => ipc::Type::FloatingPoint(Box::new(ipc::FloatingPoint {
198            precision: ipc::Precision::Single,
199        })),
200        Float64 => ipc::Type::FloatingPoint(Box::new(ipc::FloatingPoint {
201            precision: ipc::Precision::Double,
202        })),
203        Decimal(precision, scale) => ipc::Type::Decimal(Box::new(ipc::Decimal {
204            precision: *precision as i32,
205            scale: *scale as i32,
206            bit_width: 128,
207        })),
208        Decimal256(precision, scale) => ipc::Type::Decimal(Box::new(ipc::Decimal {
209            precision: *precision as i32,
210            scale: *scale as i32,
211            bit_width: 256,
212        })),
213        Binary => ipc::Type::Binary(Box::new(ipc::Binary {})),
214        LargeBinary => ipc::Type::LargeBinary(Box::new(ipc::LargeBinary {})),
215        Utf8 => ipc::Type::Utf8(Box::new(ipc::Utf8 {})),
216        LargeUtf8 => ipc::Type::LargeUtf8(Box::new(ipc::LargeUtf8 {})),
217        FixedSizeBinary(size) => ipc::Type::FixedSizeBinary(Box::new(ipc::FixedSizeBinary {
218            byte_width: *size as i32,
219        })),
220        Date32 => ipc::Type::Date(Box::new(ipc::Date {
221            unit: ipc::DateUnit::Day,
222        })),
223        Date64 => ipc::Type::Date(Box::new(ipc::Date {
224            unit: ipc::DateUnit::Millisecond,
225        })),
226        Duration(unit) => ipc::Type::Duration(Box::new(ipc::Duration {
227            unit: serialize_time_unit(unit),
228        })),
229        Time32(unit) => ipc::Type::Time(Box::new(ipc::Time {
230            unit: serialize_time_unit(unit),
231            bit_width: 32,
232        })),
233        Time64(unit) => ipc::Type::Time(Box::new(ipc::Time {
234            unit: serialize_time_unit(unit),
235            bit_width: 64,
236        })),
237        Timestamp(unit, tz) => ipc::Type::Timestamp(Box::new(ipc::Timestamp {
238            unit: serialize_time_unit(unit),
239            timezone: tz.as_ref().map(|x| x.to_string()),
240        })),
241        Interval(unit) => ipc::Type::Interval(Box::new(ipc::Interval {
242            unit: match unit {
243                IntervalUnit::YearMonth => ipc::IntervalUnit::YearMonth,
244                IntervalUnit::DayTime => ipc::IntervalUnit::DayTime,
245                IntervalUnit::MonthDayNano => ipc::IntervalUnit::MonthDayNano,
246            },
247        })),
248        List(_) => ipc::Type::List(Box::new(ipc::List {})),
249        LargeList(_) => ipc::Type::LargeList(Box::new(ipc::LargeList {})),
250        FixedSizeList(_, size) => ipc::Type::FixedSizeList(Box::new(ipc::FixedSizeList {
251            list_size: *size as i32,
252        })),
253        Union(u) => ipc::Type::Union(Box::new(ipc::Union {
254            mode: match u.mode {
255                UnionMode::Dense => ipc::UnionMode::Dense,
256                UnionMode::Sparse => ipc::UnionMode::Sparse,
257            },
258            type_ids: u.ids.clone(),
259        })),
260        Map(_, keys_sorted) => ipc::Type::Map(Box::new(ipc::Map {
261            keys_sorted: *keys_sorted,
262        })),
263        Struct(_) => ipc::Type::Struct(Box::new(ipc::Struct {})),
264        Dictionary(_, v, _) => serialize_type(v),
265        Extension(ext) => serialize_type(&ext.inner),
266        Utf8View => ipc::Type::Utf8View(Box::new(ipc::Utf8View {})),
267        BinaryView => ipc::Type::BinaryView(Box::new(ipc::BinaryView {})),
268        Unknown => unimplemented!(),
269    }
270}
271
272fn serialize_children(
273    dtype: &ArrowDataType,
274    ipc_field: &IpcField,
275) -> Vec<arrow_format::ipc::Field> {
276    use ArrowDataType::*;
277    match dtype {
278        Null
279        | Boolean
280        | Int8
281        | Int16
282        | Int32
283        | Int64
284        | UInt8
285        | UInt16
286        | UInt32
287        | UInt64
288        | Int128
289        | Float16
290        | Float32
291        | Float64
292        | Timestamp(_, _)
293        | Date32
294        | Date64
295        | Time32(_)
296        | Time64(_)
297        | Duration(_)
298        | Interval(_)
299        | Binary
300        | FixedSizeBinary(_)
301        | LargeBinary
302        | Utf8
303        | LargeUtf8
304        | Decimal(_, _)
305        | Utf8View
306        | BinaryView
307        | Decimal256(_, _) => vec![],
308        FixedSizeList(inner, _) | LargeList(inner) | List(inner) | Map(inner, _) => {
309            vec![serialize_field(inner, &ipc_field.fields[0])]
310        },
311        Struct(fields) => fields
312            .iter()
313            .zip(ipc_field.fields.iter())
314            .map(|(field, ipc)| serialize_field(field, ipc))
315            .collect(),
316        Union(u) => u
317            .fields
318            .iter()
319            .zip(ipc_field.fields.iter())
320            .map(|(field, ipc)| serialize_field(field, ipc))
321            .collect(),
322        Dictionary(_, inner, _) => serialize_children(inner, ipc_field),
323        Extension(ext) => serialize_children(&ext.inner, ipc_field),
324        Unknown => unimplemented!(),
325    }
326}
327
328/// Create an IPC dictionary encoding
329pub(crate) fn serialize_dictionary(
330    index_type: &IntegerType,
331    dict_id: i64,
332    dict_is_ordered: bool,
333) -> arrow_format::ipc::DictionaryEncoding {
334    use IntegerType::*;
335    let is_signed = match index_type {
336        Int8 | Int16 | Int32 | Int64 | Int128 => true,
337        UInt8 | UInt16 | UInt32 | UInt64 => false,
338    };
339
340    let bit_width = match index_type {
341        Int8 | UInt8 => 8,
342        Int16 | UInt16 => 16,
343        Int32 | UInt32 => 32,
344        Int64 | UInt64 => 64,
345        Int128 => 128,
346    };
347
348    let index_type = arrow_format::ipc::Int {
349        bit_width,
350        is_signed,
351    };
352
353    arrow_format::ipc::DictionaryEncoding {
354        id: dict_id,
355        index_type: Some(Box::new(index_type)),
356        is_ordered: dict_is_ordered,
357        dictionary_kind: arrow_format::ipc::DictionaryKind::DenseArray,
358    }
359}