1use std::fmt::Debug;
2
3use polars_utils::pl_str::PlSmallStr;
4
5use crate::prelude::*;
6use crate::utils::try_get_supertype;
7
8pub type SchemaRef = Arc<Schema>;
9pub type Schema = polars_schema::Schema<DataType>;
10
11pub trait SchemaExt {
12 fn from_arrow_schema(value: &ArrowSchema) -> Self;
13
14 fn get_field(&self, name: &str) -> Option<Field>;
15
16 fn try_get_field(&self, name: &str) -> PolarsResult<Field>;
17
18 fn to_arrow(&self, compat_level: CompatLevel) -> ArrowSchema;
19
20 fn iter_fields(&self) -> impl ExactSizeIterator<Item = Field> + '_;
21
22 fn to_supertype(&mut self, other: &Schema) -> PolarsResult<bool>;
23}
24
25impl SchemaExt for Schema {
26 fn from_arrow_schema(value: &ArrowSchema) -> Self {
27 value
28 .iter_values()
29 .map(|x| (x.name.clone(), DataType::from_arrow_field(x)))
30 .collect()
31 }
32
33 fn get_field(&self, name: &str) -> Option<Field> {
40 self.get_full(name)
41 .map(|(_, name, dtype)| Field::new(name.clone(), dtype.clone()))
42 }
43
44 fn try_get_field(&self, name: &str) -> PolarsResult<Field> {
51 self.get_full(name)
52 .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name))
53 .map(|(_, name, dtype)| Field::new(name.clone(), dtype.clone()))
54 }
55
56 fn to_arrow(&self, compat_level: CompatLevel) -> ArrowSchema {
58 self.iter()
59 .map(|(name, dtype)| {
60 (
61 name.clone(),
62 dtype.to_arrow_field(name.clone(), compat_level),
63 )
64 })
65 .collect()
66 }
67
68 fn iter_fields(&self) -> impl ExactSizeIterator<Item = Field> + '_ {
73 self.iter()
74 .map(|(name, dtype)| Field::new(name.clone(), dtype.clone()))
75 }
76
77 fn to_supertype(&mut self, other: &Schema) -> PolarsResult<bool> {
79 polars_ensure!(self.len() == other.len(), ComputeError: "schema lengths differ");
80
81 let mut changed = false;
82 for ((k, dt), (other_k, other_dt)) in self.iter_mut().zip(other.iter()) {
83 polars_ensure!(k == other_k, ComputeError: "schema names differ: got {}, expected {}", k, other_k);
84
85 let st = try_get_supertype(dt, other_dt)?;
86 changed |= (&st != dt) || (&st != other_dt);
87 *dt = st
88 }
89 Ok(changed)
90 }
91}
92
93pub trait SchemaNamesAndDtypes {
94 const IS_ARROW: bool;
95 type DataType: Debug + Clone + Default + PartialEq;
96
97 fn iter_names_and_dtypes(
98 &self,
99 ) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)>;
100}
101
102impl SchemaNamesAndDtypes for ArrowSchema {
103 const IS_ARROW: bool = true;
104 type DataType = ArrowDataType;
105
106 fn iter_names_and_dtypes(
107 &self,
108 ) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)> {
109 self.iter_values().map(|x| (&x.name, &x.dtype))
110 }
111}
112
113impl SchemaNamesAndDtypes for Schema {
114 const IS_ARROW: bool = false;
115 type DataType = DataType;
116
117 fn iter_names_and_dtypes(
118 &self,
119 ) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)> {
120 self.iter()
121 }
122}
123
124pub fn ensure_matching_schema<D>(
125 lhs: &polars_schema::Schema<D>,
126 rhs: &polars_schema::Schema<D>,
127) -> PolarsResult<()>
128where
129 polars_schema::Schema<D>: SchemaNamesAndDtypes,
130{
131 let lhs = lhs.iter_names_and_dtypes();
132 let rhs = rhs.iter_names_and_dtypes();
133
134 if lhs.len() != rhs.len() {
135 polars_bail!(
136 SchemaMismatch:
137 "schemas contained differing number of columns: {} != {}",
138 lhs.len(), rhs.len(),
139 );
140 }
141
142 for (i, ((l_name, l_dtype), (r_name, r_dtype))) in lhs.zip(rhs).enumerate() {
143 if l_name != r_name {
144 polars_bail!(
145 SchemaMismatch:
146 "schema names differ at index {}: {} != {}",
147 i, l_name, r_name
148 )
149 }
150 if l_dtype != r_dtype
151 && (!polars_schema::Schema::<D>::IS_ARROW
152 || unsafe {
153 DataType::from_arrow_dtype(std::mem::transmute::<
155 &<polars_schema::Schema<D> as SchemaNamesAndDtypes>::DataType,
156 &ArrowDataType,
157 >(l_dtype))
158 != DataType::from_arrow_dtype(std::mem::transmute::<
159 &<polars_schema::Schema<D> as SchemaNamesAndDtypes>::DataType,
160 &ArrowDataType,
161 >(r_dtype))
162 })
163 {
164 polars_bail!(
165 SchemaMismatch:
166 "schema dtypes differ at index {} for column {}: {:?} != {:?}",
167 i, l_name, l_dtype, r_dtype
168 )
169 }
170 }
171
172 Ok(())
173}