polars_schema/
schema.rs

1use core::fmt::{Debug, Formatter};
2use core::hash::{Hash, Hasher};
3
4use indexmap::map::MutableKeys;
5use polars_error::{polars_bail, polars_ensure, polars_err, PolarsResult};
6use polars_utils::aliases::{InitHashMaps, PlIndexMap};
7use polars_utils::pl_str::PlSmallStr;
8
9#[derive(Clone, Default)]
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11pub struct Schema<D> {
12    fields: PlIndexMap<PlSmallStr, D>,
13}
14
15impl<D: Eq> Eq for Schema<D> {}
16
17impl<D> Schema<D> {
18    pub fn with_capacity(capacity: usize) -> Self {
19        let fields = PlIndexMap::with_capacity(capacity);
20        Self { fields }
21    }
22
23    /// Reserve `additional` memory spaces in the schema.
24    pub fn reserve(&mut self, additional: usize) {
25        self.fields.reserve(additional);
26    }
27
28    /// The number of fields in the schema.
29    #[inline]
30    pub fn len(&self) -> usize {
31        self.fields.len()
32    }
33
34    #[inline]
35    pub fn is_empty(&self) -> bool {
36        self.fields.is_empty()
37    }
38
39    /// Rename field `old` to `new`, and return the (owned) old name.
40    ///
41    /// If `old` is not present in the schema, the schema is not modified and `None` is returned. Otherwise the schema
42    /// is updated and `Some(old_name)` is returned.
43    pub fn rename(&mut self, old: &str, new: PlSmallStr) -> Option<PlSmallStr> {
44        // Remove `old`, get the corresponding index and dtype, and move the last item in the map to that position
45        let (old_index, old_name, dtype) = self.fields.swap_remove_full(old)?;
46        // Insert the same dtype under the new name at the end of the map and store that index
47        let (new_index, _) = self.fields.insert_full(new, dtype);
48        // Swap the two indices to move the originally last element back to the end and to move the new element back to
49        // its original position
50        self.fields.swap_indices(old_index, new_index);
51
52        Some(old_name)
53    }
54
55    pub fn insert(&mut self, key: PlSmallStr, value: D) -> Option<D> {
56        self.fields.insert(key, value)
57    }
58
59    /// Insert a field with `name` and `dtype` at the given `index` into this schema.
60    ///
61    /// If a field named `name` already exists, it is updated with the new dtype. Regardless, the field named `name` is
62    /// always moved to the given index. Valid indices range from `0` (front of the schema) to `self.len()` (after the
63    /// end of the schema).
64    ///
65    /// For a non-mutating version that clones the schema, see [`new_inserting_at_index`][Self::new_inserting_at_index].
66    ///
67    /// Runtime: **O(n)** where `n` is the number of fields in the schema.
68    ///
69    /// Returns:
70    /// - If index is out of bounds, `Err(PolarsError)`
71    /// - Else if `name` was already in the schema, `Ok(Some(old_dtype))`
72    /// - Else `Ok(None)`
73    pub fn insert_at_index(
74        &mut self,
75        mut index: usize,
76        name: PlSmallStr,
77        dtype: D,
78    ) -> PolarsResult<Option<D>> {
79        polars_ensure!(
80            index <= self.len(),
81            OutOfBounds:
82                "index {} is out of bounds for schema with length {} (the max index allowed is self.len())",
83                    index,
84                    self.len()
85        );
86
87        let (old_index, old_dtype) = self.fields.insert_full(name, dtype);
88
89        // If we're moving an existing field, one-past-the-end will actually be out of bounds. Also, self.len() won't
90        // have changed after inserting, so `index == self.len()` is the same as it was before inserting.
91        if old_dtype.is_some() && index == self.len() {
92            index -= 1;
93        }
94        self.fields.move_index(old_index, index);
95        Ok(old_dtype)
96    }
97
98    /// Get a reference to the dtype of the field named `name`, or `None` if the field doesn't exist.
99    pub fn get(&self, name: &str) -> Option<&D> {
100        self.fields.get(name)
101    }
102
103    /// Get a reference to the dtype of the field named `name`, or `Err(PolarsErr)` if the field doesn't exist.
104    pub fn try_get(&self, name: &str) -> PolarsResult<&D> {
105        self.get(name)
106            .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name))
107    }
108
109    /// Get a mutable reference to the dtype of the field named `name`, or `Err(PolarsErr)` if the field doesn't exist.
110    pub fn try_get_mut(&mut self, name: &str) -> PolarsResult<&mut D> {
111        self.fields
112            .get_mut(name)
113            .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name))
114    }
115
116    /// Return all data about the field named `name`: its index in the schema, its name, and its dtype.
117    ///
118    /// Returns `Some((index, &name, &dtype))` if the field exists, `None` if it doesn't.
119    pub fn get_full(&self, name: &str) -> Option<(usize, &PlSmallStr, &D)> {
120        self.fields.get_full(name)
121    }
122
123    /// Return all data about the field named `name`: its index in the schema, its name, and its dtype.
124    ///
125    /// Returns `Ok((index, &name, &dtype))` if the field exists, `Err(PolarsErr)` if it doesn't.
126    pub fn try_get_full(&self, name: &str) -> PolarsResult<(usize, &PlSmallStr, &D)> {
127        self.fields
128            .get_full(name)
129            .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name))
130    }
131
132    /// Get references to the name and dtype of the field at `index`.
133    ///
134    /// If `index` is inbounds, returns `Some((&name, &dtype))`, else `None`. See
135    /// [`get_at_index_mut`][Self::get_at_index_mut] for a mutable version.
136    pub fn get_at_index(&self, index: usize) -> Option<(&PlSmallStr, &D)> {
137        self.fields.get_index(index)
138    }
139
140    pub fn try_get_at_index(&self, index: usize) -> PolarsResult<(&PlSmallStr, &D)> {
141        self.fields.get_index(index).ok_or_else(|| polars_err!(ComputeError: "index {index} out of bounds with 'schema' of len: {}", self.len()))
142    }
143
144    /// Get mutable references to the name and dtype of the field at `index`.
145    ///
146    /// If `index` is inbounds, returns `Some((&mut name, &mut dtype))`, else `None`. See
147    /// [`get_at_index`][Self::get_at_index] for an immutable version.
148    pub fn get_at_index_mut(&mut self, index: usize) -> Option<(&mut PlSmallStr, &mut D)> {
149        self.fields.get_index_mut2(index)
150    }
151
152    /// Swap-remove a field by name and, if the field existed, return its dtype.
153    ///
154    /// If the field does not exist, the schema is not modified and `None` is returned.
155    ///
156    /// This method does a `swap_remove`, which is O(1) but **changes the order of the schema**: the field named `name`
157    /// is replaced by the last field, which takes its position. For a slower, but order-preserving, method, use
158    /// [`shift_remove`][Self::shift_remove].
159    pub fn remove(&mut self, name: &str) -> Option<D> {
160        self.fields.swap_remove(name)
161    }
162
163    /// Remove a field by name, preserving order, and, if the field existed, return its dtype.
164    ///
165    /// If the field does not exist, the schema is not modified and `None` is returned.
166    ///
167    /// This method does a `shift_remove`, which preserves the order of the fields in the schema but **is O(n)**. For a
168    /// faster, but not order-preserving, method, use [`remove`][Self::remove].
169    pub fn shift_remove(&mut self, name: &str) -> Option<D> {
170        self.fields.shift_remove(name)
171    }
172
173    /// Remove a field by name, preserving order, and, if the field existed, return its dtype.
174    ///
175    /// If the field does not exist, the schema is not modified and `None` is returned.
176    ///
177    /// This method does a `shift_remove`, which preserves the order of the fields in the schema but **is O(n)**. For a
178    /// faster, but not order-preserving, method, use [`remove`][Self::remove].
179    pub fn shift_remove_index(&mut self, index: usize) -> Option<(PlSmallStr, D)> {
180        self.fields.shift_remove_index(index)
181    }
182
183    /// Whether the schema contains a field named `name`.
184    pub fn contains(&self, name: &str) -> bool {
185        self.get(name).is_some()
186    }
187
188    /// Change the field named `name` to the given `dtype` and return the previous dtype.
189    ///
190    /// If `name` doesn't already exist in the schema, the schema is not modified and `None` is returned. Otherwise
191    /// returns `Some(old_dtype)`.
192    ///
193    /// This method only ever modifies an existing field and never adds a new field to the schema. To add a new field,
194    /// use [`with_column`][Self::with_column] or [`insert_at_index`][Self::insert_at_index].
195    pub fn set_dtype(&mut self, name: &str, dtype: D) -> Option<D> {
196        let old_dtype = self.fields.get_mut(name)?;
197        Some(std::mem::replace(old_dtype, dtype))
198    }
199
200    /// Change the field at the given index to the given `dtype` and return the previous dtype.
201    ///
202    /// If the index is out of bounds, the schema is not modified and `None` is returned. Otherwise returns
203    /// `Some(old_dtype)`.
204    ///
205    /// This method only ever modifies an existing index and never adds a new field to the schema. To add a new field,
206    /// use [`with_column`][Self::with_column] or [`insert_at_index`][Self::insert_at_index].
207    pub fn set_dtype_at_index(&mut self, index: usize, dtype: D) -> Option<D> {
208        let (_, old_dtype) = self.fields.get_index_mut(index)?;
209        Some(std::mem::replace(old_dtype, dtype))
210    }
211
212    /// Insert a new column in the [`Schema`].
213    ///
214    /// If an equivalent name already exists in the schema: the name remains and
215    /// retains in its place in the order, its corresponding value is updated
216    /// with [`D`] and the older dtype is returned inside `Some(_)`.
217    ///
218    /// If no equivalent key existed in the map: the new name-dtype pair is
219    /// inserted, last in order, and `None` is returned.
220    ///
221    /// To enforce the index of the resulting field, use [`insert_at_index`][Self::insert_at_index].
222    ///
223    /// Computes in **O(1)** time (amortized average).
224    pub fn with_column(&mut self, name: PlSmallStr, dtype: D) -> Option<D> {
225        self.fields.insert(name, dtype)
226    }
227
228    /// Merge `other` into `self`.
229    ///
230    /// Merging logic:
231    /// - Fields that occur in `self` but not `other` are unmodified
232    /// - Fields that occur in `other` but not `self` are appended, in order, to the end of `self`
233    /// - Fields that occur in both `self` and `other` are updated with the dtype from `other`, but keep their original
234    ///   index
235    pub fn merge(&mut self, other: Self) {
236        self.fields.extend(other.fields)
237    }
238
239    /// Iterates over the `(&name, &dtype)` pairs in this schema.
240    ///
241    /// For an owned version, use [`iter_fields`][Self::iter_fields], which clones the data to iterate owned `Field`s
242    pub fn iter(&self) -> impl ExactSizeIterator<Item = (&PlSmallStr, &D)> + '_ {
243        self.fields.iter()
244    }
245
246    pub fn iter_mut(&mut self) -> impl ExactSizeIterator<Item = (&PlSmallStr, &mut D)> + '_ {
247        self.fields.iter_mut()
248    }
249
250    /// Iterates over references to the names in this schema.
251    pub fn iter_names(&self) -> impl '_ + ExactSizeIterator<Item = &PlSmallStr> {
252        self.fields.iter().map(|(name, _dtype)| name)
253    }
254
255    pub fn iter_names_cloned(&self) -> impl '_ + ExactSizeIterator<Item = PlSmallStr> {
256        self.iter_names().cloned()
257    }
258
259    /// Iterates over references to the dtypes in this schema.
260    pub fn iter_values(&self) -> impl '_ + ExactSizeIterator<Item = &D> {
261        self.fields.iter().map(|(_name, dtype)| dtype)
262    }
263
264    pub fn into_iter_values(self) -> impl ExactSizeIterator<Item = D> {
265        self.fields.into_values()
266    }
267
268    /// Iterates over mut references to the dtypes in this schema.
269    pub fn iter_values_mut(&mut self) -> impl '_ + ExactSizeIterator<Item = &mut D> {
270        self.fields.iter_mut().map(|(_name, dtype)| dtype)
271    }
272
273    pub fn index_of(&self, name: &str) -> Option<usize> {
274        self.fields.get_index_of(name)
275    }
276
277    pub fn try_index_of(&self, name: &str) -> PolarsResult<usize> {
278        let Some(i) = self.fields.get_index_of(name) else {
279            polars_bail!(
280                ColumnNotFound:
281                "unable to find column {:?}; valid columns: {:?}",
282                name, self.iter_names().collect::<Vec<_>>(),
283            )
284        };
285
286        Ok(i)
287    }
288
289    /// Compare the fields between two schema returning the additional columns that each schema has.
290    pub fn field_compare<'a, 'b>(
291        &'a self,
292        other: &'b Self,
293        self_extra: &mut Vec<(usize, (&'a PlSmallStr, &'a D))>,
294        other_extra: &mut Vec<(usize, (&'b PlSmallStr, &'b D))>,
295    ) {
296        self_extra.extend(
297            self.iter()
298                .enumerate()
299                .filter(|(_, (n, _))| !other.contains(n)),
300        );
301        other_extra.extend(
302            other
303                .iter()
304                .enumerate()
305                .filter(|(_, (n, _))| !self.contains(n)),
306        );
307    }
308}
309
310impl<D> Schema<D>
311where
312    D: Clone + Default,
313{
314    /// Create a new schema from this one, inserting a field with `name` and `dtype` at the given `index`.
315    ///
316    /// If a field named `name` already exists, it is updated with the new dtype. Regardless, the field named `name` is
317    /// always moved to the given index. Valid indices range from `0` (front of the schema) to `self.len()` (after the
318    /// end of the schema).
319    ///
320    /// For a mutating version that doesn't clone, see [`insert_at_index`][Self::insert_at_index].
321    ///
322    /// Runtime: **O(m * n)** where `m` is the (average) length of the field names and `n` is the number of fields in
323    /// the schema. This method clones every field in the schema.
324    ///
325    /// Returns: `Ok(new_schema)` if `index <= self.len()`, else `Err(PolarsError)`
326    pub fn new_inserting_at_index(
327        &self,
328        index: usize,
329        name: PlSmallStr,
330        field: D,
331    ) -> PolarsResult<Self> {
332        polars_ensure!(
333            index <= self.len(),
334            OutOfBounds:
335                "index {} is out of bounds for schema with length {} (the max index allowed is self.len())",
336                    index,
337                    self.len()
338        );
339
340        let mut new = Self::default();
341        let mut iter = self.fields.iter().filter_map(|(fld_name, dtype)| {
342            (fld_name != &name).then_some((fld_name.clone(), dtype.clone()))
343        });
344        new.fields.extend(iter.by_ref().take(index));
345        new.fields.insert(name.clone(), field);
346        new.fields.extend(iter);
347        Ok(new)
348    }
349
350    /// Merge borrowed `other` into `self`.
351    ///
352    /// Merging logic:
353    /// - Fields that occur in `self` but not `other` are unmodified
354    /// - Fields that occur in `other` but not `self` are appended, in order, to the end of `self`
355    /// - Fields that occur in both `self` and `other` are updated with the dtype from `other`, but keep their original
356    ///   index
357    pub fn merge_from_ref(&mut self, other: &Self) {
358        self.fields.extend(
359            other
360                .iter()
361                .map(|(column, field)| (column.clone(), field.clone())),
362        )
363    }
364
365    /// Generates another schema with just the specified columns selected from this one.
366    pub fn try_project<I>(&self, columns: I) -> PolarsResult<Self>
367    where
368        I: IntoIterator,
369        I::Item: AsRef<str>,
370    {
371        let schema = columns
372            .into_iter()
373            .map(|c| {
374                let name = c.as_ref();
375                let (_, name, dtype) = self
376                    .fields
377                    .get_full(name)
378                    .ok_or_else(|| polars_err!(col_not_found = name))?;
379                PolarsResult::Ok((name.clone(), dtype.clone()))
380            })
381            .collect::<PolarsResult<PlIndexMap<PlSmallStr, _>>>()?;
382        Ok(Self::from(schema))
383    }
384
385    pub fn try_project_indices(&self, indices: &[usize]) -> PolarsResult<Self> {
386        let fields = indices
387            .iter()
388            .map(|&i| {
389                let Some((k, v)) = self.fields.get_index(i) else {
390                    polars_bail!(
391                        SchemaFieldNotFound:
392                        "projection index {} is out of bounds for schema of length {}",
393                        i, self.fields.len()
394                    );
395                };
396
397                Ok((k.clone(), v.clone()))
398            })
399            .collect::<PolarsResult<PlIndexMap<_, _>>>()?;
400
401        Ok(Self { fields })
402    }
403
404    /// Returns a new [`Schema`] with a subset of all fields whose `predicate`
405    /// evaluates to true.
406    pub fn filter<F: Fn(usize, &D) -> bool>(self, predicate: F) -> Self {
407        let fields = self
408            .fields
409            .into_iter()
410            .enumerate()
411            .filter_map(|(index, (name, d))| {
412                if (predicate)(index, &d) {
413                    Some((name, d))
414                } else {
415                    None
416                }
417            })
418            .collect();
419
420        Self { fields }
421    }
422}
423
424pub fn debug_ensure_matching_schema_names<D>(lhs: &Schema<D>, rhs: &Schema<D>) -> PolarsResult<()> {
425    if cfg!(debug_assertions) {
426        let lhs = lhs.iter_names().collect::<Vec<_>>();
427        let rhs = rhs.iter_names().collect::<Vec<_>>();
428
429        if lhs != rhs {
430            polars_bail!(
431                SchemaMismatch:
432                "lhs: {:?} rhs: {:?}",
433                lhs, rhs
434            )
435        }
436    }
437
438    Ok(())
439}
440
441impl<D: Debug> Debug for Schema<D> {
442    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
443        writeln!(f, "Schema:")?;
444        for (name, field) in self.fields.iter() {
445            writeln!(f, "name: {name}, field: {field:?}")?;
446        }
447        Ok(())
448    }
449}
450
451impl<D: Hash> Hash for Schema<D> {
452    fn hash<H: Hasher>(&self, state: &mut H) {
453        self.fields.iter().for_each(|v| v.hash(state))
454    }
455}
456
457// Schemas will only compare equal if they have the same fields in the same order. We can't use `self.inner ==
458// other.inner` because [`IndexMap`] ignores order when checking equality, but we don't want to ignore it.
459impl<D: PartialEq> PartialEq for Schema<D> {
460    fn eq(&self, other: &Self) -> bool {
461        self.fields.len() == other.fields.len()
462            && self
463                .fields
464                .iter()
465                .zip(other.fields.iter())
466                .all(|(a, b)| a == b)
467    }
468}
469
470impl<D> From<PlIndexMap<PlSmallStr, D>> for Schema<D> {
471    fn from(fields: PlIndexMap<PlSmallStr, D>) -> Self {
472        Self { fields }
473    }
474}
475
476impl<F, D> FromIterator<F> for Schema<D>
477where
478    F: Into<(PlSmallStr, D)>,
479{
480    fn from_iter<I: IntoIterator<Item = F>>(iter: I) -> Self {
481        let fields = PlIndexMap::from_iter(iter.into_iter().map(|x| x.into()));
482        Self { fields }
483    }
484}
485
486impl<F, D> Extend<F> for Schema<D>
487where
488    F: Into<(PlSmallStr, D)>,
489{
490    fn extend<T: IntoIterator<Item = F>>(&mut self, iter: T) {
491        self.fields.extend(iter.into_iter().map(|x| x.into()))
492    }
493}
494
495impl<D> IntoIterator for Schema<D> {
496    type IntoIter = <PlIndexMap<PlSmallStr, D> as IntoIterator>::IntoIter;
497    type Item = (PlSmallStr, D);
498
499    fn into_iter(self) -> Self::IntoIter {
500        self.fields.into_iter()
501    }
502}