1mod field;
4mod physical_type;
5pub mod reshape;
6mod schema;
7
8use std::collections::BTreeMap;
9use std::sync::Arc;
10
11pub use field::{Field, DTYPE_CATEGORICAL, DTYPE_ENUM_VALUES};
12pub use physical_type::*;
13use polars_utils::pl_str::PlSmallStr;
14pub use schema::{ArrowSchema, ArrowSchemaRef};
15#[cfg(feature = "serde")]
16use serde::{Deserialize, Serialize};
17
18pub type Metadata = BTreeMap<PlSmallStr, PlSmallStr>;
20pub(crate) type Extension = Option<(PlSmallStr, Option<PlSmallStr>)>;
22
23#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
32#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
33pub enum ArrowDataType {
34 #[default]
36 Null,
37 Boolean,
39 Int8,
41 Int16,
43 Int32,
45 Int64,
47 Int128,
49 UInt8,
51 UInt16,
53 UInt32,
55 UInt64,
57 Float16,
59 Float32,
61 Float64,
63 Timestamp(TimeUnit, Option<PlSmallStr>),
78 Date32,
81 Date64,
84 Time32(TimeUnit),
87 Time64(TimeUnit),
90 Duration(TimeUnit),
92 Interval(IntervalUnit),
95 Binary,
97 FixedSizeBinary(usize),
100 LargeBinary,
102 Utf8,
104 LargeUtf8,
106 List(Box<Field>),
108 FixedSizeList(Box<Field>, usize),
110 LargeList(Box<Field>),
112 Struct(Vec<Field>),
114 Map(Box<Field>, bool),
142 Dictionary(IntegerType, Box<ArrowDataType>, bool),
155 Decimal(usize, usize),
160 Decimal256(usize, usize),
162 Extension(Box<ExtensionType>),
164 BinaryView,
167 Utf8View,
170 Unknown,
172 #[cfg_attr(feature = "serde", serde(skip))]
175 Union(Box<UnionType>),
176}
177
178#[derive(Debug, Clone, PartialEq, Eq, Hash)]
179#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
180pub struct ExtensionType {
181 pub name: PlSmallStr,
182 pub inner: ArrowDataType,
183 pub metadata: Option<PlSmallStr>,
184}
185
186#[derive(Debug, Clone, PartialEq, Eq, Hash)]
187pub struct UnionType {
188 pub fields: Vec<Field>,
189 pub ids: Option<Vec<i32>>,
190 pub mode: UnionMode,
191}
192
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
195#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
196pub enum UnionMode {
197 Dense,
199 Sparse,
201}
202
203impl UnionMode {
204 pub fn sparse(is_sparse: bool) -> Self {
207 if is_sparse {
208 Self::Sparse
209 } else {
210 Self::Dense
211 }
212 }
213
214 pub fn is_sparse(&self) -> bool {
216 matches!(self, Self::Sparse)
217 }
218
219 pub fn is_dense(&self) -> bool {
221 matches!(self, Self::Dense)
222 }
223}
224
225#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
227#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
228pub enum TimeUnit {
229 Second,
231 Millisecond,
233 Microsecond,
235 Nanosecond,
237}
238
239#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
241#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
242pub enum IntervalUnit {
243 YearMonth,
245 DayTime,
248 MonthDayNano,
250}
251
252impl ArrowDataType {
253 pub fn to_physical_type(&self) -> PhysicalType {
255 use ArrowDataType::*;
256 match self {
257 Null => PhysicalType::Null,
258 Boolean => PhysicalType::Boolean,
259 Int8 => PhysicalType::Primitive(PrimitiveType::Int8),
260 Int16 => PhysicalType::Primitive(PrimitiveType::Int16),
261 Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => {
262 PhysicalType::Primitive(PrimitiveType::Int32)
263 },
264 Int64 | Date64 | Timestamp(_, _) | Time64(_) | Duration(_) => {
265 PhysicalType::Primitive(PrimitiveType::Int64)
266 },
267 Decimal(_, _) => PhysicalType::Primitive(PrimitiveType::Int128),
268 Decimal256(_, _) => PhysicalType::Primitive(PrimitiveType::Int256),
269 UInt8 => PhysicalType::Primitive(PrimitiveType::UInt8),
270 UInt16 => PhysicalType::Primitive(PrimitiveType::UInt16),
271 UInt32 => PhysicalType::Primitive(PrimitiveType::UInt32),
272 UInt64 => PhysicalType::Primitive(PrimitiveType::UInt64),
273 Float16 => PhysicalType::Primitive(PrimitiveType::Float16),
274 Float32 => PhysicalType::Primitive(PrimitiveType::Float32),
275 Float64 => PhysicalType::Primitive(PrimitiveType::Float64),
276 Int128 => PhysicalType::Primitive(PrimitiveType::Int128),
277 Interval(IntervalUnit::DayTime) => PhysicalType::Primitive(PrimitiveType::DaysMs),
278 Interval(IntervalUnit::MonthDayNano) => {
279 PhysicalType::Primitive(PrimitiveType::MonthDayNano)
280 },
281 Binary => PhysicalType::Binary,
282 FixedSizeBinary(_) => PhysicalType::FixedSizeBinary,
283 LargeBinary => PhysicalType::LargeBinary,
284 Utf8 => PhysicalType::Utf8,
285 LargeUtf8 => PhysicalType::LargeUtf8,
286 BinaryView => PhysicalType::BinaryView,
287 Utf8View => PhysicalType::Utf8View,
288 List(_) => PhysicalType::List,
289 FixedSizeList(_, _) => PhysicalType::FixedSizeList,
290 LargeList(_) => PhysicalType::LargeList,
291 Struct(_) => PhysicalType::Struct,
292 Union(_) => PhysicalType::Union,
293 Map(_, _) => PhysicalType::Map,
294 Dictionary(key, _, _) => PhysicalType::Dictionary(*key),
295 Extension(ext) => ext.inner.to_physical_type(),
296 Unknown => unimplemented!(),
297 }
298 }
299
300 pub fn underlying_physical_type(&self) -> ArrowDataType {
302 use ArrowDataType::*;
303 match self {
304 Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => Int32,
305 Date64
306 | Timestamp(_, _)
307 | Time64(_)
308 | Duration(_)
309 | Interval(IntervalUnit::DayTime) => Int64,
310 Interval(IntervalUnit::MonthDayNano) => unimplemented!(),
311 Binary => Binary,
312 List(field) => List(Box::new(Field {
313 dtype: field.dtype.underlying_physical_type(),
314 ..*field.clone()
315 })),
316 LargeList(field) => LargeList(Box::new(Field {
317 dtype: field.dtype.underlying_physical_type(),
318 ..*field.clone()
319 })),
320 FixedSizeList(field, width) => FixedSizeList(
321 Box::new(Field {
322 dtype: field.dtype.underlying_physical_type(),
323 ..*field.clone()
324 }),
325 *width,
326 ),
327 Struct(fields) => Struct(
328 fields
329 .iter()
330 .map(|field| Field {
331 dtype: field.dtype.underlying_physical_type(),
332 ..field.clone()
333 })
334 .collect(),
335 ),
336 Dictionary(keys, _, _) => (*keys).into(),
337 Union(_) => unimplemented!(),
338 Map(_, _) => unimplemented!(),
339 Extension(ext) => ext.inner.underlying_physical_type(),
340 _ => self.clone(),
341 }
342 }
343
344 pub fn to_logical_type(&self) -> &ArrowDataType {
348 use ArrowDataType::*;
349 match self {
350 Extension(ext) => ext.inner.to_logical_type(),
351 _ => self,
352 }
353 }
354
355 pub fn inner_dtype(&self) -> Option<&ArrowDataType> {
356 match self {
357 ArrowDataType::List(inner) => Some(inner.dtype()),
358 ArrowDataType::LargeList(inner) => Some(inner.dtype()),
359 ArrowDataType::FixedSizeList(inner, _) => Some(inner.dtype()),
360 _ => None,
361 }
362 }
363
364 pub fn is_nested(&self) -> bool {
365 use ArrowDataType as D;
366
367 matches!(
368 self,
369 D::List(_)
370 | D::LargeList(_)
371 | D::FixedSizeList(_, _)
372 | D::Struct(_)
373 | D::Union(_)
374 | D::Map(_, _)
375 | D::Dictionary(_, _, _)
376 | D::Extension(_)
377 )
378 }
379
380 pub fn is_view(&self) -> bool {
381 matches!(self, ArrowDataType::Utf8View | ArrowDataType::BinaryView)
382 }
383
384 pub fn is_numeric(&self) -> bool {
385 use ArrowDataType as D;
386 matches!(
387 self,
388 D::Int8
389 | D::Int16
390 | D::Int32
391 | D::Int64
392 | D::Int128
393 | D::UInt8
394 | D::UInt16
395 | D::UInt32
396 | D::UInt64
397 | D::Float32
398 | D::Float64
399 | D::Decimal(_, _)
400 | D::Decimal256(_, _)
401 )
402 }
403
404 pub fn to_fixed_size_list(self, size: usize, is_nullable: bool) -> ArrowDataType {
405 ArrowDataType::FixedSizeList(
406 Box::new(Field::new(
407 PlSmallStr::from_static("item"),
408 self,
409 is_nullable,
410 )),
411 size,
412 )
413 }
414
415 pub fn contains_dictionary(&self) -> bool {
417 use ArrowDataType as D;
418 match self {
419 D::Null
420 | D::Boolean
421 | D::Int8
422 | D::Int16
423 | D::Int32
424 | D::Int64
425 | D::UInt8
426 | D::UInt16
427 | D::UInt32
428 | D::UInt64
429 | D::Int128
430 | D::Float16
431 | D::Float32
432 | D::Float64
433 | D::Timestamp(_, _)
434 | D::Date32
435 | D::Date64
436 | D::Time32(_)
437 | D::Time64(_)
438 | D::Duration(_)
439 | D::Interval(_)
440 | D::Binary
441 | D::FixedSizeBinary(_)
442 | D::LargeBinary
443 | D::Utf8
444 | D::LargeUtf8
445 | D::Decimal(_, _)
446 | D::Decimal256(_, _)
447 | D::BinaryView
448 | D::Utf8View
449 | D::Unknown => false,
450 D::List(field)
451 | D::FixedSizeList(field, _)
452 | D::Map(field, _)
453 | D::LargeList(field) => field.dtype().contains_dictionary(),
454 D::Struct(fields) => fields.iter().any(|f| f.dtype().contains_dictionary()),
455 D::Union(union) => union.fields.iter().any(|f| f.dtype().contains_dictionary()),
456 D::Dictionary(_, _, _) => true,
457 D::Extension(ext) => ext.inner.contains_dictionary(),
458 }
459 }
460}
461
462impl From<IntegerType> for ArrowDataType {
463 fn from(item: IntegerType) -> Self {
464 match item {
465 IntegerType::Int8 => ArrowDataType::Int8,
466 IntegerType::Int16 => ArrowDataType::Int16,
467 IntegerType::Int32 => ArrowDataType::Int32,
468 IntegerType::Int64 => ArrowDataType::Int64,
469 IntegerType::Int128 => ArrowDataType::Int128,
470 IntegerType::UInt8 => ArrowDataType::UInt8,
471 IntegerType::UInt16 => ArrowDataType::UInt16,
472 IntegerType::UInt32 => ArrowDataType::UInt32,
473 IntegerType::UInt64 => ArrowDataType::UInt64,
474 }
475 }
476}
477
478impl From<PrimitiveType> for ArrowDataType {
479 fn from(item: PrimitiveType) -> Self {
480 match item {
481 PrimitiveType::Int8 => ArrowDataType::Int8,
482 PrimitiveType::Int16 => ArrowDataType::Int16,
483 PrimitiveType::Int32 => ArrowDataType::Int32,
484 PrimitiveType::Int64 => ArrowDataType::Int64,
485 PrimitiveType::UInt8 => ArrowDataType::UInt8,
486 PrimitiveType::UInt16 => ArrowDataType::UInt16,
487 PrimitiveType::UInt32 => ArrowDataType::UInt32,
488 PrimitiveType::UInt64 => ArrowDataType::UInt64,
489 PrimitiveType::Int128 => ArrowDataType::Int128,
490 PrimitiveType::Int256 => ArrowDataType::Decimal256(32, 32),
491 PrimitiveType::Float16 => ArrowDataType::Float16,
492 PrimitiveType::Float32 => ArrowDataType::Float32,
493 PrimitiveType::Float64 => ArrowDataType::Float64,
494 PrimitiveType::DaysMs => ArrowDataType::Interval(IntervalUnit::DayTime),
495 PrimitiveType::MonthDayNano => ArrowDataType::Interval(IntervalUnit::MonthDayNano),
496 PrimitiveType::UInt128 => unimplemented!(),
497 }
498 }
499}
500
501pub type SchemaRef = Arc<ArrowSchema>;
503
504pub fn get_extension(metadata: &Metadata) -> Extension {
506 if let Some(name) = metadata.get(&PlSmallStr::from_static("ARROW:extension:name")) {
507 let metadata = metadata
508 .get(&PlSmallStr::from_static("ARROW:extension:metadata"))
509 .cloned();
510 Some((name.clone(), metadata))
511 } else {
512 None
513 }
514}
515
516#[cfg(not(feature = "bigidx"))]
517pub type IdxArr = super::array::UInt32Array;
518#[cfg(feature = "bigidx")]
519pub type IdxArr = super::array::UInt64Array;