1use std::any::TypeId;
2use std::sync::Arc;
3
4use num_complex::Complex;
5use num_integer::div_ceil;
6use num_traits::Zero;
7
8use crate::common::{fft_error_inplace, fft_error_outofplace};
9use crate::{array_utils, twiddles, FftDirection};
10use crate::{Direction, Fft, FftNum, Length};
11
12use super::CommonSimdData;
13use super::{
14 avx_vector::{AvxArray, AvxArrayMut, AvxVector, AvxVector128, AvxVector256},
15 AvxNum,
16};
17
18pub struct BluesteinsAvx<A: AvxNum, T> {
29 inner_fft_multiplier: Box<[A::VectorType]>,
30 common_data: CommonSimdData<T, A::VectorType>,
31 _phantom: std::marker::PhantomData<T>,
32}
33boilerplate_avx_fft_commondata!(BluesteinsAvx);
34
35impl<A: AvxNum, T: FftNum> BluesteinsAvx<A, T> {
36 #[inline(always)]
39 unsafe fn mul_complex_conjugated<V: AvxVector>(left: V, right: V) -> V {
40 let (left_real, left_imag) = V::duplicate_complex_components(left);
42
43 let right_shuffled = V::swap_complex_components(right);
45
46 let output_right = V::mul(left_imag, right_shuffled);
48
49 V::fmsubadd(left_real, right, output_right)
52 }
53
54 #[inline]
57 pub fn new(len: usize, inner_fft: Arc<dyn Fft<T>>) -> Result<Self, ()> {
58 let id_a = TypeId::of::<A>();
62 let id_t = TypeId::of::<T>();
63 assert_eq!(id_a, id_t);
64
65 let has_avx = is_x86_feature_detected!("avx");
66 let has_fma = is_x86_feature_detected!("fma");
67 if has_avx && has_fma {
68 Ok(unsafe { Self::new_with_avx(len, inner_fft) })
70 } else {
71 Err(())
72 }
73 }
74
75 #[target_feature(enable = "avx")]
76 unsafe fn new_with_avx(len: usize, inner_fft: Arc<dyn Fft<T>>) -> Self {
77 let inner_fft_len = inner_fft.len();
78 assert!(len * 2 - 1 <= inner_fft_len, "Bluestein's algorithm requires inner_fft.len() >= self.len() * 2 - 1. Expected >= {}, got {}", len * 2 - 1, inner_fft_len);
79 assert_eq!(inner_fft_len % A::VectorType::COMPLEX_PER_VECTOR, 0, "BluesteinsAvx requires its inner_fft.len() to be a multiple of {} (IE the number of complex numbers in a single vector) inner_fft.len() = {}", A::VectorType::COMPLEX_PER_VECTOR, inner_fft_len);
80
81 let inner_fft_scale = A::one() / A::from_usize(inner_fft_len).unwrap();
83 let direction = inner_fft.fft_direction();
84
85 let mut inner_fft_input = vec![Complex::zero(); inner_fft_len];
87 twiddles::fill_bluesteins_twiddles(
88 &mut inner_fft_input[..len],
89 direction.opposite_direction(),
90 );
91
92 inner_fft_input[0] = inner_fft_input[0] * inner_fft_scale;
94 for i in 1..len {
95 let twiddle = inner_fft_input[i] * inner_fft_scale;
96 inner_fft_input[i] = twiddle;
97 inner_fft_input[inner_fft_len - i] = twiddle;
98 }
99
100 let mut inner_fft_scratch = vec![Complex::zero(); inner_fft.get_inplace_scratch_len()];
102
103 {
104 let transmuted_input: &mut [Complex<T>] =
106 array_utils::workaround_transmute_mut(&mut inner_fft_input);
107
108 inner_fft.process_with_scratch(transmuted_input, &mut inner_fft_scratch);
109 }
110
111 let conjugation_mask =
113 AvxVector256::broadcast_complex_elements(Complex::new(A::zero(), -A::zero()));
114 let inner_fft_multiplier = inner_fft_input
115 .chunks_exact(A::VectorType::COMPLEX_PER_VECTOR)
116 .map(|chunk| {
117 let chunk_vector = chunk.load_complex(0);
118 AvxVector::xor(chunk_vector, conjugation_mask) })
120 .collect::<Vec<_>>()
121 .into_boxed_slice();
122
123 let chunk_count = div_ceil(len, A::VectorType::COMPLEX_PER_VECTOR);
125 let twiddle_count = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
126 let mut twiddles_scalar: Vec<Complex<A>> = vec![Complex::zero(); twiddle_count];
127 twiddles::fill_bluesteins_twiddles(&mut twiddles_scalar[..len], direction);
128
129 let twiddles: Vec<_> = twiddles_scalar
131 .chunks_exact(A::VectorType::COMPLEX_PER_VECTOR)
132 .map(|chunk| chunk.load_complex(0))
133 .collect();
134
135 let required_scratch = inner_fft_input.len() + inner_fft_scratch.len();
136
137 Self {
138 inner_fft_multiplier,
139 common_data: CommonSimdData {
140 inner_fft,
141 twiddles: twiddles.into_boxed_slice(),
142
143 len,
144
145 inplace_scratch_len: required_scratch,
146 outofplace_scratch_len: required_scratch,
147
148 direction,
149 },
150 _phantom: std::marker::PhantomData,
151 }
152 }
153
154 #[target_feature(enable = "avx", enable = "fma")]
156 unsafe fn prepare_bluesteins(
157 &self,
158 input: &[Complex<A>],
159 mut inner_fft_buffer: &mut [Complex<A>],
160 ) {
161 let chunk_count = self.common_data.twiddles.len() - 1;
162 let remainder = self.len() - chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
163
164 for (i, twiddle) in self.common_data.twiddles[..chunk_count].iter().enumerate() {
166 let index = i * A::VectorType::COMPLEX_PER_VECTOR;
167 let input_vector = input.load_complex(index);
168 let product_vector = AvxVector::mul_complex(input_vector, *twiddle);
169 inner_fft_buffer.store_complex(product_vector, index);
170 }
171
172 {
175 let remainder_twiddle = self.common_data.twiddles[chunk_count];
176
177 let remainder_index = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
178 let remainder_data = match remainder {
179 1 => input.load_partial1_complex(remainder_index).zero_extend(),
180 2 => {
181 if A::VectorType::COMPLEX_PER_VECTOR == 2 {
182 input.load_complex(remainder_index)
183 } else {
184 input.load_partial2_complex(remainder_index).zero_extend()
185 }
186 }
187 3 => input.load_partial3_complex(remainder_index),
188 4 => input.load_complex(remainder_index),
189 _ => unreachable!(),
190 };
191
192 let twiddled_remainder = AvxVector::mul_complex(remainder_twiddle, remainder_data);
193 inner_fft_buffer.store_complex(twiddled_remainder, remainder_index);
194 }
195
196 let zerofill_start = chunk_count + 1;
198 for i in zerofill_start..(inner_fft_buffer.len() / A::VectorType::COMPLEX_PER_VECTOR) {
199 let index = i * A::VectorType::COMPLEX_PER_VECTOR;
200 inner_fft_buffer.store_complex(AvxVector::zero(), index);
201 }
202 }
203
204 #[target_feature(enable = "avx", enable = "fma")]
206 unsafe fn finalize_bluesteins(
207 &self,
208 inner_fft_buffer: &[Complex<A>],
209 mut output: &mut [Complex<A>],
210 ) {
211 let chunk_count = self.common_data.twiddles.len() - 1;
212 let remainder = self.len() - chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
213
214 for (i, twiddle) in self.common_data.twiddles[..chunk_count].iter().enumerate() {
216 let index = i * A::VectorType::COMPLEX_PER_VECTOR;
217 let inner_vector = inner_fft_buffer.load_complex(index);
218 let product_vector = Self::mul_complex_conjugated(inner_vector, *twiddle);
219 output.store_complex(product_vector, index);
220 }
221
222 {
224 let remainder_twiddle = self.common_data.twiddles[chunk_count];
225
226 let remainder_index = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
227 let inner_vector = inner_fft_buffer.load_complex(remainder_index);
228 let product_vector = Self::mul_complex_conjugated(inner_vector, remainder_twiddle);
229
230 match remainder {
231 1 => output.store_partial1_complex(product_vector.lo(), remainder_index),
232 2 => {
233 if A::VectorType::COMPLEX_PER_VECTOR == 2 {
234 output.store_complex(product_vector, remainder_index)
235 } else {
236 output.store_partial2_complex(product_vector.lo(), remainder_index)
237 }
238 }
239 3 => output.store_partial3_complex(product_vector, remainder_index),
240 4 => output.store_complex(product_vector, remainder_index),
241 _ => unreachable!(),
242 };
243 }
244 }
245
246 #[target_feature(enable = "avx", enable = "fma")]
248 unsafe fn pairwise_complex_multiply_conjugated(
249 mut buffer: impl AvxArrayMut<A>,
250 multiplier: &[A::VectorType],
251 ) {
252 for (i, right) in multiplier.iter().enumerate() {
253 let left = buffer.load_complex(i * A::VectorType::COMPLEX_PER_VECTOR);
254
255 let product = Self::mul_complex_conjugated(left, *right);
257
258 buffer.store_complex(product, i * A::VectorType::COMPLEX_PER_VECTOR);
260 }
261 }
262
263 fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
264 let (inner_input, inner_scratch) = scratch
265 .split_at_mut(self.inner_fft_multiplier.len() * A::VectorType::COMPLEX_PER_VECTOR);
266
267 unsafe {
269 let transmuted_buffer: &mut [Complex<A>] =
271 array_utils::workaround_transmute_mut(buffer);
272 let transmuted_inner_input: &mut [Complex<A>] =
273 array_utils::workaround_transmute_mut(inner_input);
274
275 self.prepare_bluesteins(transmuted_buffer, transmuted_inner_input);
276 }
277
278 self.common_data
280 .inner_fft
281 .process_with_scratch(inner_input, inner_scratch);
282
283 unsafe {
287 let transmuted_inner_input: &mut [Complex<A>] =
289 array_utils::workaround_transmute_mut(inner_input);
290
291 Self::pairwise_complex_multiply_conjugated(
292 transmuted_inner_input,
293 &self.inner_fft_multiplier,
294 )
295 };
296
297 self.common_data
299 .inner_fft
300 .process_with_scratch(inner_input, inner_scratch);
301
302 unsafe {
304 let transmuted_buffer: &mut [Complex<A>] =
306 array_utils::workaround_transmute_mut(buffer);
307 let transmuted_inner_input: &mut [Complex<A>] =
308 array_utils::workaround_transmute_mut(inner_input);
309
310 self.finalize_bluesteins(transmuted_inner_input, transmuted_buffer);
311 }
312 }
313
314 fn perform_fft_out_of_place(
315 &self,
316 input: &[Complex<T>],
317 output: &mut [Complex<T>],
318 scratch: &mut [Complex<T>],
319 ) {
320 let (inner_input, inner_scratch) = scratch
321 .split_at_mut(self.inner_fft_multiplier.len() * A::VectorType::COMPLEX_PER_VECTOR);
322
323 unsafe {
325 let transmuted_input: &[Complex<A>] = array_utils::workaround_transmute(input);
327 let transmuted_inner_input: &mut [Complex<A>] =
328 array_utils::workaround_transmute_mut(inner_input);
329
330 self.prepare_bluesteins(transmuted_input, transmuted_inner_input)
331 }
332
333 self.common_data
335 .inner_fft
336 .process_with_scratch(inner_input, inner_scratch);
337
338 unsafe {
342 let transmuted_inner_input: &mut [Complex<A>] =
344 array_utils::workaround_transmute_mut(inner_input);
345
346 Self::pairwise_complex_multiply_conjugated(
347 transmuted_inner_input,
348 &self.inner_fft_multiplier,
349 )
350 };
351
352 self.common_data
354 .inner_fft
355 .process_with_scratch(inner_input, inner_scratch);
356
357 unsafe {
359 let transmuted_output: &mut [Complex<A>] =
361 array_utils::workaround_transmute_mut(output);
362 let transmuted_inner_input: &mut [Complex<A>] =
363 array_utils::workaround_transmute_mut(inner_input);
364
365 self.finalize_bluesteins(transmuted_inner_input, transmuted_output)
366 }
367 }
368}
369
370#[cfg(test)]
371mod unit_tests {
372 use num_traits::Float;
373 use rand::distributions::uniform::SampleUniform;
374
375 use super::*;
376 use crate::algorithm::Dft;
377 use crate::test_utils::check_fft_algorithm;
378 use std::sync::Arc;
379
380 #[test]
381 fn test_bluesteins_avx_f32() {
382 for len in 2..16 {
383 let minimum_inner: usize = len * 2 - 1;
386 let remainder = minimum_inner % 4;
387
388 let next_multiple_of_4 = minimum_inner - remainder + 4;
390 let maximum_inner = minimum_inner.checked_next_power_of_two().unwrap() + 1;
391
392 for inner_len in (next_multiple_of_4..maximum_inner).step_by(4) {
394 test_bluesteins_avx_with_length::<f32>(len, inner_len, FftDirection::Forward);
395 test_bluesteins_avx_with_length::<f32>(len, inner_len, FftDirection::Inverse);
396 }
397 }
398 }
399
400 #[test]
401 fn test_bluesteins_avx_f64() {
402 for len in 2..16 {
403 let minimum_inner: usize = len * 2 - 1;
406 let remainder = minimum_inner % 2;
407
408 let next_multiple_of_2 = minimum_inner + remainder;
409 let maximum_inner = minimum_inner.checked_next_power_of_two().unwrap() + 1;
410
411 for inner_len in (next_multiple_of_2..maximum_inner).step_by(2) {
413 test_bluesteins_avx_with_length::<f64>(len, inner_len, FftDirection::Forward);
414 test_bluesteins_avx_with_length::<f64>(len, inner_len, FftDirection::Inverse);
415 }
416 }
417 }
418
419 fn test_bluesteins_avx_with_length<T: AvxNum + Float + SampleUniform>(
420 len: usize,
421 inner_len: usize,
422 direction: FftDirection,
423 ) {
424 let inner_fft = Arc::new(Dft::new(inner_len, direction));
425 let fft: BluesteinsAvx<T, T> = BluesteinsAvx::new(len, inner_fft).expect(
426 "Can't run test because this machine doesn't have the required instruction sets",
427 );
428
429 check_fft_algorithm::<T>(&fft, len, direction);
430 }
431}