rustfft/
array_utils.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
use crate::Complex;
use crate::FftNum;
use std::ops::{Deref, DerefMut};

/// Given an array of size width * height, representing a flattened 2D array,
/// transpose the rows and columns of that 2D array into the output
/// benchmarking shows that loop tiling isn't effective for small arrays (in the range of 50x50 or smaller)
pub unsafe fn transpose_small<T: Copy>(width: usize, height: usize, input: &[T], output: &mut [T]) {
    for x in 0..width {
        for y in 0..height {
            let input_index = x + y * width;
            let output_index = y + x * height;

            *output.get_unchecked_mut(output_index) = *input.get_unchecked(input_index);
        }
    }
}

#[allow(unused)]
pub unsafe fn workaround_transmute<T, U>(slice: &[T]) -> &[U] {
    let ptr = slice.as_ptr() as *const U;
    let len = slice.len();
    std::slice::from_raw_parts(ptr, len)
}
#[allow(unused)]
pub unsafe fn workaround_transmute_mut<T, U>(slice: &mut [T]) -> &mut [U] {
    let ptr = slice.as_mut_ptr() as *mut U;
    let len = slice.len();
    std::slice::from_raw_parts_mut(ptr, len)
}

pub(crate) trait LoadStore<T: FftNum>: DerefMut {
    unsafe fn load(&self, idx: usize) -> Complex<T>;
    unsafe fn store(&mut self, val: Complex<T>, idx: usize);
}

impl<T: FftNum> LoadStore<T> for &mut [Complex<T>] {
    #[inline(always)]
    unsafe fn load(&self, idx: usize) -> Complex<T> {
        debug_assert!(idx < self.len());
        *self.get_unchecked(idx)
    }
    #[inline(always)]
    unsafe fn store(&mut self, val: Complex<T>, idx: usize) {
        debug_assert!(idx < self.len());
        *self.get_unchecked_mut(idx) = val;
    }
}
impl<T: FftNum, const N: usize> LoadStore<T> for &mut [Complex<T>; N] {
    #[inline(always)]
    unsafe fn load(&self, idx: usize) -> Complex<T> {
        debug_assert!(idx < self.len());
        *self.get_unchecked(idx)
    }
    #[inline(always)]
    unsafe fn store(&mut self, val: Complex<T>, idx: usize) {
        debug_assert!(idx < self.len());
        *self.get_unchecked_mut(idx) = val;
    }
}

pub(crate) struct DoubleBuf<'a, T> {
    pub input: &'a [Complex<T>],
    pub output: &'a mut [Complex<T>],
}
impl<'a, T> Deref for DoubleBuf<'a, T> {
    type Target = [Complex<T>];
    fn deref(&self) -> &Self::Target {
        self.input
    }
}
impl<'a, T> DerefMut for DoubleBuf<'a, T> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        self.output
    }
}
impl<'a, T: FftNum> LoadStore<T> for DoubleBuf<'a, T> {
    #[inline(always)]
    unsafe fn load(&self, idx: usize) -> Complex<T> {
        debug_assert!(idx < self.input.len());
        *self.input.get_unchecked(idx)
    }
    #[inline(always)]
    unsafe fn store(&mut self, val: Complex<T>, idx: usize) {
        debug_assert!(idx < self.output.len());
        *self.output.get_unchecked_mut(idx) = val;
    }
}

#[cfg(test)]
mod unit_tests {
    use super::*;
    use crate::test_utils::random_signal;
    use num_complex::Complex;
    use num_traits::Zero;

    #[test]
    fn test_transpose() {
        let sizes: Vec<usize> = (1..16).collect();

        for &width in &sizes {
            for &height in &sizes {
                let len = width * height;

                let input: Vec<Complex<f32>> = random_signal(len);
                let mut output = vec![Zero::zero(); len];

                unsafe { transpose_small(width, height, &input, &mut output) };

                for x in 0..width {
                    for y in 0..height {
                        assert_eq!(
                            input[x + y * width],
                            output[y + x * height],
                            "x = {}, y = {}",
                            x,
                            y
                        );
                    }
                }
            }
        }
    }
}

// Loop over exact chunks of the provided buffer. Very similar in semantics to ChunksExactMut, but generates smaller code and requires no modulo operations
// Returns Ok() if every element ended up in a chunk, Err() if there was a remainder
pub fn iter_chunks<T>(
    mut buffer: &mut [T],
    chunk_size: usize,
    mut chunk_fn: impl FnMut(&mut [T]),
) -> Result<(), ()> {
    // Loop over the buffer, splicing off chunk_size at a time, and calling chunk_fn on each
    while buffer.len() >= chunk_size {
        let (head, tail) = buffer.split_at_mut(chunk_size);
        buffer = tail;

        chunk_fn(head);
    }

    // We have a remainder if there's data still in the buffer -- in which case we want to indicate to the caller that there was an unwanted remainder
    if buffer.len() == 0 {
        Ok(())
    } else {
        Err(())
    }
}

// Loop over exact zipped chunks of the 2 provided buffers. Very similar in semantics to ChunksExactMut.zip(ChunksExactMut), but generates smaller code and requires no modulo operations
// Returns Ok() if every element of both buffers ended up in a chunk, Err() if there was a remainder
pub fn iter_chunks_zipped<T>(
    mut buffer1: &mut [T],
    mut buffer2: &mut [T],
    chunk_size: usize,
    mut chunk_fn: impl FnMut(&mut [T], &mut [T]),
) -> Result<(), ()> {
    // If the two buffers aren't the same size, record the fact that they're different, then snip them to be the same size
    let uneven = if buffer1.len() > buffer2.len() {
        buffer1 = &mut buffer1[..buffer2.len()];
        true
    } else if buffer2.len() < buffer1.len() {
        buffer2 = &mut buffer2[..buffer1.len()];
        true
    } else {
        false
    };

    // Now that we know the two slices are the same length, loop over each one, splicing off chunk_size at a time, and calling chunk_fn on each
    while buffer1.len() >= chunk_size && buffer2.len() >= chunk_size {
        let (head1, tail1) = buffer1.split_at_mut(chunk_size);
        buffer1 = tail1;

        let (head2, tail2) = buffer2.split_at_mut(chunk_size);
        buffer2 = tail2;

        chunk_fn(head1, head2);
    }

    // We have a remainder if the 2 chunks were uneven to start with, or if there's still data in the buffers -- in which case we want to indicate to the caller that there was an unwanted remainder
    if !uneven && buffer1.len() == 0 {
        Ok(())
    } else {
        Err(())
    }
}