1use std::sync::Arc;
2
3use arrow_format::ipc::planus::ReadAsRoot;
4use arrow_format::ipc::{FieldRef, FixedSizeListRef, MapRef, TimeRef, TimestampRef, UnionRef};
5use polars_error::{polars_bail, polars_err, PolarsResult};
6use polars_utils::pl_str::PlSmallStr;
7
8use super::super::{IpcField, IpcSchema};
9use super::{OutOfSpecKind, StreamMetadata};
10use crate::datatypes::{
11 get_extension, ArrowDataType, ArrowSchema, Extension, ExtensionType, Field, IntegerType,
12 IntervalUnit, Metadata, TimeUnit, UnionMode, UnionType,
13};
14
15fn try_unzip_vec<A, B, I: Iterator<Item = PolarsResult<(A, B)>>>(
16 iter: I,
17) -> PolarsResult<(Vec<A>, Vec<B>)> {
18 let mut a = vec![];
19 let mut b = vec![];
20 for maybe_item in iter {
21 let (a_i, b_i) = maybe_item?;
22 a.push(a_i);
23 b.push(b_i);
24 }
25
26 Ok((a, b))
27}
28
29fn deserialize_field(ipc_field: arrow_format::ipc::FieldRef) -> PolarsResult<(Field, IpcField)> {
30 let metadata = read_metadata(&ipc_field)?;
31
32 let extension = metadata.as_ref().and_then(get_extension);
33
34 let (dtype, ipc_field_) = get_dtype(ipc_field, extension, true)?;
35
36 let field = Field {
37 name: PlSmallStr::from_str(
38 ipc_field
39 .name()?
40 .ok_or_else(|| polars_err!(oos = "Every field in IPC must have a name"))?,
41 ),
42 dtype,
43 is_nullable: ipc_field.nullable()?,
44 metadata: metadata.map(Arc::new),
45 };
46
47 Ok((field, ipc_field_))
48}
49
50fn read_metadata(field: &arrow_format::ipc::FieldRef) -> PolarsResult<Option<Metadata>> {
51 Ok(if let Some(list) = field.custom_metadata()? {
52 let mut metadata_map = Metadata::new();
53 for kv in list {
54 let kv = kv?;
55 if let (Some(k), Some(v)) = (kv.key()?, kv.value()?) {
56 metadata_map.insert(PlSmallStr::from_str(k), PlSmallStr::from_str(v));
57 }
58 }
59 Some(metadata_map)
60 } else {
61 None
62 })
63}
64
65fn deserialize_integer(int: arrow_format::ipc::IntRef) -> PolarsResult<IntegerType> {
66 Ok(match (int.bit_width()?, int.is_signed()?) {
67 (8, true) => IntegerType::Int8,
68 (8, false) => IntegerType::UInt8,
69 (16, true) => IntegerType::Int16,
70 (16, false) => IntegerType::UInt16,
71 (32, true) => IntegerType::Int32,
72 (32, false) => IntegerType::UInt32,
73 (64, true) => IntegerType::Int64,
74 (64, false) => IntegerType::UInt64,
75 (128, true) => IntegerType::Int128,
76 _ => polars_bail!(oos = "IPC: indexType can only be 8, 16, 32, 64 or 128."),
77 })
78}
79
80fn deserialize_timeunit(time_unit: arrow_format::ipc::TimeUnit) -> PolarsResult<TimeUnit> {
81 use arrow_format::ipc::TimeUnit::*;
82 Ok(match time_unit {
83 Second => TimeUnit::Second,
84 Millisecond => TimeUnit::Millisecond,
85 Microsecond => TimeUnit::Microsecond,
86 Nanosecond => TimeUnit::Nanosecond,
87 })
88}
89
90fn deserialize_time(time: TimeRef) -> PolarsResult<(ArrowDataType, IpcField)> {
91 let unit = deserialize_timeunit(time.unit()?)?;
92
93 let dtype = match (time.bit_width()?, unit) {
94 (32, TimeUnit::Second) => ArrowDataType::Time32(TimeUnit::Second),
95 (32, TimeUnit::Millisecond) => ArrowDataType::Time32(TimeUnit::Millisecond),
96 (64, TimeUnit::Microsecond) => ArrowDataType::Time64(TimeUnit::Microsecond),
97 (64, TimeUnit::Nanosecond) => ArrowDataType::Time64(TimeUnit::Nanosecond),
98 (bits, precision) => {
99 polars_bail!(ComputeError:
100 "Time type with bit width of {bits} and unit of {precision:?}"
101 )
102 },
103 };
104 Ok((dtype, IpcField::default()))
105}
106
107fn deserialize_timestamp(timestamp: TimestampRef) -> PolarsResult<(ArrowDataType, IpcField)> {
108 let timezone = timestamp.timezone()?;
109 let time_unit = deserialize_timeunit(timestamp.unit()?)?;
110 Ok((
111 ArrowDataType::Timestamp(time_unit, timezone.map(PlSmallStr::from_str)),
112 IpcField::default(),
113 ))
114}
115
116fn deserialize_union(union_: UnionRef, field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
117 let mode = UnionMode::sparse(union_.mode()? == arrow_format::ipc::UnionMode::Sparse);
118 let ids = union_.type_ids()?.map(|x| x.iter().collect());
119
120 let fields = field
121 .children()?
122 .ok_or_else(|| polars_err!(oos = "IPC: Union must contain children"))?;
123 if fields.is_empty() {
124 polars_bail!(oos = "IPC: Union must contain at least one child");
125 }
126
127 let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {
128 let (field, fields) = deserialize_field(field?)?;
129 Ok((field, fields))
130 }))?;
131 let ipc_field = IpcField {
132 fields: ipc_fields,
133 dictionary_id: None,
134 };
135 Ok((
136 ArrowDataType::Union(Box::new(UnionType { fields, ids, mode })),
137 ipc_field,
138 ))
139}
140
141fn deserialize_map(map: MapRef, field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
142 let is_sorted = map.keys_sorted()?;
143
144 let children = field
145 .children()?
146 .ok_or_else(|| polars_err!(oos = "IPC: Map must contain children"))?;
147 let inner = children
148 .get(0)
149 .ok_or_else(|| polars_err!(oos = "IPC: Map must contain one child"))??;
150 let (field, ipc_field) = deserialize_field(inner)?;
151
152 let dtype = ArrowDataType::Map(Box::new(field), is_sorted);
153 Ok((
154 dtype,
155 IpcField {
156 fields: vec![ipc_field],
157 dictionary_id: None,
158 },
159 ))
160}
161
162fn deserialize_struct(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
163 let fields = field
164 .children()?
165 .ok_or_else(|| polars_err!(oos = "IPC: Struct must contain children"))?;
166 let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {
167 let (field, fields) = deserialize_field(field?)?;
168 Ok((field, fields))
169 }))?;
170 let ipc_field = IpcField {
171 fields: ipc_fields,
172 dictionary_id: None,
173 };
174 Ok((ArrowDataType::Struct(fields), ipc_field))
175}
176
177fn deserialize_list(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
178 let children = field
179 .children()?
180 .ok_or_else(|| polars_err!(oos = "IPC: List must contain children"))?;
181 let inner = children
182 .get(0)
183 .ok_or_else(|| polars_err!(oos = "IPC: List must contain one child"))??;
184 let (field, ipc_field) = deserialize_field(inner)?;
185
186 Ok((
187 ArrowDataType::List(Box::new(field)),
188 IpcField {
189 fields: vec![ipc_field],
190 dictionary_id: None,
191 },
192 ))
193}
194
195fn deserialize_large_list(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
196 let children = field
197 .children()?
198 .ok_or_else(|| polars_err!(oos = "IPC: List must contain children"))?;
199 let inner = children
200 .get(0)
201 .ok_or_else(|| polars_err!(oos = "IPC: List must contain one child"))??;
202 let (field, ipc_field) = deserialize_field(inner)?;
203
204 Ok((
205 ArrowDataType::LargeList(Box::new(field)),
206 IpcField {
207 fields: vec![ipc_field],
208 dictionary_id: None,
209 },
210 ))
211}
212
213fn deserialize_fixed_size_list(
214 list: FixedSizeListRef,
215 field: FieldRef,
216) -> PolarsResult<(ArrowDataType, IpcField)> {
217 let children = field
218 .children()?
219 .ok_or_else(|| polars_err!(oos = "IPC: FixedSizeList must contain children"))?;
220 let inner = children
221 .get(0)
222 .ok_or_else(|| polars_err!(oos = "IPC: FixedSizeList must contain one child"))??;
223 let (field, ipc_field) = deserialize_field(inner)?;
224
225 let size = list
226 .list_size()?
227 .try_into()
228 .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
229
230 Ok((
231 ArrowDataType::FixedSizeList(Box::new(field), size),
232 IpcField {
233 fields: vec![ipc_field],
234 dictionary_id: None,
235 },
236 ))
237}
238
239fn get_dtype(
241 field: arrow_format::ipc::FieldRef,
242 extension: Extension,
243 may_be_dictionary: bool,
244) -> PolarsResult<(ArrowDataType, IpcField)> {
245 if let Some(dictionary) = field.dictionary()? {
246 if may_be_dictionary {
247 let int = dictionary
248 .index_type()?
249 .ok_or_else(|| polars_err!(oos = "indexType is mandatory in Dictionary."))?;
250 let index_type = deserialize_integer(int)?;
251 let (inner, mut ipc_field) = get_dtype(field, extension, false)?;
252 ipc_field.dictionary_id = Some(dictionary.id()?);
253 return Ok((
254 ArrowDataType::Dictionary(index_type, Box::new(inner), dictionary.is_ordered()?),
255 ipc_field,
256 ));
257 }
258 }
259
260 if let Some(extension) = extension {
261 let (name, metadata) = extension;
262 let (dtype, fields) = get_dtype(field, None, false)?;
263 return Ok((
264 ArrowDataType::Extension(Box::new(ExtensionType {
265 name,
266 inner: dtype,
267 metadata,
268 })),
269 fields,
270 ));
271 }
272
273 let type_ = field
274 .type_()?
275 .ok_or_else(|| polars_err!(oos = "IPC: field type is mandatory"))?;
276
277 use arrow_format::ipc::TypeRef::*;
278 Ok(match type_ {
279 Null(_) => (ArrowDataType::Null, IpcField::default()),
280 Bool(_) => (ArrowDataType::Boolean, IpcField::default()),
281 Int(int) => {
282 let dtype = deserialize_integer(int)?.into();
283 (dtype, IpcField::default())
284 },
285 Binary(_) => (ArrowDataType::Binary, IpcField::default()),
286 LargeBinary(_) => (ArrowDataType::LargeBinary, IpcField::default()),
287 Utf8(_) => (ArrowDataType::Utf8, IpcField::default()),
288 LargeUtf8(_) => (ArrowDataType::LargeUtf8, IpcField::default()),
289 BinaryView(_) => (ArrowDataType::BinaryView, IpcField::default()),
290 Utf8View(_) => (ArrowDataType::Utf8View, IpcField::default()),
291 FixedSizeBinary(fixed) => (
292 ArrowDataType::FixedSizeBinary(
293 fixed
294 .byte_width()?
295 .try_into()
296 .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?,
297 ),
298 IpcField::default(),
299 ),
300 FloatingPoint(float) => {
301 let dtype = match float.precision()? {
302 arrow_format::ipc::Precision::Half => ArrowDataType::Float16,
303 arrow_format::ipc::Precision::Single => ArrowDataType::Float32,
304 arrow_format::ipc::Precision::Double => ArrowDataType::Float64,
305 };
306 (dtype, IpcField::default())
307 },
308 Date(date) => {
309 let dtype = match date.unit()? {
310 arrow_format::ipc::DateUnit::Day => ArrowDataType::Date32,
311 arrow_format::ipc::DateUnit::Millisecond => ArrowDataType::Date64,
312 };
313 (dtype, IpcField::default())
314 },
315 Time(time) => deserialize_time(time)?,
316 Timestamp(timestamp) => deserialize_timestamp(timestamp)?,
317 Interval(interval) => {
318 let dtype = match interval.unit()? {
319 arrow_format::ipc::IntervalUnit::YearMonth => {
320 ArrowDataType::Interval(IntervalUnit::YearMonth)
321 },
322 arrow_format::ipc::IntervalUnit::DayTime => {
323 ArrowDataType::Interval(IntervalUnit::DayTime)
324 },
325 arrow_format::ipc::IntervalUnit::MonthDayNano => {
326 ArrowDataType::Interval(IntervalUnit::MonthDayNano)
327 },
328 };
329 (dtype, IpcField::default())
330 },
331 Duration(duration) => {
332 let time_unit = deserialize_timeunit(duration.unit()?)?;
333 (ArrowDataType::Duration(time_unit), IpcField::default())
334 },
335 Decimal(decimal) => {
336 let bit_width: usize = decimal
337 .bit_width()?
338 .try_into()
339 .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
340 let precision: usize = decimal
341 .precision()?
342 .try_into()
343 .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
344 let scale: usize = decimal
345 .scale()?
346 .try_into()
347 .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
348
349 let dtype = match bit_width {
350 128 => ArrowDataType::Decimal(precision, scale),
351 256 => ArrowDataType::Decimal256(precision, scale),
352 _ => return Err(polars_err!(oos = OutOfSpecKind::NegativeFooterLength)),
353 };
354
355 (dtype, IpcField::default())
356 },
357 List(_) => deserialize_list(field)?,
358 LargeList(_) => deserialize_large_list(field)?,
359 FixedSizeList(list) => deserialize_fixed_size_list(list, field)?,
360 Struct(_) => deserialize_struct(field)?,
361 Union(union_) => deserialize_union(union_, field)?,
362 Map(map) => deserialize_map(map, field)?,
363 RunEndEncoded(_) => todo!(),
364 LargeListView(_) | ListView(_) => todo!(),
365 })
366}
367
368pub fn deserialize_schema(
370 message: &[u8],
371) -> PolarsResult<(ArrowSchema, IpcSchema, Option<Metadata>)> {
372 let message = arrow_format::ipc::MessageRef::read_as_root(message)
373 .map_err(|err| polars_err!(oos = format!("Unable deserialize message: {err:?}")))?;
374
375 let schema = match message
376 .header()?
377 .ok_or_else(|| polars_err!(oos = "Unable to convert header to a schema".to_string()))?
378 {
379 arrow_format::ipc::MessageHeaderRef::Schema(schema) => PolarsResult::Ok(schema),
380 _ => polars_bail!(ComputeError: "The message is expected to be a Schema message"),
381 }?;
382
383 fb_to_schema(schema)
384}
385
386pub(super) fn fb_to_schema(
388 schema: arrow_format::ipc::SchemaRef,
389) -> PolarsResult<(ArrowSchema, IpcSchema, Option<Metadata>)> {
390 let fields = schema
391 .fields()?
392 .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingFields))?;
393
394 let mut arrow_schema = ArrowSchema::with_capacity(fields.len());
395 let mut ipc_fields = Vec::with_capacity(fields.len());
396
397 for field in fields {
398 let (field, ipc_field) = deserialize_field(field?)?;
399 arrow_schema.insert(field.name.clone(), field);
400 ipc_fields.push(ipc_field);
401 }
402
403 let is_little_endian = match schema.endianness()? {
404 arrow_format::ipc::Endianness::Little => true,
405 arrow_format::ipc::Endianness::Big => false,
406 };
407
408 let custom_schema_metadata = match schema.custom_metadata()? {
409 None => None,
410 Some(metadata) => {
411 let metadata: Metadata = metadata
412 .into_iter()
413 .filter_map(|kv_result| {
414 let kv_ref = kv_result.ok()?;
416 Some((kv_ref.key().ok()??.into(), kv_ref.value().ok()??.into()))
417 })
418 .collect();
419
420 if metadata.is_empty() {
421 None
422 } else {
423 Some(metadata)
424 }
425 },
426 };
427
428 Ok((
429 arrow_schema,
430 IpcSchema {
431 fields: ipc_fields,
432 is_little_endian,
433 },
434 custom_schema_metadata,
435 ))
436}
437
438pub(super) fn deserialize_stream_metadata(meta: &[u8]) -> PolarsResult<StreamMetadata> {
439 let message = arrow_format::ipc::MessageRef::read_as_root(meta)
440 .map_err(|err| polars_err!(oos = format!("Unable to get root as message: {err:?}")))?;
441 let version = message.version()?;
442 let header = message
444 .header()?
445 .ok_or_else(|| polars_err!(oos = "Unable to read the first IPC message"))?;
446 let schema = if let arrow_format::ipc::MessageHeaderRef::Schema(schema) = header {
447 schema
448 } else {
449 polars_bail!(oos = "The first IPC message of the stream must be a schema")
450 };
451 let (schema, ipc_schema, custom_schema_metadata) = fb_to_schema(schema)?;
452
453 Ok(StreamMetadata {
454 schema,
455 version,
456 ipc_schema,
457 custom_schema_metadata,
458 })
459}