polars_arrow/io/ipc/read/
schema.rs

1use std::sync::Arc;
2
3use arrow_format::ipc::planus::ReadAsRoot;
4use arrow_format::ipc::{FieldRef, FixedSizeListRef, MapRef, TimeRef, TimestampRef, UnionRef};
5use polars_error::{polars_bail, polars_err, PolarsResult};
6use polars_utils::pl_str::PlSmallStr;
7
8use super::super::{IpcField, IpcSchema};
9use super::{OutOfSpecKind, StreamMetadata};
10use crate::datatypes::{
11    get_extension, ArrowDataType, ArrowSchema, Extension, ExtensionType, Field, IntegerType,
12    IntervalUnit, Metadata, TimeUnit, UnionMode, UnionType,
13};
14
15fn try_unzip_vec<A, B, I: Iterator<Item = PolarsResult<(A, B)>>>(
16    iter: I,
17) -> PolarsResult<(Vec<A>, Vec<B>)> {
18    let mut a = vec![];
19    let mut b = vec![];
20    for maybe_item in iter {
21        let (a_i, b_i) = maybe_item?;
22        a.push(a_i);
23        b.push(b_i);
24    }
25
26    Ok((a, b))
27}
28
29fn deserialize_field(ipc_field: arrow_format::ipc::FieldRef) -> PolarsResult<(Field, IpcField)> {
30    let metadata = read_metadata(&ipc_field)?;
31
32    let extension = metadata.as_ref().and_then(get_extension);
33
34    let (dtype, ipc_field_) = get_dtype(ipc_field, extension, true)?;
35
36    let field = Field {
37        name: PlSmallStr::from_str(
38            ipc_field
39                .name()?
40                .ok_or_else(|| polars_err!(oos = "Every field in IPC must have a name"))?,
41        ),
42        dtype,
43        is_nullable: ipc_field.nullable()?,
44        metadata: metadata.map(Arc::new),
45    };
46
47    Ok((field, ipc_field_))
48}
49
50fn read_metadata(field: &arrow_format::ipc::FieldRef) -> PolarsResult<Option<Metadata>> {
51    Ok(if let Some(list) = field.custom_metadata()? {
52        let mut metadata_map = Metadata::new();
53        for kv in list {
54            let kv = kv?;
55            if let (Some(k), Some(v)) = (kv.key()?, kv.value()?) {
56                metadata_map.insert(PlSmallStr::from_str(k), PlSmallStr::from_str(v));
57            }
58        }
59        Some(metadata_map)
60    } else {
61        None
62    })
63}
64
65fn deserialize_integer(int: arrow_format::ipc::IntRef) -> PolarsResult<IntegerType> {
66    Ok(match (int.bit_width()?, int.is_signed()?) {
67        (8, true) => IntegerType::Int8,
68        (8, false) => IntegerType::UInt8,
69        (16, true) => IntegerType::Int16,
70        (16, false) => IntegerType::UInt16,
71        (32, true) => IntegerType::Int32,
72        (32, false) => IntegerType::UInt32,
73        (64, true) => IntegerType::Int64,
74        (64, false) => IntegerType::UInt64,
75        (128, true) => IntegerType::Int128,
76        _ => polars_bail!(oos = "IPC: indexType can only be 8, 16, 32, 64 or 128."),
77    })
78}
79
80fn deserialize_timeunit(time_unit: arrow_format::ipc::TimeUnit) -> PolarsResult<TimeUnit> {
81    use arrow_format::ipc::TimeUnit::*;
82    Ok(match time_unit {
83        Second => TimeUnit::Second,
84        Millisecond => TimeUnit::Millisecond,
85        Microsecond => TimeUnit::Microsecond,
86        Nanosecond => TimeUnit::Nanosecond,
87    })
88}
89
90fn deserialize_time(time: TimeRef) -> PolarsResult<(ArrowDataType, IpcField)> {
91    let unit = deserialize_timeunit(time.unit()?)?;
92
93    let dtype = match (time.bit_width()?, unit) {
94        (32, TimeUnit::Second) => ArrowDataType::Time32(TimeUnit::Second),
95        (32, TimeUnit::Millisecond) => ArrowDataType::Time32(TimeUnit::Millisecond),
96        (64, TimeUnit::Microsecond) => ArrowDataType::Time64(TimeUnit::Microsecond),
97        (64, TimeUnit::Nanosecond) => ArrowDataType::Time64(TimeUnit::Nanosecond),
98        (bits, precision) => {
99            polars_bail!(ComputeError:
100                "Time type with bit width of {bits} and unit of {precision:?}"
101            )
102        },
103    };
104    Ok((dtype, IpcField::default()))
105}
106
107fn deserialize_timestamp(timestamp: TimestampRef) -> PolarsResult<(ArrowDataType, IpcField)> {
108    let timezone = timestamp.timezone()?;
109    let time_unit = deserialize_timeunit(timestamp.unit()?)?;
110    Ok((
111        ArrowDataType::Timestamp(time_unit, timezone.map(PlSmallStr::from_str)),
112        IpcField::default(),
113    ))
114}
115
116fn deserialize_union(union_: UnionRef, field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
117    let mode = UnionMode::sparse(union_.mode()? == arrow_format::ipc::UnionMode::Sparse);
118    let ids = union_.type_ids()?.map(|x| x.iter().collect());
119
120    let fields = field
121        .children()?
122        .ok_or_else(|| polars_err!(oos = "IPC: Union must contain children"))?;
123    if fields.is_empty() {
124        polars_bail!(oos = "IPC: Union must contain at least one child");
125    }
126
127    let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {
128        let (field, fields) = deserialize_field(field?)?;
129        Ok((field, fields))
130    }))?;
131    let ipc_field = IpcField {
132        fields: ipc_fields,
133        dictionary_id: None,
134    };
135    Ok((
136        ArrowDataType::Union(Box::new(UnionType { fields, ids, mode })),
137        ipc_field,
138    ))
139}
140
141fn deserialize_map(map: MapRef, field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
142    let is_sorted = map.keys_sorted()?;
143
144    let children = field
145        .children()?
146        .ok_or_else(|| polars_err!(oos = "IPC: Map must contain children"))?;
147    let inner = children
148        .get(0)
149        .ok_or_else(|| polars_err!(oos = "IPC: Map must contain one child"))??;
150    let (field, ipc_field) = deserialize_field(inner)?;
151
152    let dtype = ArrowDataType::Map(Box::new(field), is_sorted);
153    Ok((
154        dtype,
155        IpcField {
156            fields: vec![ipc_field],
157            dictionary_id: None,
158        },
159    ))
160}
161
162fn deserialize_struct(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
163    let fields = field
164        .children()?
165        .ok_or_else(|| polars_err!(oos = "IPC: Struct must contain children"))?;
166    let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {
167        let (field, fields) = deserialize_field(field?)?;
168        Ok((field, fields))
169    }))?;
170    let ipc_field = IpcField {
171        fields: ipc_fields,
172        dictionary_id: None,
173    };
174    Ok((ArrowDataType::Struct(fields), ipc_field))
175}
176
177fn deserialize_list(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
178    let children = field
179        .children()?
180        .ok_or_else(|| polars_err!(oos = "IPC: List must contain children"))?;
181    let inner = children
182        .get(0)
183        .ok_or_else(|| polars_err!(oos = "IPC: List must contain one child"))??;
184    let (field, ipc_field) = deserialize_field(inner)?;
185
186    Ok((
187        ArrowDataType::List(Box::new(field)),
188        IpcField {
189            fields: vec![ipc_field],
190            dictionary_id: None,
191        },
192    ))
193}
194
195fn deserialize_large_list(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
196    let children = field
197        .children()?
198        .ok_or_else(|| polars_err!(oos = "IPC: List must contain children"))?;
199    let inner = children
200        .get(0)
201        .ok_or_else(|| polars_err!(oos = "IPC: List must contain one child"))??;
202    let (field, ipc_field) = deserialize_field(inner)?;
203
204    Ok((
205        ArrowDataType::LargeList(Box::new(field)),
206        IpcField {
207            fields: vec![ipc_field],
208            dictionary_id: None,
209        },
210    ))
211}
212
213fn deserialize_fixed_size_list(
214    list: FixedSizeListRef,
215    field: FieldRef,
216) -> PolarsResult<(ArrowDataType, IpcField)> {
217    let children = field
218        .children()?
219        .ok_or_else(|| polars_err!(oos = "IPC: FixedSizeList must contain children"))?;
220    let inner = children
221        .get(0)
222        .ok_or_else(|| polars_err!(oos = "IPC: FixedSizeList must contain one child"))??;
223    let (field, ipc_field) = deserialize_field(inner)?;
224
225    let size = list
226        .list_size()?
227        .try_into()
228        .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
229
230    Ok((
231        ArrowDataType::FixedSizeList(Box::new(field), size),
232        IpcField {
233            fields: vec![ipc_field],
234            dictionary_id: None,
235        },
236    ))
237}
238
239/// Get the Arrow data type from the flatbuffer Field table
240fn get_dtype(
241    field: arrow_format::ipc::FieldRef,
242    extension: Extension,
243    may_be_dictionary: bool,
244) -> PolarsResult<(ArrowDataType, IpcField)> {
245    if let Some(dictionary) = field.dictionary()? {
246        if may_be_dictionary {
247            let int = dictionary
248                .index_type()?
249                .ok_or_else(|| polars_err!(oos = "indexType is mandatory in Dictionary."))?;
250            let index_type = deserialize_integer(int)?;
251            let (inner, mut ipc_field) = get_dtype(field, extension, false)?;
252            ipc_field.dictionary_id = Some(dictionary.id()?);
253            return Ok((
254                ArrowDataType::Dictionary(index_type, Box::new(inner), dictionary.is_ordered()?),
255                ipc_field,
256            ));
257        }
258    }
259
260    if let Some(extension) = extension {
261        let (name, metadata) = extension;
262        let (dtype, fields) = get_dtype(field, None, false)?;
263        return Ok((
264            ArrowDataType::Extension(Box::new(ExtensionType {
265                name,
266                inner: dtype,
267                metadata,
268            })),
269            fields,
270        ));
271    }
272
273    let type_ = field
274        .type_()?
275        .ok_or_else(|| polars_err!(oos = "IPC: field type is mandatory"))?;
276
277    use arrow_format::ipc::TypeRef::*;
278    Ok(match type_ {
279        Null(_) => (ArrowDataType::Null, IpcField::default()),
280        Bool(_) => (ArrowDataType::Boolean, IpcField::default()),
281        Int(int) => {
282            let dtype = deserialize_integer(int)?.into();
283            (dtype, IpcField::default())
284        },
285        Binary(_) => (ArrowDataType::Binary, IpcField::default()),
286        LargeBinary(_) => (ArrowDataType::LargeBinary, IpcField::default()),
287        Utf8(_) => (ArrowDataType::Utf8, IpcField::default()),
288        LargeUtf8(_) => (ArrowDataType::LargeUtf8, IpcField::default()),
289        BinaryView(_) => (ArrowDataType::BinaryView, IpcField::default()),
290        Utf8View(_) => (ArrowDataType::Utf8View, IpcField::default()),
291        FixedSizeBinary(fixed) => (
292            ArrowDataType::FixedSizeBinary(
293                fixed
294                    .byte_width()?
295                    .try_into()
296                    .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?,
297            ),
298            IpcField::default(),
299        ),
300        FloatingPoint(float) => {
301            let dtype = match float.precision()? {
302                arrow_format::ipc::Precision::Half => ArrowDataType::Float16,
303                arrow_format::ipc::Precision::Single => ArrowDataType::Float32,
304                arrow_format::ipc::Precision::Double => ArrowDataType::Float64,
305            };
306            (dtype, IpcField::default())
307        },
308        Date(date) => {
309            let dtype = match date.unit()? {
310                arrow_format::ipc::DateUnit::Day => ArrowDataType::Date32,
311                arrow_format::ipc::DateUnit::Millisecond => ArrowDataType::Date64,
312            };
313            (dtype, IpcField::default())
314        },
315        Time(time) => deserialize_time(time)?,
316        Timestamp(timestamp) => deserialize_timestamp(timestamp)?,
317        Interval(interval) => {
318            let dtype = match interval.unit()? {
319                arrow_format::ipc::IntervalUnit::YearMonth => {
320                    ArrowDataType::Interval(IntervalUnit::YearMonth)
321                },
322                arrow_format::ipc::IntervalUnit::DayTime => {
323                    ArrowDataType::Interval(IntervalUnit::DayTime)
324                },
325                arrow_format::ipc::IntervalUnit::MonthDayNano => {
326                    ArrowDataType::Interval(IntervalUnit::MonthDayNano)
327                },
328            };
329            (dtype, IpcField::default())
330        },
331        Duration(duration) => {
332            let time_unit = deserialize_timeunit(duration.unit()?)?;
333            (ArrowDataType::Duration(time_unit), IpcField::default())
334        },
335        Decimal(decimal) => {
336            let bit_width: usize = decimal
337                .bit_width()?
338                .try_into()
339                .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
340            let precision: usize = decimal
341                .precision()?
342                .try_into()
343                .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
344            let scale: usize = decimal
345                .scale()?
346                .try_into()
347                .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
348
349            let dtype = match bit_width {
350                128 => ArrowDataType::Decimal(precision, scale),
351                256 => ArrowDataType::Decimal256(precision, scale),
352                _ => return Err(polars_err!(oos = OutOfSpecKind::NegativeFooterLength)),
353            };
354
355            (dtype, IpcField::default())
356        },
357        List(_) => deserialize_list(field)?,
358        LargeList(_) => deserialize_large_list(field)?,
359        FixedSizeList(list) => deserialize_fixed_size_list(list, field)?,
360        Struct(_) => deserialize_struct(field)?,
361        Union(union_) => deserialize_union(union_, field)?,
362        Map(map) => deserialize_map(map, field)?,
363        RunEndEncoded(_) => todo!(),
364        LargeListView(_) | ListView(_) => todo!(),
365    })
366}
367
368/// Deserialize an flatbuffers-encoded Schema message into [`ArrowSchema`] and [`IpcSchema`].
369pub fn deserialize_schema(
370    message: &[u8],
371) -> PolarsResult<(ArrowSchema, IpcSchema, Option<Metadata>)> {
372    let message = arrow_format::ipc::MessageRef::read_as_root(message)
373        .map_err(|err| polars_err!(oos = format!("Unable deserialize message: {err:?}")))?;
374
375    let schema = match message
376        .header()?
377        .ok_or_else(|| polars_err!(oos = "Unable to convert header to a schema".to_string()))?
378    {
379        arrow_format::ipc::MessageHeaderRef::Schema(schema) => PolarsResult::Ok(schema),
380        _ => polars_bail!(ComputeError: "The message is expected to be a Schema message"),
381    }?;
382
383    fb_to_schema(schema)
384}
385
386/// Deserialize the raw Schema table from IPC format to Schema data type
387pub(super) fn fb_to_schema(
388    schema: arrow_format::ipc::SchemaRef,
389) -> PolarsResult<(ArrowSchema, IpcSchema, Option<Metadata>)> {
390    let fields = schema
391        .fields()?
392        .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingFields))?;
393
394    let mut arrow_schema = ArrowSchema::with_capacity(fields.len());
395    let mut ipc_fields = Vec::with_capacity(fields.len());
396
397    for field in fields {
398        let (field, ipc_field) = deserialize_field(field?)?;
399        arrow_schema.insert(field.name.clone(), field);
400        ipc_fields.push(ipc_field);
401    }
402
403    let is_little_endian = match schema.endianness()? {
404        arrow_format::ipc::Endianness::Little => true,
405        arrow_format::ipc::Endianness::Big => false,
406    };
407
408    let custom_schema_metadata = match schema.custom_metadata()? {
409        None => None,
410        Some(metadata) => {
411            let metadata: Metadata = metadata
412                .into_iter()
413                .filter_map(|kv_result| {
414                    // FIXME: silently hiding errors here
415                    let kv_ref = kv_result.ok()?;
416                    Some((kv_ref.key().ok()??.into(), kv_ref.value().ok()??.into()))
417                })
418                .collect();
419
420            if metadata.is_empty() {
421                None
422            } else {
423                Some(metadata)
424            }
425        },
426    };
427
428    Ok((
429        arrow_schema,
430        IpcSchema {
431            fields: ipc_fields,
432            is_little_endian,
433        },
434        custom_schema_metadata,
435    ))
436}
437
438pub(super) fn deserialize_stream_metadata(meta: &[u8]) -> PolarsResult<StreamMetadata> {
439    let message = arrow_format::ipc::MessageRef::read_as_root(meta)
440        .map_err(|err| polars_err!(oos = format!("Unable to get root as message: {err:?}")))?;
441    let version = message.version()?;
442    // message header is a Schema, so read it
443    let header = message
444        .header()?
445        .ok_or_else(|| polars_err!(oos = "Unable to read the first IPC message"))?;
446    let schema = if let arrow_format::ipc::MessageHeaderRef::Schema(schema) = header {
447        schema
448    } else {
449        polars_bail!(oos = "The first IPC message of the stream must be a schema")
450    };
451    let (schema, ipc_schema, custom_schema_metadata) = fb_to_schema(schema)?;
452
453    Ok(StreamMetadata {
454        schema,
455        version,
456        ipc_schema,
457        custom_schema_metadata,
458    })
459}