rustfft/avx/
avx_bluesteins.rs

1use std::any::TypeId;
2use std::sync::Arc;
3
4use num_complex::Complex;
5use num_integer::div_ceil;
6use num_traits::Zero;
7
8use crate::common::{fft_error_inplace, fft_error_outofplace};
9use crate::{array_utils, twiddles, FftDirection};
10use crate::{Direction, Fft, FftNum, Length};
11
12use super::CommonSimdData;
13use super::{
14    avx_vector::{AvxArray, AvxArrayMut, AvxVector, AvxVector128, AvxVector256},
15    AvxNum,
16};
17
18/// Implementation of Bluestein's Algorithm
19///
20/// This algorithm computes an arbitrary-sized FFT in O(nlogn) time. It does this by converting this size n FFT into a
21/// size M where M >= 2N - 1. The most obvious choice for M is a power of two, although that isn't a requirement.
22///
23/// It requires a large scratch space, so it's probably inconvenient to use as an inner FFT to other algorithms.
24///
25/// Bluestein's Algorithm is relatively expensive compared to other FFT algorithms. Benchmarking shows that it is up to
26/// an order of magnitude slower than similar composite sizes.
27
28pub struct BluesteinsAvx<A: AvxNum, T> {
29    inner_fft_multiplier: Box<[A::VectorType]>,
30    common_data: CommonSimdData<T, A::VectorType>,
31    _phantom: std::marker::PhantomData<T>,
32}
33boilerplate_avx_fft_commondata!(BluesteinsAvx);
34
35impl<A: AvxNum, T: FftNum> BluesteinsAvx<A, T> {
36    /// Pairwise multiply the complex numbers in `left` with the complex numbers in `right`.
37    /// This is exactly the same as `mul_complex` in `AvxVector`, but this implementation also conjugates the `left` input before multiplying
38    #[inline(always)]
39    unsafe fn mul_complex_conjugated<V: AvxVector>(left: V, right: V) -> V {
40        // Extract the real and imaginary components from left into 2 separate registers
41        let (left_real, left_imag) = V::duplicate_complex_components(left);
42
43        // create a shuffled version of right where the imaginary values are swapped with the reals
44        let right_shuffled = V::swap_complex_components(right);
45
46        // multiply our duplicated imaginary left vector by our shuffled right vector. that will give us the right side of the traditional complex multiplication formula
47        let output_right = V::mul(left_imag, right_shuffled);
48
49        // use a FMA instruction to multiply together left side of the complex multiplication formula, then alternatingly add and subtract the left side from the right
50        // By using subadd instead of addsub, we can conjugate the left side for free.
51        V::fmsubadd(left_real, right, output_right)
52    }
53
54    /// Preallocates necessary arrays and precomputes necessary data to efficiently compute the FFT
55    /// Returns Ok() if this machine has the required instruction sets, Err() if some instruction sets are missing
56    #[inline]
57    pub fn new(len: usize, inner_fft: Arc<dyn Fft<T>>) -> Result<Self, ()> {
58        // Internal sanity check: Make sure that A == T.
59        // This struct has two generic parameters A and T, but they must always be the same, and are only kept separate to help work around the lack of specialization.
60        // It would be cool if we could do this as a static_assert instead
61        let id_a = TypeId::of::<A>();
62        let id_t = TypeId::of::<T>();
63        assert_eq!(id_a, id_t);
64
65        let has_avx = is_x86_feature_detected!("avx");
66        let has_fma = is_x86_feature_detected!("fma");
67        if has_avx && has_fma {
68            // Safety: new_with_avx requires the "avx" feature set. Since we know it's present, we're safe
69            Ok(unsafe { Self::new_with_avx(len, inner_fft) })
70        } else {
71            Err(())
72        }
73    }
74
75    #[target_feature(enable = "avx")]
76    unsafe fn new_with_avx(len: usize, inner_fft: Arc<dyn Fft<T>>) -> Self {
77        let inner_fft_len = inner_fft.len();
78        assert!(len * 2 - 1 <= inner_fft_len, "Bluestein's algorithm requires inner_fft.len() >= self.len() * 2 - 1. Expected >= {}, got {}", len * 2 - 1, inner_fft_len);
79        assert_eq!(inner_fft_len % A::VectorType::COMPLEX_PER_VECTOR, 0, "BluesteinsAvx requires its inner_fft.len() to be a multiple of {} (IE the number of complex numbers in a single vector) inner_fft.len() = {}", A::VectorType::COMPLEX_PER_VECTOR, inner_fft_len);
80
81        // when computing FFTs, we're going to run our inner multiply pairwise by some precomputed data, then run an inverse inner FFT. We need to precompute that inner data here
82        let inner_fft_scale = A::one() / A::from_usize(inner_fft_len).unwrap();
83        let direction = inner_fft.fft_direction();
84
85        // Compute twiddle factors that we'll run our inner FFT on
86        let mut inner_fft_input = vec![Complex::zero(); inner_fft_len];
87        twiddles::fill_bluesteins_twiddles(
88            &mut inner_fft_input[..len],
89            direction.opposite_direction(),
90        );
91
92        // Scale the computed twiddles and copy them to the end of the array
93        inner_fft_input[0] = inner_fft_input[0] * inner_fft_scale;
94        for i in 1..len {
95            let twiddle = inner_fft_input[i] * inner_fft_scale;
96            inner_fft_input[i] = twiddle;
97            inner_fft_input[inner_fft_len - i] = twiddle;
98        }
99
100        //Compute the inner fft
101        let mut inner_fft_scratch = vec![Complex::zero(); inner_fft.get_inplace_scratch_len()];
102
103        {
104            // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
105            let transmuted_input: &mut [Complex<T>] =
106                array_utils::workaround_transmute_mut(&mut inner_fft_input);
107
108            inner_fft.process_with_scratch(transmuted_input, &mut inner_fft_scratch);
109        }
110
111        // When computing the FFT, we'll want this array to be pre-conjugated, so conjugate it now
112        let conjugation_mask =
113            AvxVector256::broadcast_complex_elements(Complex::new(A::zero(), -A::zero()));
114        let inner_fft_multiplier = inner_fft_input
115            .chunks_exact(A::VectorType::COMPLEX_PER_VECTOR)
116            .map(|chunk| {
117                let chunk_vector = chunk.load_complex(0);
118                AvxVector::xor(chunk_vector, conjugation_mask) // compute our conjugation by xoring our data with a precomputed mask
119            })
120            .collect::<Vec<_>>()
121            .into_boxed_slice();
122
123        // also compute some more mundane twiddle factors to start and end with.
124        let chunk_count = div_ceil(len, A::VectorType::COMPLEX_PER_VECTOR);
125        let twiddle_count = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
126        let mut twiddles_scalar: Vec<Complex<A>> = vec![Complex::zero(); twiddle_count];
127        twiddles::fill_bluesteins_twiddles(&mut twiddles_scalar[..len], direction);
128
129        // We have the twiddles in scalar format, last step is to copy them over to AVX vectors
130        let twiddles: Vec<_> = twiddles_scalar
131            .chunks_exact(A::VectorType::COMPLEX_PER_VECTOR)
132            .map(|chunk| chunk.load_complex(0))
133            .collect();
134
135        let required_scratch = inner_fft_input.len() + inner_fft_scratch.len();
136
137        Self {
138            inner_fft_multiplier,
139            common_data: CommonSimdData {
140                inner_fft,
141                twiddles: twiddles.into_boxed_slice(),
142
143                len,
144
145                inplace_scratch_len: required_scratch,
146                outofplace_scratch_len: required_scratch,
147
148                direction,
149            },
150            _phantom: std::marker::PhantomData,
151        }
152    }
153
154    // Do the necessary setup for bluestein's algorithm: copy the data to the inner buffers, apply some twiddle factors, zero out the rest of the inner buffer
155    #[target_feature(enable = "avx", enable = "fma")]
156    unsafe fn prepare_bluesteins(
157        &self,
158        input: &[Complex<A>],
159        mut inner_fft_buffer: &mut [Complex<A>],
160    ) {
161        let chunk_count = self.common_data.twiddles.len() - 1;
162        let remainder = self.len() - chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
163
164        // Copy the buffer into our inner FFT input, applying twiddle factors as we go. the buffer will only fill part of the FFT input, so zero fill the rest
165        for (i, twiddle) in self.common_data.twiddles[..chunk_count].iter().enumerate() {
166            let index = i * A::VectorType::COMPLEX_PER_VECTOR;
167            let input_vector = input.load_complex(index);
168            let product_vector = AvxVector::mul_complex(input_vector, *twiddle);
169            inner_fft_buffer.store_complex(product_vector, index);
170        }
171
172        // the buffer will almost certainly have a remainder. it's so likely, in fact, that we're just going to apply a remainder unconditionally
173        // it uses a couple more instructions in the rare case when our FFT size is a multiple of 4, but saves instructions when it's not
174        {
175            let remainder_twiddle = self.common_data.twiddles[chunk_count];
176
177            let remainder_index = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
178            let remainder_data = match remainder {
179                1 => input.load_partial1_complex(remainder_index).zero_extend(),
180                2 => {
181                    if A::VectorType::COMPLEX_PER_VECTOR == 2 {
182                        input.load_complex(remainder_index)
183                    } else {
184                        input.load_partial2_complex(remainder_index).zero_extend()
185                    }
186                }
187                3 => input.load_partial3_complex(remainder_index),
188                4 => input.load_complex(remainder_index),
189                _ => unreachable!(),
190            };
191
192            let twiddled_remainder = AvxVector::mul_complex(remainder_twiddle, remainder_data);
193            inner_fft_buffer.store_complex(twiddled_remainder, remainder_index);
194        }
195
196        // zero fill the rest of the `inner` array
197        let zerofill_start = chunk_count + 1;
198        for i in zerofill_start..(inner_fft_buffer.len() / A::VectorType::COMPLEX_PER_VECTOR) {
199            let index = i * A::VectorType::COMPLEX_PER_VECTOR;
200            inner_fft_buffer.store_complex(AvxVector::zero(), index);
201        }
202    }
203
204    // Do the necessary finalization for bluestein's algorithm: Conjugate the inner FFT buffer, apply some twiddle factors, zero out the rest of the inner buffer
205    #[target_feature(enable = "avx", enable = "fma")]
206    unsafe fn finalize_bluesteins(
207        &self,
208        inner_fft_buffer: &[Complex<A>],
209        mut output: &mut [Complex<A>],
210    ) {
211        let chunk_count = self.common_data.twiddles.len() - 1;
212        let remainder = self.len() - chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
213
214        // copy our data to the output, applying twiddle factors again as we go. Also conjugate inner_fft_buffer to complete the inverse FFT
215        for (i, twiddle) in self.common_data.twiddles[..chunk_count].iter().enumerate() {
216            let index = i * A::VectorType::COMPLEX_PER_VECTOR;
217            let inner_vector = inner_fft_buffer.load_complex(index);
218            let product_vector = Self::mul_complex_conjugated(inner_vector, *twiddle);
219            output.store_complex(product_vector, index);
220        }
221
222        // again, unconditionally apply a remainder
223        {
224            let remainder_twiddle = self.common_data.twiddles[chunk_count];
225
226            let remainder_index = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
227            let inner_vector = inner_fft_buffer.load_complex(remainder_index);
228            let product_vector = Self::mul_complex_conjugated(inner_vector, remainder_twiddle);
229
230            match remainder {
231                1 => output.store_partial1_complex(product_vector.lo(), remainder_index),
232                2 => {
233                    if A::VectorType::COMPLEX_PER_VECTOR == 2 {
234                        output.store_complex(product_vector, remainder_index)
235                    } else {
236                        output.store_partial2_complex(product_vector.lo(), remainder_index)
237                    }
238                }
239                3 => output.store_partial3_complex(product_vector, remainder_index),
240                4 => output.store_complex(product_vector, remainder_index),
241                _ => unreachable!(),
242            };
243        }
244    }
245
246    // compute buffer[i] = buffer[i].conj() * multiplier[i] pairwise complex multiplication for each element.
247    #[target_feature(enable = "avx", enable = "fma")]
248    unsafe fn pairwise_complex_multiply_conjugated(
249        mut buffer: impl AvxArrayMut<A>,
250        multiplier: &[A::VectorType],
251    ) {
252        for (i, right) in multiplier.iter().enumerate() {
253            let left = buffer.load_complex(i * A::VectorType::COMPLEX_PER_VECTOR);
254
255            // Do a complex multiplication between `left` and `right`
256            let product = Self::mul_complex_conjugated(left, *right);
257
258            // Store the result
259            buffer.store_complex(product, i * A::VectorType::COMPLEX_PER_VECTOR);
260        }
261    }
262
263    fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
264        let (inner_input, inner_scratch) = scratch
265            .split_at_mut(self.inner_fft_multiplier.len() * A::VectorType::COMPLEX_PER_VECTOR);
266
267        // do the necessary setup for bluestein's algorithm: copy the data to the inner buffers, apply some twiddle factors, zero out the rest of the inner buffer
268        unsafe {
269            // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
270            let transmuted_buffer: &mut [Complex<A>] =
271                array_utils::workaround_transmute_mut(buffer);
272            let transmuted_inner_input: &mut [Complex<A>] =
273                array_utils::workaround_transmute_mut(inner_input);
274
275            self.prepare_bluesteins(transmuted_buffer, transmuted_inner_input);
276        }
277
278        // run our inner forward FFT
279        self.common_data
280            .inner_fft
281            .process_with_scratch(inner_input, inner_scratch);
282
283        // Multiply our inner FFT output by our precomputed data. Then, conjugate the result to set up for an inverse FFT.
284        // We can conjugate the result of multiplication by conjugating both inputs. We pre-conjugated the multiplier array,
285        // so we just need to conjugate inner_input, which the pairwise_complex_multiply_conjugated function will handle
286        unsafe {
287            // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
288            let transmuted_inner_input: &mut [Complex<A>] =
289                array_utils::workaround_transmute_mut(inner_input);
290
291            Self::pairwise_complex_multiply_conjugated(
292                transmuted_inner_input,
293                &self.inner_fft_multiplier,
294            )
295        };
296
297        // inverse FFT. we're computing a forward but we're massaging it into an inverse by conjugating the inputs and outputs
298        self.common_data
299            .inner_fft
300            .process_with_scratch(inner_input, inner_scratch);
301
302        // finalize the result
303        unsafe {
304            // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
305            let transmuted_buffer: &mut [Complex<A>] =
306                array_utils::workaround_transmute_mut(buffer);
307            let transmuted_inner_input: &mut [Complex<A>] =
308                array_utils::workaround_transmute_mut(inner_input);
309
310            self.finalize_bluesteins(transmuted_inner_input, transmuted_buffer);
311        }
312    }
313
314    fn perform_fft_out_of_place(
315        &self,
316        input: &[Complex<T>],
317        output: &mut [Complex<T>],
318        scratch: &mut [Complex<T>],
319    ) {
320        let (inner_input, inner_scratch) = scratch
321            .split_at_mut(self.inner_fft_multiplier.len() * A::VectorType::COMPLEX_PER_VECTOR);
322
323        // do the necessary setup for bluestein's algorithm: copy the data to the inner buffers, apply some twiddle factors, zero out the rest of the inner buffer
324        unsafe {
325            // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
326            let transmuted_input: &[Complex<A>] = array_utils::workaround_transmute(input);
327            let transmuted_inner_input: &mut [Complex<A>] =
328                array_utils::workaround_transmute_mut(inner_input);
329
330            self.prepare_bluesteins(transmuted_input, transmuted_inner_input)
331        }
332
333        // run our inner forward FFT
334        self.common_data
335            .inner_fft
336            .process_with_scratch(inner_input, inner_scratch);
337
338        // Multiply our inner FFT output by our precomputed data. Then, conjugate the result to set up for an inverse FFT.
339        // We can conjugate the result of multiplication by conjugating both inputs. We pre-conjugated the multiplier array,
340        // so we just need to conjugate inner_input, which the pairwise_complex_multiply_conjugated function will handle
341        unsafe {
342            // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
343            let transmuted_inner_input: &mut [Complex<A>] =
344                array_utils::workaround_transmute_mut(inner_input);
345
346            Self::pairwise_complex_multiply_conjugated(
347                transmuted_inner_input,
348                &self.inner_fft_multiplier,
349            )
350        };
351
352        // inverse FFT. we're computing a forward but we're massaging it into an inverse by conjugating the inputs and outputs
353        self.common_data
354            .inner_fft
355            .process_with_scratch(inner_input, inner_scratch);
356
357        // finalize the result
358        unsafe {
359            // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
360            let transmuted_output: &mut [Complex<A>] =
361                array_utils::workaround_transmute_mut(output);
362            let transmuted_inner_input: &mut [Complex<A>] =
363                array_utils::workaround_transmute_mut(inner_input);
364
365            self.finalize_bluesteins(transmuted_inner_input, transmuted_output)
366        }
367    }
368}
369
370#[cfg(test)]
371mod unit_tests {
372    use num_traits::Float;
373    use rand::distributions::uniform::SampleUniform;
374
375    use super::*;
376    use crate::algorithm::Dft;
377    use crate::test_utils::check_fft_algorithm;
378    use std::sync::Arc;
379
380    #[test]
381    fn test_bluesteins_avx_f32() {
382        for len in 2..16 {
383            // for this len, compute the range of inner FFT lengths we'll use.
384            // Bluesteins AVX f32 requires a multiple of 4 for the inner FFT, so we need to go up to the next multiple of 4 from the minimum
385            let minimum_inner: usize = len * 2 - 1;
386            let remainder = minimum_inner % 4;
387
388            // remainder will never be 0, because "n * 2 - 1" is guaranteed to be odd. so we can just subtract the remainder and add 4.
389            let next_multiple_of_4 = minimum_inner - remainder + 4;
390            let maximum_inner = minimum_inner.checked_next_power_of_two().unwrap() + 1;
391
392            // start at the next multiple of 4, and increment by 4 unti lwe get to the next power of 2.
393            for inner_len in (next_multiple_of_4..maximum_inner).step_by(4) {
394                test_bluesteins_avx_with_length::<f32>(len, inner_len, FftDirection::Forward);
395                test_bluesteins_avx_with_length::<f32>(len, inner_len, FftDirection::Inverse);
396            }
397        }
398    }
399
400    #[test]
401    fn test_bluesteins_avx_f64() {
402        for len in 2..16 {
403            // for this len, compute the range of inner FFT lengths we'll use.
404            // Bluesteins AVX f64 requires a multiple of 2 for the inner FFT, so we need to go up to the next multiple of 2 from the minimum
405            let minimum_inner: usize = len * 2 - 1;
406            let remainder = minimum_inner % 2;
407
408            let next_multiple_of_2 = minimum_inner + remainder;
409            let maximum_inner = minimum_inner.checked_next_power_of_two().unwrap() + 1;
410
411            // start at the next multiple of 2, and increment by 2 unti lwe get to the next power of 2.
412            for inner_len in (next_multiple_of_2..maximum_inner).step_by(2) {
413                test_bluesteins_avx_with_length::<f64>(len, inner_len, FftDirection::Forward);
414                test_bluesteins_avx_with_length::<f64>(len, inner_len, FftDirection::Inverse);
415            }
416        }
417    }
418
419    fn test_bluesteins_avx_with_length<T: AvxNum + Float + SampleUniform>(
420        len: usize,
421        inner_len: usize,
422        direction: FftDirection,
423    ) {
424        let inner_fft = Arc::new(Dft::new(inner_len, direction));
425        let fft: BluesteinsAvx<T, T> = BluesteinsAvx::new(len, inner_fft).expect(
426            "Can't run test because this machine doesn't have the required instruction sets",
427        );
428
429        check_fft_algorithm::<T>(&fft, len, direction);
430    }
431}