polars_compute/
propagate_dictionary.rs

1use arrow::array::{Array, BinaryViewArray, PrimitiveArray, Utf8ViewArray};
2use arrow::bitmap::Bitmap;
3use arrow::datatypes::ArrowDataType::UInt32;
4
5/// Propagate the nulls from the dictionary values into the keys and remove those nulls from the
6/// values.
7pub fn propagate_dictionary_value_nulls(
8    keys: &PrimitiveArray<u32>,
9    values: &Utf8ViewArray,
10) -> (PrimitiveArray<u32>, Utf8ViewArray) {
11    let Some(values_validity) = values.validity() else {
12        return (keys.clone(), values.clone().with_validity(None));
13    };
14    if values_validity.unset_bits() == 0 {
15        return (keys.clone(), values.clone().with_validity(None));
16    }
17
18    let num_values = values.len();
19
20    // Create a map from the old indices to indices with nulls filtered out
21    let mut offset = 0;
22    let new_idx_map: Vec<u32> = (0..num_values)
23        .map(|i| {
24            let is_valid = unsafe { values_validity.get_bit_unchecked(i) };
25            offset += usize::from(!is_valid);
26            if is_valid {
27                (i - offset) as u32
28            } else {
29                0
30            }
31        })
32        .collect();
33
34    let keys = match keys.validity() {
35        None => {
36            let values = keys
37                .values()
38                .iter()
39                .map(|&k| unsafe {
40                    // SAFETY: Arrow invariant that all keys are in range of values
41                    *new_idx_map.get_unchecked(k as usize)
42                })
43                .collect();
44            let validity = Bitmap::from_iter(keys.values().iter().map(|&k| unsafe {
45                // SAFETY: Arrow invariant that all keys are in range of values
46                values_validity.get_bit_unchecked(k as usize)
47            }));
48
49            PrimitiveArray::new(UInt32, values, Some(validity))
50        },
51        Some(keys_validity) => {
52            let values = keys
53                .values()
54                .iter()
55                .map(|&k| {
56                    // deal with nulls in keys
57                    let idx = (k as usize).min(num_values);
58                    // SAFETY: Arrow invariant that all keys are in range of values
59                    *unsafe { new_idx_map.get_unchecked(idx) }
60                })
61                .collect();
62            let propagated_validity = Bitmap::from_iter(keys.values().iter().map(|&k| {
63                // deal with nulls in keys
64                let idx = (k as usize).min(num_values);
65                // SAFETY: Arrow invariant that all keys are in range of values
66                unsafe { values_validity.get_bit_unchecked(idx) }
67            }));
68
69            let validity = &propagated_validity & keys_validity;
70            PrimitiveArray::new(UInt32, values, Some(validity))
71        },
72    };
73
74    // Filter only handles binary
75    let values = values.to_binview();
76
77    // Filter out the null values
78    let values = crate::filter::filter_with_bitmap(&values, values_validity);
79    let values = values.as_any().downcast_ref::<BinaryViewArray>().unwrap();
80    let values = unsafe { values.to_utf8view_unchecked() }.clone();
81
82    // Explicitly set the values validity to none.
83    assert_eq!(values.null_count(), 0);
84    let values = values.with_validity(None);
85
86    (keys, values)
87}