rustfft/sse/
sse_radix4.rs

1use num_complex::Complex;
2
3use std::any::TypeId;
4use std::sync::Arc;
5
6use crate::array_utils::{self, bitreversed_transpose, workaround_transmute_mut};
7use crate::common::{fft_error_inplace, fft_error_outofplace};
8use crate::{common::FftNum, FftDirection};
9use crate::{Direction, Fft, Length};
10
11use super::SseNum;
12
13use super::sse_vector::{Rotation90, SseArray, SseArrayMut, SseVector};
14
15/// FFT algorithm optimized for power-of-two sizes, SSE accelerated version.
16/// This is designed to be used via a Planner, and not created directly.
17
18pub struct SseRadix4<S: SseNum, T> {
19    twiddles: Box<[S::VectorType]>,
20    rotation: Rotation90<S::VectorType>,
21
22    base_fft: Arc<dyn Fft<T>>,
23    base_len: usize,
24
25    len: usize,
26    direction: FftDirection,
27}
28
29impl<S: SseNum, T: FftNum> SseRadix4<S, T> {
30    /// Constructs a new SseRadix4 which computes FFTs of size 4^k * base_fft.len()
31    #[inline]
32    pub fn new(k: u32, base_fft: Arc<dyn Fft<T>>) -> Result<Self, ()> {
33        // Internal sanity check: Make sure that S == T.
34        // This struct has two generic parameters S and T, but they must always be the same, and are only kept separate to help work around the lack of specialization.
35        // It would be cool if we could do this as a static_assert instead
36        let id_a = TypeId::of::<S>();
37        let id_t = TypeId::of::<T>();
38        assert_eq!(id_a, id_t);
39
40        let has_sse = is_x86_feature_detected!("sse4.1");
41        if has_sse {
42            // Safety: new_with_sse requires the "sse4.1" feature set. Since we know it's present, we're safe
43            Ok(unsafe { Self::new_with_sse(k, base_fft) })
44        } else {
45            Err(())
46        }
47    }
48
49    #[target_feature(enable = "sse4.1")]
50    unsafe fn new_with_sse(k: u32, base_fft: Arc<dyn Fft<T>>) -> Self {
51        let direction = base_fft.fft_direction();
52        let base_len = base_fft.len();
53
54        // note that we can eventually release this restriction - we just need to update the rest of the code in here to handle remainders
55        assert!(base_len % (2 * S::VectorType::COMPLEX_PER_VECTOR) == 0 && base_len > 0);
56
57        let len = base_len * (1 << (k * 2));
58
59        // precompute the twiddle factors this algorithm will use.
60        // we're doing the same precomputation of twiddle factors as the mixed radix algorithm where width=4 and height=len/4
61        // but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down
62        // so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom layer and going up
63        const ROW_COUNT: usize = 4;
64        let mut cross_fft_len = base_len * ROW_COUNT;
65        let mut twiddle_factors = Vec::with_capacity(len * 2);
66        while cross_fft_len <= len {
67            let num_scalar_columns = cross_fft_len / ROW_COUNT;
68            let num_vector_columns = num_scalar_columns / S::VectorType::COMPLEX_PER_VECTOR;
69
70            for i in 0..num_vector_columns {
71                for k in 1..ROW_COUNT {
72                    twiddle_factors.push(SseVector::make_mixedradix_twiddle_chunk(
73                        i * S::VectorType::COMPLEX_PER_VECTOR,
74                        k,
75                        cross_fft_len,
76                        direction,
77                    ));
78                }
79            }
80            cross_fft_len *= ROW_COUNT;
81        }
82
83        Self {
84            twiddles: twiddle_factors.into_boxed_slice(),
85            rotation: SseVector::make_rotate90(direction),
86
87            base_fft,
88            base_len,
89
90            len,
91            direction,
92        }
93    }
94
95    #[target_feature(enable = "sse4.1")]
96    unsafe fn perform_fft_out_of_place(
97        &self,
98        input: &[Complex<T>],
99        output: &mut [Complex<T>],
100        _scratch: &mut [Complex<T>],
101    ) {
102        // copy the data into the output vector
103        if self.len() == self.base_len {
104            output.copy_from_slice(input);
105        } else {
106            bitreversed_transpose::<Complex<T>, 4>(self.base_len, input, output);
107        }
108
109        // Base-level FFTs
110        self.base_fft.process_with_scratch(output, &mut []);
111
112        // cross-FFTs
113        const ROW_COUNT: usize = 4;
114        let mut cross_fft_len = self.base_len * ROW_COUNT;
115        let mut layer_twiddles: &[S::VectorType] = &self.twiddles;
116
117        while cross_fft_len <= input.len() {
118            let num_rows = input.len() / cross_fft_len;
119            let num_scalar_columns = cross_fft_len / ROW_COUNT;
120            let num_vector_columns = num_scalar_columns / S::VectorType::COMPLEX_PER_VECTOR;
121
122            for i in 0..num_rows {
123                butterfly_4::<S, T>(
124                    &mut output[i * cross_fft_len..],
125                    layer_twiddles,
126                    num_scalar_columns,
127                    &self.rotation,
128                )
129            }
130
131            // skip past all the twiddle factors used in this layer
132            let twiddle_offset = num_vector_columns * (ROW_COUNT - 1);
133            layer_twiddles = &layer_twiddles[twiddle_offset..];
134
135            cross_fft_len *= ROW_COUNT;
136        }
137    }
138}
139boilerplate_fft_sse_oop!(SseRadix4, |this: &SseRadix4<_, _>| this.len);
140
141#[target_feature(enable = "sse4.1")]
142unsafe fn butterfly_4<S: SseNum, T: FftNum>(
143    data: &mut [Complex<T>],
144    twiddles: &[S::VectorType],
145    num_ffts: usize,
146    rotation: &Rotation90<S::VectorType>,
147) {
148    let unroll_offset = S::VectorType::COMPLEX_PER_VECTOR;
149
150    let mut idx = 0usize;
151    let mut buffer: &mut [Complex<S>] = workaround_transmute_mut(data);
152    for tw in twiddles
153        .chunks_exact(6)
154        .take(num_ffts / (S::VectorType::COMPLEX_PER_VECTOR * 2))
155    {
156        let mut scratcha = [
157            buffer.load_complex(idx + 0 * num_ffts),
158            buffer.load_complex(idx + 1 * num_ffts),
159            buffer.load_complex(idx + 2 * num_ffts),
160            buffer.load_complex(idx + 3 * num_ffts),
161        ];
162        let mut scratchb = [
163            buffer.load_complex(idx + 0 * num_ffts + unroll_offset),
164            buffer.load_complex(idx + 1 * num_ffts + unroll_offset),
165            buffer.load_complex(idx + 2 * num_ffts + unroll_offset),
166            buffer.load_complex(idx + 3 * num_ffts + unroll_offset),
167        ];
168
169        scratcha[1] = SseVector::mul_complex(scratcha[1], tw[0]);
170        scratcha[2] = SseVector::mul_complex(scratcha[2], tw[1]);
171        scratcha[3] = SseVector::mul_complex(scratcha[3], tw[2]);
172        scratchb[1] = SseVector::mul_complex(scratchb[1], tw[3]);
173        scratchb[2] = SseVector::mul_complex(scratchb[2], tw[4]);
174        scratchb[3] = SseVector::mul_complex(scratchb[3], tw[5]);
175
176        let scratcha = SseVector::column_butterfly4(scratcha, *rotation);
177        let scratchb = SseVector::column_butterfly4(scratchb, *rotation);
178
179        buffer.store_complex(scratcha[0], idx + 0 * num_ffts);
180        buffer.store_complex(scratchb[0], idx + 0 * num_ffts + unroll_offset);
181        buffer.store_complex(scratcha[1], idx + 1 * num_ffts);
182        buffer.store_complex(scratchb[1], idx + 1 * num_ffts + unroll_offset);
183        buffer.store_complex(scratcha[2], idx + 2 * num_ffts);
184        buffer.store_complex(scratchb[2], idx + 2 * num_ffts + unroll_offset);
185        buffer.store_complex(scratcha[3], idx + 3 * num_ffts);
186        buffer.store_complex(scratchb[3], idx + 3 * num_ffts + unroll_offset);
187
188        idx += S::VectorType::COMPLEX_PER_VECTOR * 2;
189    }
190}
191
192#[cfg(test)]
193mod unit_tests {
194    use super::*;
195    use crate::test_utils::{check_fft_algorithm, construct_base};
196
197    #[test]
198    fn test_sse_radix4_64() {
199        for base in [2, 4, 6, 8, 12, 16] {
200            let base_forward = construct_base(base, FftDirection::Forward);
201            let base_inverse = construct_base(base, FftDirection::Inverse);
202            for k in 0..4 {
203                test_sse_radix4_64_with_base(k, Arc::clone(&base_forward));
204                test_sse_radix4_64_with_base(k, Arc::clone(&base_inverse));
205            }
206        }
207    }
208
209    fn test_sse_radix4_64_with_base(k: u32, base_fft: Arc<dyn Fft<f64>>) {
210        let len = base_fft.len() * 4usize.pow(k);
211        let direction = base_fft.fft_direction();
212        let fft = SseRadix4::<f64, f64>::new(k, base_fft).unwrap();
213        check_fft_algorithm::<f64>(&fft, len, direction);
214    }
215
216    #[test]
217    fn test_sse_radix4_32() {
218        for base in [4, 8, 12, 16] {
219            let base_forward = construct_base(base, FftDirection::Forward);
220            let base_inverse = construct_base(base, FftDirection::Inverse);
221            for k in 0..4 {
222                test_sse_radix4_32_with_base(k, Arc::clone(&base_forward));
223                test_sse_radix4_32_with_base(k, Arc::clone(&base_inverse));
224            }
225        }
226    }
227
228    fn test_sse_radix4_32_with_base(k: u32, base_fft: Arc<dyn Fft<f32>>) {
229        let len = base_fft.len() * 4usize.pow(k);
230        let direction = base_fft.fft_direction();
231        let fft = SseRadix4::<f32, f32>::new(k, base_fft).unwrap();
232        check_fft_algorithm::<f32>(&fft, len, direction);
233    }
234}