rustfft/avx/
avx32_butterflies.rs

1use std::arch::x86_64::*;
2use std::marker::PhantomData;
3use std::mem::MaybeUninit;
4
5use num_complex::Complex;
6
7use crate::array_utils;
8use crate::array_utils::workaround_transmute_mut;
9use crate::array_utils::DoubleBuf;
10use crate::common::{fft_error_inplace, fft_error_outofplace};
11use crate::{common::FftNum, twiddles};
12use crate::{Direction, Fft, FftDirection, Length};
13
14use super::avx32_utils;
15use super::avx_vector::{self, AvxArray};
16use super::avx_vector::{AvxArrayMut, AvxVector, AvxVector128, AvxVector256, Rotation90};
17
18// Safety: This macro will call `self::perform_fft_f32()` which probably has a #[target_feature(enable = "...")] annotation on it.
19// Calling functions with that annotation is unsafe, because it doesn't actually check if the CPU has the required features.
20// Callers of this macro must guarantee that users can't even obtain an instance of $struct_name if their CPU doesn't have the required CPU features.
21macro_rules! boilerplate_fft_simd_butterfly {
22    ($struct_name:ident, $len:expr) => {
23        impl $struct_name<f32> {
24            #[inline]
25            pub fn is_supported_by_cpu() -> bool {
26                is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma")
27            }
28            #[inline]
29            pub fn new(direction: FftDirection) -> Result<Self, ()> {
30                if Self::is_supported_by_cpu() {
31                    // Safety: new_internal requires the "avx" feature set. Since we know it's present, we're safe
32                    Ok(unsafe { Self::new_with_avx(direction) })
33                } else {
34                    Err(())
35                }
36            }
37        }
38
39        impl<T: FftNum> Fft<T> for $struct_name<f32> {
40            fn process_outofplace_with_scratch(
41                &self,
42                input: &mut [Complex<T>],
43                output: &mut [Complex<T>],
44                _scratch: &mut [Complex<T>],
45            ) {
46                if input.len() < self.len() || output.len() != input.len() {
47                    // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
48                    fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0);
49                    return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here
50                }
51
52                let result = array_utils::iter_chunks_zipped(
53                    input,
54                    output,
55                    self.len(),
56                    |in_chunk, out_chunk| {
57                        unsafe {
58                            // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices
59                            let input_slice = workaround_transmute_mut(in_chunk);
60                            let output_slice = workaround_transmute_mut(out_chunk);
61                            self.perform_fft_f32(DoubleBuf {
62                                input: input_slice,
63                                output: output_slice,
64                            });
65                        }
66                    },
67                );
68
69                if result.is_err() {
70                    // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size,
71                    // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
72                    fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0);
73                }
74            }
75            fn process_with_scratch(&self, buffer: &mut [Complex<T>], _scratch: &mut [Complex<T>]) {
76                if buffer.len() < self.len() {
77                    // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
78                    fft_error_inplace(self.len(), buffer.len(), 0, 0);
79                    return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here
80                }
81
82                let result = array_utils::iter_chunks(buffer, self.len(), |chunk| {
83                    unsafe {
84                        // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices
85                        self.perform_fft_f32(workaround_transmute_mut::<_, Complex<f32>>(chunk));
86                    }
87                });
88
89                if result.is_err() {
90                    // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size,
91                    // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
92                    fft_error_inplace(self.len(), buffer.len(), 0, 0);
93                }
94            }
95            #[inline(always)]
96            fn get_inplace_scratch_len(&self) -> usize {
97                0
98            }
99            #[inline(always)]
100            fn get_outofplace_scratch_len(&self) -> usize {
101                0
102            }
103        }
104        impl<T> Length for $struct_name<T> {
105            #[inline(always)]
106            fn len(&self) -> usize {
107                $len
108            }
109        }
110        impl<T> Direction for $struct_name<T> {
111            #[inline(always)]
112            fn fft_direction(&self) -> FftDirection {
113                self.direction
114            }
115        }
116    };
117}
118
119// Safety: This macro will call `self::column_butterflies_and_transpose and self::row_butterflies()` which probably has a #[target_feature(enable = "...")] annotation on it.
120// Calling functions with that annotation is unsafe, because it doesn't actually check if the CPU has the required features.
121// Callers of this macro must guarantee that users can't even obtain an instance of $struct_name if their CPU doesn't have the required CPU features.
122macro_rules! boilerplate_fft_simd_butterfly_with_scratch {
123    ($struct_name:ident, $len:expr) => {
124        impl $struct_name<f32> {
125            #[inline]
126            pub fn new(direction: FftDirection) -> Result<Self, ()> {
127                let has_avx = is_x86_feature_detected!("avx");
128                let has_fma = is_x86_feature_detected!("fma");
129                if has_avx && has_fma {
130                    // Safety: new_internal requires the "avx" feature set. Since we know it's present, we're safe
131                    Ok(unsafe { Self::new_with_avx(direction) })
132                } else {
133                    Err(())
134                }
135            }
136        }
137        impl<T: FftNum> $struct_name<T> {
138            #[inline]
139            fn perform_fft_inplace(
140                &self,
141                buffer: &mut [Complex<f32>],
142                scratch: &mut [Complex<f32>],
143            ) {
144                // Perform the column FFTs
145                // Safety: self.perform_column_butterflies() requres the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available
146                unsafe { self.column_butterflies_and_transpose(buffer, scratch) };
147
148                // process the row FFTs, and copy from the scratch back to the buffer as we go
149                // Safety: self.transpose() requres the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available
150                unsafe {
151                    self.row_butterflies(DoubleBuf {
152                        input: scratch,
153                        output: buffer,
154                    })
155                };
156            }
157
158            #[inline]
159            fn perform_fft_out_of_place(
160                &self,
161                input: &mut [Complex<f32>],
162                output: &mut [Complex<f32>],
163            ) {
164                // Perform the column FFTs
165                // Safety: self.perform_column_butterflies() requres the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available
166                unsafe { self.column_butterflies_and_transpose(input, output) };
167
168                // process the row FFTs in-place in the output buffer
169                // Safety: self.transpose() requres the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available
170                unsafe { self.row_butterflies(output) };
171            }
172        }
173        impl<T: FftNum> Fft<T> for $struct_name<f32> {
174            fn process_outofplace_with_scratch(
175                &self,
176                input: &mut [Complex<T>],
177                output: &mut [Complex<T>],
178                _scratch: &mut [Complex<T>],
179            ) {
180                if input.len() < self.len() || output.len() != input.len() {
181                    // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
182                    fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0);
183                    return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here
184                }
185
186                // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
187                let transmuted_input: &mut [Complex<f32>] =
188                    unsafe { array_utils::workaround_transmute_mut(input) };
189                let transmuted_output: &mut [Complex<f32>] =
190                    unsafe { array_utils::workaround_transmute_mut(output) };
191                let result = array_utils::iter_chunks_zipped(
192                    transmuted_input,
193                    transmuted_output,
194                    self.len(),
195                    |in_chunk, out_chunk| self.perform_fft_out_of_place(in_chunk, out_chunk),
196                );
197
198                if result.is_err() {
199                    // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size,
200                    // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
201                    fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0);
202                }
203            }
204            fn process_with_scratch(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
205                let required_scratch = self.len();
206                if scratch.len() < required_scratch || buffer.len() < self.len() {
207                    // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
208                    fft_error_inplace(self.len(), buffer.len(), self.len(), scratch.len());
209                    return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here
210                }
211
212                let scratch = &mut scratch[..required_scratch];
213
214                // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
215                let transmuted_buffer: &mut [Complex<f32>] =
216                    unsafe { array_utils::workaround_transmute_mut(buffer) };
217                let transmuted_scratch: &mut [Complex<f32>] =
218                    unsafe { array_utils::workaround_transmute_mut(scratch) };
219                let result = array_utils::iter_chunks(transmuted_buffer, self.len(), |chunk| {
220                    self.perform_fft_inplace(chunk, transmuted_scratch)
221                });
222
223                if result.is_err() {
224                    // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size,
225                    // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
226                    fft_error_inplace(self.len(), buffer.len(), self.len(), scratch.len());
227                }
228            }
229            #[inline(always)]
230            fn get_inplace_scratch_len(&self) -> usize {
231                $len
232            }
233            #[inline(always)]
234            fn get_outofplace_scratch_len(&self) -> usize {
235                0
236            }
237        }
238        impl<T> Length for $struct_name<T> {
239            #[inline(always)]
240            fn len(&self) -> usize {
241                $len
242            }
243        }
244        impl<T> Direction for $struct_name<T> {
245            #[inline(always)]
246            fn fft_direction(&self) -> FftDirection {
247                self.direction
248            }
249        }
250    };
251}
252
253macro_rules! gen_butterfly_twiddles_interleaved_columns {
254    ($num_rows:expr, $num_cols:expr, $skip_cols:expr, $direction: expr) => {{
255        const FFT_LEN: usize = $num_rows * $num_cols;
256        const TWIDDLE_ROWS: usize = $num_rows - 1;
257        const TWIDDLE_COLS: usize = $num_cols - $skip_cols;
258        const TWIDDLE_VECTOR_COLS: usize = TWIDDLE_COLS / 4;
259        const TWIDDLE_VECTOR_COUNT: usize = TWIDDLE_VECTOR_COLS * TWIDDLE_ROWS;
260        let mut twiddles = [AvxVector::zero(); TWIDDLE_VECTOR_COUNT];
261        for index in 0..TWIDDLE_VECTOR_COUNT {
262            let y = (index / TWIDDLE_VECTOR_COLS) + 1;
263            let x = (index % TWIDDLE_VECTOR_COLS) * 4 + $skip_cols;
264
265            twiddles[index] = AvxVector::make_mixedradix_twiddle_chunk(x, y, FFT_LEN, $direction);
266        }
267        twiddles
268    }};
269}
270
271macro_rules! gen_butterfly_twiddles_separated_columns {
272    ($num_rows:expr, $num_cols:expr, $skip_cols:expr, $direction: expr) => {{
273        const FFT_LEN: usize = $num_rows * $num_cols;
274        const TWIDDLE_ROWS: usize = $num_rows - 1;
275        const TWIDDLE_COLS: usize = $num_cols - $skip_cols;
276        const TWIDDLE_VECTOR_COLS: usize = TWIDDLE_COLS / 4;
277        const TWIDDLE_VECTOR_COUNT: usize = TWIDDLE_VECTOR_COLS * TWIDDLE_ROWS;
278        let mut twiddles = [AvxVector::zero(); TWIDDLE_VECTOR_COUNT];
279        for index in 0..TWIDDLE_VECTOR_COUNT {
280            let y = (index % TWIDDLE_ROWS) + 1;
281            let x = (index / TWIDDLE_ROWS) * 4 + $skip_cols;
282
283            twiddles[index] = AvxVector::make_mixedradix_twiddle_chunk(x, y, FFT_LEN, $direction);
284        }
285        twiddles
286    }};
287}
288
289pub struct Butterfly5Avx<T> {
290    twiddles: [__m128; 3],
291    direction: FftDirection,
292    _phantom_t: std::marker::PhantomData<T>,
293}
294boilerplate_fft_simd_butterfly!(Butterfly5Avx, 5);
295impl Butterfly5Avx<f32> {
296    #[target_feature(enable = "avx")]
297    unsafe fn new_with_avx(direction: FftDirection) -> Self {
298        let twiddle1 = twiddles::compute_twiddle(1, 5, direction);
299        let twiddle2 = twiddles::compute_twiddle(2, 5, direction);
300        Self {
301            twiddles: [
302                _mm_set_ps(twiddle1.im, twiddle1.im, twiddle1.re, twiddle1.re),
303                _mm_set_ps(twiddle2.im, twiddle2.im, twiddle2.re, twiddle2.re),
304                _mm_set_ps(-twiddle1.im, -twiddle1.im, twiddle1.re, twiddle1.re),
305            ],
306            direction,
307            _phantom_t: PhantomData,
308        }
309    }
310}
311impl<T> Butterfly5Avx<T> {
312    #[target_feature(enable = "avx", enable = "fma")]
313    unsafe fn perform_fft_f32(&self, mut buffer: impl AvxArrayMut<f32>) {
314        let input0 = _mm_castpd_ps(_mm_load1_pd(buffer.input_ptr() as *const f64)); // load the first element of the input, and duplicate it into both complex number slots of input0
315        let input12 = buffer.load_partial2_complex(1);
316        let input34 = buffer.load_partial2_complex(3);
317
318        // swap elements for inputs 3 and 4
319        let input43 = AvxVector::reverse_complex_elements(input34);
320
321        // do some prep work before we can start applying twiddle factors
322        let [sum12, diff43] = AvxVector::column_butterfly2([input12, input43]);
323
324        let rotation = AvxVector::make_rotation90(FftDirection::Inverse);
325        let rotated43 = AvxVector::rotate90(diff43, rotation);
326
327        let [mid14, mid23] = AvxVector::unpack_complex([sum12, rotated43]);
328
329        // to compute the first output, compute the sum of all elements. mid14[0] and mid23[0] already have the sum of 1+4 and 2+3 respectively, so if we add them, we'll get the sum of all 4
330        let sum1234 = AvxVector::add(mid14, mid23);
331        let output0 = AvxVector::add(input0, sum1234);
332
333        // apply twiddle factors
334        let twiddled14_mid = AvxVector::mul(mid14, self.twiddles[0]);
335        let twiddled23_mid = AvxVector::mul(mid14, self.twiddles[1]);
336        let twiddled14 = AvxVector::fmadd(mid23, self.twiddles[1], twiddled14_mid);
337        let twiddled23 = AvxVector::fmadd(mid23, self.twiddles[2], twiddled23_mid);
338
339        // unpack the data for the last butterfly 2
340        let [twiddled12, twiddled43] = AvxVector::unpack_complex([twiddled14, twiddled23]);
341        let [output12, output43] = AvxVector::column_butterfly2([twiddled12, twiddled43]);
342
343        // swap the elements in output43 before writing them out, and add the first input to everything
344        let final12 = AvxVector::add(input0, output12);
345        let output34 = AvxVector::reverse_complex_elements(output43);
346        let final34 = AvxVector::add(input0, output34);
347
348        buffer.store_partial1_complex(output0, 0);
349        buffer.store_partial2_complex(final12, 1);
350        buffer.store_partial2_complex(final34, 3);
351    }
352}
353
354pub struct Butterfly7Avx<T> {
355    twiddles: [__m128; 5],
356    direction: FftDirection,
357    _phantom_t: std::marker::PhantomData<T>,
358}
359boilerplate_fft_simd_butterfly!(Butterfly7Avx, 7);
360impl Butterfly7Avx<f32> {
361    #[target_feature(enable = "avx")]
362    unsafe fn new_with_avx(direction: FftDirection) -> Self {
363        let twiddle1 = twiddles::compute_twiddle(1, 7, direction);
364        let twiddle2 = twiddles::compute_twiddle(2, 7, direction);
365        let twiddle3 = twiddles::compute_twiddle(3, 7, direction);
366        Self {
367            twiddles: [
368                _mm_set_ps(twiddle1.im, twiddle1.im, twiddle1.re, twiddle1.re),
369                _mm_set_ps(twiddle2.im, twiddle2.im, twiddle2.re, twiddle2.re),
370                _mm_set_ps(twiddle3.im, twiddle3.im, twiddle3.re, twiddle3.re),
371                _mm_set_ps(-twiddle3.im, -twiddle3.im, twiddle3.re, twiddle3.re),
372                _mm_set_ps(-twiddle1.im, -twiddle1.im, twiddle1.re, twiddle1.re),
373            ],
374            direction,
375            _phantom_t: PhantomData,
376        }
377    }
378}
379impl<T> Butterfly7Avx<T> {
380    #[target_feature(enable = "avx", enable = "fma")]
381    unsafe fn perform_fft_f32(&self, mut buffer: impl AvxArrayMut<f32>) {
382        // load the first element of the input, and duplicate it into both complex number slots of input0
383        let input0 = _mm_castpd_ps(_mm_load1_pd(buffer.input_ptr() as *const f64));
384
385        // we want to load 3 elements into 123 and 3 elements into 456, but we can only load 4, so we're going to do slightly overlapping reads here
386        // we have to reverse 456 immediately after loading, and that'll be easiest if we load the 456 into the latter 3 slots of the register, rather than the front 3 slots
387        // as a bonus, that also means we don't need masked reads or anything
388        let input123 = buffer.load_complex(1);
389        let input456 = buffer.load_complex(3);
390
391        // reverse the order of input456
392        let input654 = AvxVector::reverse_complex_elements(input456);
393
394        // do some prep work before we can start applying twiddle factors
395        let [sum123, diff654] = AvxVector::column_butterfly2([input123, input654]);
396        let rotation = AvxVector::make_rotation90(FftDirection::Inverse);
397        let rotated654 = AvxVector::rotate90(diff654, rotation);
398
399        let [mid1634, mid25] = AvxVector::unpack_complex([sum123, rotated654]);
400
401        let mid16 = mid1634.lo();
402        let mid25 = mid25.lo();
403        let mid34 = mid1634.hi();
404
405        // to compute the first output, compute the sum of all elements. mid16[0], mid25[0], and mid34[0] already have the sum of 1+6, 2+5 and 3+4 respectively, so if we add them, we'll get 1+2+3+4+5+6
406        let output0_left = AvxVector::add(mid16, mid25);
407        let output0_right = AvxVector::add(input0, mid34);
408        let output0 = AvxVector::add(output0_left, output0_right);
409        buffer.store_partial1_complex(output0, 0);
410
411        _mm256_zeroupper();
412
413        // apply twiddle factors
414        let twiddled16_intermediate1 = AvxVector::mul(mid16, self.twiddles[0]);
415        let twiddled25_intermediate1 = AvxVector::mul(mid16, self.twiddles[1]);
416        let twiddled34_intermediate1 = AvxVector::mul(mid16, self.twiddles[2]);
417
418        let twiddled16_intermediate2 =
419            AvxVector::fmadd(mid25, self.twiddles[1], twiddled16_intermediate1);
420        let twiddled25_intermediate2 =
421            AvxVector::fmadd(mid25, self.twiddles[3], twiddled25_intermediate1);
422        let twiddled34_intermediate2 =
423            AvxVector::fmadd(mid25, self.twiddles[4], twiddled34_intermediate1);
424
425        let twiddled16 = AvxVector::fmadd(mid34, self.twiddles[2], twiddled16_intermediate2);
426        let twiddled25 = AvxVector::fmadd(mid34, self.twiddles[4], twiddled25_intermediate2);
427        let twiddled34 = AvxVector::fmadd(mid34, self.twiddles[1], twiddled34_intermediate2);
428
429        // unpack the data for the last butterfly 2
430        let [twiddled12, twiddled65] = AvxVector::unpack_complex([twiddled16, twiddled25]);
431        let [twiddled33, twiddled44] = AvxVector::unpack_complex([twiddled34, twiddled34]);
432
433        // we can save one add if we add input0 to twiddled33 now. normally we'd add input0 to the final output, but the arrangement of data makes that a little awkward
434        let twiddled033 = AvxVector::add(twiddled33, input0);
435
436        let [output12, output65] = AvxVector::column_butterfly2([twiddled12, twiddled65]);
437        let [output033, output044] = AvxVector::column_butterfly2([twiddled033, twiddled44]);
438        let output56 = AvxVector::reverse_complex_elements(output65);
439
440        buffer.store_partial2_complex(AvxVector::add(output12, input0), 1);
441        buffer.store_partial1_complex(output033, 3);
442        buffer.store_partial1_complex(output044, 4);
443        buffer.store_partial2_complex(AvxVector::add(output56, input0), 5);
444    }
445}
446
447pub struct Butterfly11Avx<T> {
448    twiddles: [__m256; 10],
449    twiddle_lo_4: __m128,
450    twiddle_lo_9: __m128,
451    twiddle_lo_3: __m128,
452    twiddle_lo_8: __m128,
453    twiddle_lo_2: __m128,
454    direction: FftDirection,
455    _phantom_t: std::marker::PhantomData<T>,
456}
457boilerplate_fft_simd_butterfly!(Butterfly11Avx, 11);
458impl Butterfly11Avx<f32> {
459    #[target_feature(enable = "avx")]
460    unsafe fn new_with_avx(direction: FftDirection) -> Self {
461        let twiddle1 = twiddles::compute_twiddle(1, 11, direction);
462        let twiddle2 = twiddles::compute_twiddle(2, 11, direction);
463        let twiddle3 = twiddles::compute_twiddle(3, 11, direction);
464        let twiddle4 = twiddles::compute_twiddle(4, 11, direction);
465        let twiddle5 = twiddles::compute_twiddle(5, 11, direction);
466
467        let twiddles_lo = [
468            _mm_set_ps(twiddle1.im, twiddle1.im, twiddle1.re, twiddle1.re),
469            _mm_set_ps(twiddle2.im, twiddle2.im, twiddle2.re, twiddle2.re),
470            _mm_set_ps(twiddle3.im, twiddle3.im, twiddle3.re, twiddle3.re),
471            _mm_set_ps(twiddle4.im, twiddle4.im, twiddle4.re, twiddle4.re),
472            _mm_set_ps(twiddle5.im, twiddle5.im, twiddle5.re, twiddle5.re),
473            _mm_set_ps(-twiddle5.im, -twiddle5.im, twiddle5.re, twiddle5.re),
474            _mm_set_ps(-twiddle4.im, -twiddle4.im, twiddle4.re, twiddle4.re),
475            _mm_set_ps(-twiddle3.im, -twiddle3.im, twiddle3.re, twiddle3.re),
476            _mm_set_ps(-twiddle2.im, -twiddle2.im, twiddle2.re, twiddle2.re),
477            _mm_set_ps(-twiddle1.im, -twiddle1.im, twiddle1.re, twiddle1.re),
478        ];
479
480        Self {
481            twiddles: [
482                AvxVector256::merge(twiddles_lo[0], twiddles_lo[2]),
483                AvxVector256::merge(twiddles_lo[1], twiddles_lo[3]),
484                AvxVector256::merge(twiddles_lo[1], twiddles_lo[5]),
485                AvxVector256::merge(twiddles_lo[3], twiddles_lo[7]),
486                AvxVector256::merge(twiddles_lo[2], twiddles_lo[8]),
487                AvxVector256::merge(twiddles_lo[5], twiddles_lo[0]),
488                AvxVector256::merge(twiddles_lo[3], twiddles_lo[0]),
489                AvxVector256::merge(twiddles_lo[7], twiddles_lo[4]),
490                AvxVector256::merge(twiddles_lo[4], twiddles_lo[3]),
491                AvxVector256::merge(twiddles_lo[9], twiddles_lo[8]),
492            ],
493            twiddle_lo_4: twiddles_lo[4],
494            twiddle_lo_9: twiddles_lo[9],
495            twiddle_lo_3: twiddles_lo[3],
496            twiddle_lo_8: twiddles_lo[8],
497            twiddle_lo_2: twiddles_lo[2],
498            direction,
499            _phantom_t: PhantomData,
500        }
501    }
502}
503impl<T> Butterfly11Avx<T> {
504    #[target_feature(enable = "avx", enable = "fma")]
505    unsafe fn perform_fft_f32(&self, mut buffer: impl AvxArrayMut<f32>) {
506        let input0 = _mm_castpd_ps(_mm_load1_pd(buffer.input_ptr() as *const f64)); // load the first element of the input, and duplicate it into both complex number slots of input0
507        let input1234 = buffer.load_complex(1);
508        let input56 = buffer.load_partial2_complex(5);
509        let input78910 = buffer.load_complex(7);
510
511        // reverse the order of input78910, and separate
512        let [input55, input66] = AvxVector::unpack_complex([input56, input56]);
513        let input10987 = AvxVector::reverse_complex_elements(input78910);
514
515        // do some initial butterflies and rotations
516        let [sum1234, diff10987] = AvxVector::column_butterfly2([input1234, input10987]);
517        let [sum55, diff66] = AvxVector::column_butterfly2([input55, input66]);
518
519        let rotation = AvxVector::make_rotation90(FftDirection::Inverse);
520        let rotated10987 = AvxVector::rotate90(diff10987, rotation);
521        let rotated66 = AvxVector::rotate90(diff66, rotation.lo());
522
523        // arrange the data into the format to apply twiddles
524        let [mid11038, mid2947] = AvxVector::unpack_complex([sum1234, rotated10987]);
525
526        let mid110: __m256 = AvxVector256::merge(mid11038.lo(), mid11038.lo());
527        let mid29: __m256 = AvxVector256::merge(mid2947.lo(), mid2947.lo());
528        let mid38: __m256 = AvxVector256::merge(mid11038.hi(), mid11038.hi());
529        let mid47: __m256 = AvxVector256::merge(mid2947.hi(), mid2947.hi());
530        let mid56 = AvxVector::unpacklo_complex([sum55, rotated66]);
531        let mid56: __m256 = AvxVector256::merge(mid56, mid56);
532
533        // to compute the first output, compute the sum of all elements. mid16[0], mid25[0], and mid34[0] already have the sum of 1+6, 2+5 and 3+4 respectively, so if we add them, we'll get 1+2+3+4+5+6
534        let mid12910 = AvxVector::add(mid110.lo(), mid29.lo());
535        let mid3478 = AvxVector::add(mid38.lo(), mid47.lo());
536        let output0_left = AvxVector::add(input0, mid56.lo());
537        let output0_right = AvxVector::add(mid12910, mid3478);
538        let output0 = AvxVector::add(output0_left, output0_right);
539        buffer.store_partial1_complex(output0, 0);
540
541        // we need to add the first input to each of our 5 twiddles values -- but right now, input0 is duplicated into both slots
542        // but we only want to add it once, so zero the second element
543        let zero = _mm_setzero_pd();
544        let input0 = _mm_castpd_ps(_mm_move_sd(zero, _mm_castps_pd(input0)));
545        let input0 = AvxVector256::merge(input0, input0);
546
547        // apply twiddle factors
548        let twiddled11038 = AvxVector::fmadd(mid110, self.twiddles[0], input0);
549        let twiddled2947 = AvxVector::fmadd(mid110, self.twiddles[1], input0);
550        let twiddled56 = AvxVector::fmadd(mid110.lo(), self.twiddle_lo_4, input0.lo());
551
552        let twiddled11038 = AvxVector::fmadd(mid29, self.twiddles[2], twiddled11038);
553        let twiddled2947 = AvxVector::fmadd(mid29, self.twiddles[3], twiddled2947);
554        let twiddled56 = AvxVector::fmadd(mid29.lo(), self.twiddle_lo_9, twiddled56);
555
556        let twiddled11038 = AvxVector::fmadd(mid38, self.twiddles[4], twiddled11038);
557        let twiddled2947 = AvxVector::fmadd(mid38, self.twiddles[5], twiddled2947);
558        let twiddled56 = AvxVector::fmadd(mid38.lo(), self.twiddle_lo_3, twiddled56);
559
560        let twiddled11038 = AvxVector::fmadd(mid47, self.twiddles[6], twiddled11038);
561        let twiddled2947 = AvxVector::fmadd(mid47, self.twiddles[7], twiddled2947);
562        let twiddled56 = AvxVector::fmadd(mid47.lo(), self.twiddle_lo_8, twiddled56);
563
564        let twiddled11038 = AvxVector::fmadd(mid56, self.twiddles[8], twiddled11038);
565        let twiddled2947 = AvxVector::fmadd(mid56, self.twiddles[9], twiddled2947);
566        let twiddled56 = AvxVector::fmadd(mid56.lo(), self.twiddle_lo_2, twiddled56);
567
568        // unpack the data for the last butterfly 2
569        let [twiddled1234, twiddled10987] =
570            AvxVector::unpack_complex([twiddled11038, twiddled2947]);
571        let [twiddled55, twiddled66] = AvxVector::unpack_complex([twiddled56, twiddled56]);
572
573        let [output1234, output10987] = AvxVector::column_butterfly2([twiddled1234, twiddled10987]);
574        let [output55, output66] = AvxVector::column_butterfly2([twiddled55, twiddled66]);
575        let output78910 = AvxVector::reverse_complex_elements(output10987);
576
577        buffer.store_complex(output1234, 1);
578        buffer.store_partial1_complex(output55, 5);
579        buffer.store_partial1_complex(output66, 6);
580        buffer.store_complex(output78910, 7);
581    }
582}
583
584pub struct Butterfly8Avx<T> {
585    twiddles: __m256,
586    twiddles_butterfly4: __m256,
587    direction: FftDirection,
588    _phantom_t: std::marker::PhantomData<T>,
589}
590boilerplate_fft_simd_butterfly!(Butterfly8Avx, 8);
591impl Butterfly8Avx<f32> {
592    #[target_feature(enable = "avx")]
593    unsafe fn new_with_avx(direction: FftDirection) -> Self {
594        Self {
595            twiddles: AvxVector::make_mixedradix_twiddle_chunk(0, 1, 8, direction),
596            twiddles_butterfly4: match direction {
597                FftDirection::Forward => [
598                    Complex::new(0.0f32, 0.0),
599                    Complex::new(0.0, -0.0),
600                    Complex::new(0.0, 0.0),
601                    Complex::new(0.0, -0.0),
602                ]
603                .as_slice()
604                .load_complex(0),
605                FftDirection::Inverse => [
606                    Complex::new(0.0f32, 0.0),
607                    Complex::new(-0.0, 0.0),
608                    Complex::new(0.0, 0.0),
609                    Complex::new(-0.0, 0.0),
610                ]
611                .as_slice()
612                .load_complex(0),
613            },
614            direction,
615            _phantom_t: PhantomData,
616        }
617    }
618}
619impl<T> Butterfly8Avx<T> {
620    #[target_feature(enable = "avx", enable = "fma")]
621    unsafe fn perform_fft_f32(&self, mut buffer: impl AvxArrayMut<f32>) {
622        let row0 = buffer.load_complex(0);
623        let row1 = buffer.load_complex(4);
624
625        // Do our butterfly 2's down the columns
626        let [intermediate0, intermediate1_pretwiddle] = AvxVector::column_butterfly2([row0, row1]);
627
628        // Apply the size-8 twiddle factors
629        let intermediate1 = AvxVector::mul_complex(intermediate1_pretwiddle, self.twiddles);
630
631        // Rearrange the data before we do our butterfly 4s. This swaps the last 2 elements of butterfly0 with the first two elements of butterfly1
632        // The result is that we can then do a 4x butterfly 2, apply twiddles, use unpack instructions to transpose to the final output, then do another 4x butterfly 2
633        let permuted0 = _mm256_permute2f128_ps(intermediate0, intermediate1, 0x20);
634        let permuted1 = _mm256_permute2f128_ps(intermediate0, intermediate1, 0x31);
635
636        // Do the first set of butterfly 2's
637        let [postbutterfly0, postbutterfly1_pretwiddle] =
638            AvxVector::column_butterfly2([permuted0, permuted1]);
639
640        // Which negative we blend in depends on whether we're forward or direction
641        // Our goal is to swap the reals with the imaginaries, then negate either the reals or the imaginaries, based on whether we're an direction or not
642        // but we can't use the AvxVector swap_complex_components function, because we only want to swap the odd reals with the odd imaginaries
643        let elements_swapped = _mm256_permute_ps(postbutterfly1_pretwiddle, 0xB4);
644
645        // We can negate the elements we want by xoring the row with a pre-set vector
646        let postbutterfly1 = AvxVector::xor(elements_swapped, self.twiddles_butterfly4);
647
648        // use unpack instructions to transpose, and to prepare for the final butterfly 2's
649        let unpermuted0 = _mm256_permute2f128_ps(postbutterfly0, postbutterfly1, 0x20);
650        let unpermuted1 = _mm256_permute2f128_ps(postbutterfly0, postbutterfly1, 0x31);
651        let unpacked = AvxVector::unpack_complex([unpermuted0, unpermuted1]);
652
653        let [output0, output1] = AvxVector::column_butterfly2(unpacked);
654
655        buffer.store_complex(output0, 0);
656        buffer.store_complex(output1, 4);
657    }
658}
659
660pub struct Butterfly9Avx<T> {
661    twiddles: __m256,
662    twiddles_butterfly3: __m256,
663    direction: FftDirection,
664    _phantom_t: std::marker::PhantomData<T>,
665}
666boilerplate_fft_simd_butterfly!(Butterfly9Avx, 9);
667impl Butterfly9Avx<f32> {
668    #[target_feature(enable = "avx")]
669    unsafe fn new_with_avx(direction: FftDirection) -> Self {
670        let twiddles: [Complex<f32>; 4] = [
671            twiddles::compute_twiddle(1, 9, direction),
672            twiddles::compute_twiddle(2, 9, direction),
673            twiddles::compute_twiddle(2, 9, direction),
674            twiddles::compute_twiddle(4, 9, direction),
675        ];
676        Self {
677            twiddles: twiddles.as_slice().load_complex(0),
678            twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, direction),
679            direction,
680            _phantom_t: PhantomData,
681        }
682    }
683}
684impl<T> Butterfly9Avx<T> {
685    #[target_feature(enable = "avx", enable = "fma")]
686    unsafe fn perform_fft_f32(&self, mut buffer: impl AvxArrayMut<f32>) {
687        // we're going to load these elements in a peculiar way. instead of loading a row into the first 3 element of each register and leaving the last element empty
688        // we're leaving the first element empty and putting the data in the last 3 elements. this will let us do 3 total complex multiplies instead of 4.
689
690        let input0_lo = _mm_castpd_ps(_mm_load1_pd(buffer.input_ptr() as *const f64));
691        let input0_hi = buffer.load_partial2_complex(1);
692        let input0 = AvxVector256::merge(input0_lo, input0_hi);
693        let input1 = buffer.load_complex(2);
694        let input2 = buffer.load_complex(5);
695
696        // We're going to treat our input as a 3x3 2d array. First, do 3 butterfly 3's down the columns of that array.
697        let [mid0, mid1, mid2] =
698            AvxVector::column_butterfly3([input0, input1, input2], self.twiddles_butterfly3);
699
700        // merge the twiddle-able data into a single avx vector
701        let twiddle_data = _mm256_permute2f128_ps(mid1, mid2, 0x31);
702        let twiddled = AvxVector::mul_complex(twiddle_data, self.twiddles);
703
704        // Transpose our 3x3 array. We could use the 4x4 transpose with an empty bottom row, which would result in an empty last column
705        // but it turns out that it'll make our packing process later simpler if we duplicate the second row into the last row
706        // which will result in duplicating the second column into the last column after the transpose
707        let permute0 = _mm256_permute2f128_ps(mid0, mid2, 0x20);
708        let permute1 = _mm256_permute2f128_ps(mid1, mid1, 0x20);
709        let permute2 = _mm256_permute2f128_ps(mid0, twiddled, 0x31);
710        let permute3 = _mm256_permute2f128_ps(twiddled, twiddled, 0x20);
711
712        let transposed0 = AvxVector::unpackhi_complex([permute0, permute1]);
713        let [transposed1, transposed2] = AvxVector::unpack_complex([permute2, permute3]);
714
715        // more size 3 buterflies
716        let output_rows = AvxVector::column_butterfly3(
717            [transposed0, transposed1, transposed2],
718            self.twiddles_butterfly3,
719        );
720
721        // the elements of row 1 are in pretty much the worst possible order, thankfully we can fix that with just a couple instructions
722        let swapped1 = _mm256_permute_ps(output_rows[1], 0x4E); // swap even and odd complex numbers
723        let packed1 = _mm256_permute2f128_ps(swapped1, output_rows[2], 0x21);
724        buffer.store_complex(packed1, 4);
725
726        // merge just the high element of swapped_lo into the high element of row 0
727        let zero_swapped1_lo = AvxVector256::merge(AvxVector::zero(), swapped1.lo());
728        let packed0 = _mm256_blend_ps(output_rows[0], zero_swapped1_lo, 0xC0);
729        buffer.store_complex(packed0, 0);
730
731        // The last element can just be written on its own
732        buffer.store_partial1_complex(output_rows[2].hi(), 8);
733    }
734}
735
736pub struct Butterfly12Avx<T> {
737    twiddles: [__m256; 2],
738    twiddles_butterfly3: __m256,
739    twiddles_butterfly4: Rotation90<__m256>,
740    direction: FftDirection,
741    _phantom_t: std::marker::PhantomData<T>,
742}
743boilerplate_fft_simd_butterfly!(Butterfly12Avx, 12);
744impl Butterfly12Avx<f32> {
745    #[target_feature(enable = "avx")]
746    unsafe fn new_with_avx(direction: FftDirection) -> Self {
747        let twiddles = [
748            Complex {
749                re: 1.0f32,
750                im: 0.0,
751            },
752            Complex { re: 1.0, im: 0.0 },
753            twiddles::compute_twiddle(2, 12, direction),
754            twiddles::compute_twiddle(4, 12, direction),
755            // note that these twiddles are deliberately in a weird order, see perform_fft_f32 for why
756            twiddles::compute_twiddle(1, 12, direction),
757            twiddles::compute_twiddle(2, 12, direction),
758            twiddles::compute_twiddle(3, 12, direction),
759            twiddles::compute_twiddle(6, 12, direction),
760        ];
761        Self {
762            twiddles: [
763                twiddles.as_slice().load_complex(0),
764                twiddles.as_slice().load_complex(4),
765            ],
766            twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, direction),
767            twiddles_butterfly4: AvxVector::make_rotation90(direction),
768            direction,
769            _phantom_t: PhantomData,
770        }
771    }
772}
773impl<T> Butterfly12Avx<T> {
774    #[target_feature(enable = "avx", enable = "fma")]
775    unsafe fn perform_fft_f32(&self, mut buffer: impl AvxArrayMut<f32>) {
776        // we're going to load these elements in a peculiar way. instead of loading a row into the first 3 element of each register and leaving the last element empty
777        // we're leaving the first element empty and putting the data in the last 3 elements. this will save us a complex multiply.
778
779        // for everything but the first element, we can do overlapping reads. for the first element, an "overlapping read" would have us reading from index -1, so instead we have to shuffle some data around
780        let input0_lo = _mm_castpd_ps(_mm_load1_pd(buffer.input_ptr() as *const f64));
781        let input0_hi = buffer.load_partial2_complex(1);
782        let input_rows = [
783            AvxVector256::merge(input0_lo, input0_hi),
784            buffer.load_complex(2),
785            buffer.load_complex(5),
786            buffer.load_complex(8),
787        ];
788
789        // 3 butterfly 4's down the columns
790        let mut mid = AvxVector::column_butterfly4(input_rows, self.twiddles_butterfly4);
791
792        // Multiply in our twiddle factors. mid2 will be normal, but for mid1 and mid3, we're going to merge the twiddle-able parts into a single vector,
793        // and do a single complex multiply on it. this transformation saves a complex multiply and costs nothing,
794        // because we needthe second halves of mid1 and mid3 in a single vector for the transpose afterward anyways, so we would have done this permute2f128 operation either way
795        mid[2] = AvxVector::mul_complex(mid[2], self.twiddles[0]);
796        let merged_mid13 = _mm256_permute2f128_ps(mid[1], mid[3], 0x31);
797        let twiddled13 = AvxVector::mul_complex(self.twiddles[1], merged_mid13);
798
799        // Transpose our 3x4 array into a 4x3. we're doing a custom transpose here because we have to re-distribute the merged twiddled23 back out, and we can roll that into the transpose to make it free
800        let transposed = {
801            let permute0 = _mm256_permute2f128_ps(mid[0], mid[2], 0x20);
802            let permute1 = _mm256_permute2f128_ps(mid[1], mid[3], 0x20);
803            let permute2 = _mm256_permute2f128_ps(mid[0], mid[2], 0x31);
804            let permute3 = twiddled13; // normally we'd need to do a permute here, but we can skip it because we already did it for twiddle factors
805
806            let unpacked1 = AvxVector::unpackhi_complex([permute0, permute1]);
807            let [unpacked2, unpacked3] = AvxVector::unpack_complex([permute2, permute3]);
808
809            [unpacked1, unpacked2, unpacked3]
810        };
811
812        // Do 4 butterfly 3's down the columns of our transposed array
813        let output_rows = AvxVector::column_butterfly3(transposed, self.twiddles_butterfly3);
814
815        buffer.store_complex(output_rows[0], 0);
816        buffer.store_complex(output_rows[1], 4);
817        buffer.store_complex(output_rows[2], 8);
818    }
819}
820
821pub struct Butterfly16Avx<T> {
822    twiddles: [__m256; 3],
823    twiddles_butterfly4: Rotation90<__m256>,
824    direction: FftDirection,
825    _phantom_t: std::marker::PhantomData<T>,
826}
827boilerplate_fft_simd_butterfly!(Butterfly16Avx, 16);
828impl Butterfly16Avx<f32> {
829    #[target_feature(enable = "avx")]
830    unsafe fn new_with_avx(direction: FftDirection) -> Self {
831        Self {
832            twiddles: gen_butterfly_twiddles_interleaved_columns!(4, 4, 0, direction),
833            twiddles_butterfly4: AvxVector::make_rotation90(direction),
834            direction,
835            _phantom_t: PhantomData,
836        }
837    }
838}
839impl<T> Butterfly16Avx<T> {
840    #[target_feature(enable = "avx", enable = "fma")]
841    unsafe fn perform_fft_f32(&self, mut buffer: impl AvxArrayMut<f32>) {
842        // Manually unrolling this loop because writing a "for r in 0..4" loop results in slow codegen that makes the whole thing take 1.5x longer :(
843        let rows = [
844            buffer.load_complex(0),
845            buffer.load_complex(4),
846            buffer.load_complex(8),
847            buffer.load_complex(12),
848        ];
849
850        // We're going to treat our input as a 4x4 2d array. First, do 4 butterfly 4's down the columns of that array.
851        let mut mid = AvxVector::column_butterfly4(rows, self.twiddles_butterfly4);
852
853        // apply twiddle factors
854        for r in 1..4 {
855            mid[r] = AvxVector::mul_complex(mid[r], self.twiddles[r - 1]);
856        }
857
858        // Transpose our 4x4 array
859        let transposed = avx32_utils::transpose_4x4_f32(mid);
860
861        // Do 4 butterfly 4's down the columns of our transposed array
862        let output_rows = AvxVector::column_butterfly4(transposed, self.twiddles_butterfly4);
863
864        // Manually unrolling this loop because writing a "for r in 0..4" loop results in slow codegen that makes the whole thing take 1.5x longer :(
865        buffer.store_complex(output_rows[0], 0);
866        buffer.store_complex(output_rows[1], 4);
867        buffer.store_complex(output_rows[2], 8);
868        buffer.store_complex(output_rows[3], 12);
869    }
870}
871
872pub struct Butterfly24Avx<T> {
873    twiddles: [__m256; 5],
874    twiddles_butterfly3: __m256,
875    twiddles_butterfly4: Rotation90<__m256>,
876    direction: FftDirection,
877    _phantom_t: std::marker::PhantomData<T>,
878}
879boilerplate_fft_simd_butterfly!(Butterfly24Avx, 24);
880impl Butterfly24Avx<f32> {
881    #[target_feature(enable = "avx")]
882    unsafe fn new_with_avx(direction: FftDirection) -> Self {
883        Self {
884            twiddles: gen_butterfly_twiddles_interleaved_columns!(6, 4, 0, direction),
885            twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, direction),
886            twiddles_butterfly4: AvxVector::make_rotation90(direction),
887            direction,
888            _phantom_t: PhantomData,
889        }
890    }
891}
892impl<T> Butterfly24Avx<T> {
893    #[target_feature(enable = "avx", enable = "fma")]
894    unsafe fn perform_fft_f32(&self, mut buffer: impl AvxArrayMut<f32>) {
895        // Manually unrolling this loop because writing a "for r in 0..6" loop results in slow codegen that makes the whole thing take 1.5x longer :(
896        let rows = [
897            buffer.load_complex(0),
898            buffer.load_complex(4),
899            buffer.load_complex(8),
900            buffer.load_complex(12),
901            buffer.load_complex(16),
902            buffer.load_complex(20),
903        ];
904
905        // We're going to treat our input as a 4x6 2d array. First, do 4 butterfly 6's down the columns of that array.
906        let mut mid = AvxVector256::column_butterfly6(rows, self.twiddles_butterfly3);
907
908        // apply twiddle factors
909        for r in 1..6 {
910            mid[r] = AvxVector::mul_complex(mid[r], self.twiddles[r - 1]);
911        }
912
913        // Transpose our 6x4 array into a 4x6.
914        let (transposed0, transposed1) = avx32_utils::transpose_4x6_to_6x4_f32(mid);
915
916        // Do 6 butterfly 4's down the columns of our transposed array
917        let output0 = AvxVector::column_butterfly4(transposed0, self.twiddles_butterfly4);
918        let output1 = AvxVector::column_butterfly4(transposed1, self.twiddles_butterfly4);
919
920        // the upper two elements of output1 are empty, so only store half the data for it
921        for r in 0..4 {
922            buffer.store_complex(output0[r], 6 * r);
923            buffer.store_partial2_complex(output1[r].lo(), r * 6 + 4);
924        }
925    }
926}
927
928pub struct Butterfly27Avx<T> {
929    twiddles: [__m256; 4],
930    twiddles_butterfly9: [__m256; 3],
931    twiddles_butterfly3: __m256,
932    direction: FftDirection,
933    _phantom_t: std::marker::PhantomData<T>,
934}
935boilerplate_fft_simd_butterfly!(Butterfly27Avx, 27);
936impl Butterfly27Avx<f32> {
937    #[target_feature(enable = "avx")]
938    unsafe fn new_with_avx(direction: FftDirection) -> Self {
939        Self {
940            twiddles: gen_butterfly_twiddles_interleaved_columns!(3, 9, 1, direction),
941            twiddles_butterfly9: [
942                AvxVector::broadcast_twiddle(1, 9, direction),
943                AvxVector::broadcast_twiddle(2, 9, direction),
944                AvxVector::broadcast_twiddle(4, 9, direction),
945            ],
946            twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, direction),
947            direction,
948            _phantom_t: PhantomData,
949        }
950    }
951}
952impl<T> Butterfly27Avx<T> {
953    #[target_feature(enable = "avx", enable = "fma")]
954    unsafe fn perform_fft_f32(&self, mut buffer: impl AvxArrayMut<f32>) {
955        // we're going to load our data in a peculiar way. we're going to load the first column on its own as a column of __m128.
956        // it's faster to just load the first 2 columns into these m128s than trying to worry about masks, etc, so the second column will piggyback along and we just won't use it
957        let mut rows0 = [AvxVector::zero(); 3];
958        let mut rows1 = [AvxVector::zero(); 3];
959        let mut rows2 = [AvxVector::zero(); 3];
960        for r in 0..3 {
961            rows0[r] = buffer.load_partial2_complex(r * 9);
962            rows1[r] = buffer.load_complex(r * 9 + 1);
963            rows2[r] = buffer.load_complex(r * 9 + 5);
964        }
965
966        // butterfly 3s down the columns
967        let mid0 = AvxVector::column_butterfly3(rows0, self.twiddles_butterfly3.lo());
968        let mut mid1 = AvxVector::column_butterfly3(rows1, self.twiddles_butterfly3);
969        let mut mid2 = AvxVector::column_butterfly3(rows2, self.twiddles_butterfly3);
970
971        // apply twiddle factors
972        mid1[1] = AvxVector::mul_complex(mid1[1], self.twiddles[0]);
973        mid2[1] = AvxVector::mul_complex(mid2[1], self.twiddles[1]);
974        mid1[2] = AvxVector::mul_complex(mid1[2], self.twiddles[2]);
975        mid2[2] = AvxVector::mul_complex(mid2[2], self.twiddles[3]);
976
977        // transpose 9x3 to 3x9. this will be a little awkward because of rows0 containing garbage data, so use a transpose function that knows to ignore it
978        let transposed = avx32_utils::transpose_9x3_to_3x9_emptycolumn1_f32(mid0, mid1, mid2);
979
980        // butterfly 9s down the rows
981        let output_rows = AvxVector256::column_butterfly9(
982            transposed,
983            self.twiddles_butterfly9,
984            self.twiddles_butterfly3,
985        );
986
987        // Our last column is empty, so it's a bit awkward to write out to memory. We could pack it in fewer vectors, but benchmarking shows it's simpler and just as fast to just brute-force it with partial writes
988        buffer.store_partial3_complex(output_rows[0], 0);
989        buffer.store_partial3_complex(output_rows[1], 3);
990        buffer.store_partial3_complex(output_rows[2], 6);
991        buffer.store_partial3_complex(output_rows[3], 9);
992        buffer.store_partial3_complex(output_rows[4], 12);
993        buffer.store_partial3_complex(output_rows[5], 15);
994        buffer.store_partial3_complex(output_rows[6], 18);
995        buffer.store_partial3_complex(output_rows[7], 21);
996        buffer.store_partial3_complex(output_rows[8], 24);
997    }
998}
999
1000pub struct Butterfly32Avx<T> {
1001    twiddles: [__m256; 6],
1002    twiddles_butterfly4: Rotation90<__m256>,
1003    direction: FftDirection,
1004    _phantom_t: std::marker::PhantomData<T>,
1005}
1006boilerplate_fft_simd_butterfly!(Butterfly32Avx, 32);
1007impl Butterfly32Avx<f32> {
1008    #[target_feature(enable = "avx")]
1009    unsafe fn new_with_avx(direction: FftDirection) -> Self {
1010        Self {
1011            twiddles: gen_butterfly_twiddles_interleaved_columns!(4, 8, 0, direction),
1012            twiddles_butterfly4: AvxVector::make_rotation90(direction),
1013            direction,
1014            _phantom_t: PhantomData,
1015        }
1016    }
1017}
1018impl<T> Butterfly32Avx<T> {
1019    #[target_feature(enable = "avx", enable = "fma")]
1020    unsafe fn perform_fft_f32(&self, mut buffer: impl AvxArrayMut<f32>) {
1021        let mut rows0 = [AvxVector::zero(); 4];
1022        let mut rows1 = [AvxVector::zero(); 4];
1023        for r in 0..4 {
1024            rows0[r] = buffer.load_complex(8 * r);
1025            rows1[r] = buffer.load_complex(8 * r + 4);
1026        }
1027
1028        // We're going to treat our input as a 8x4 2d array. First, do 8 butterfly 4's down the columns of that array.
1029        let mut mid0 = AvxVector::column_butterfly4(rows0, self.twiddles_butterfly4);
1030        let mut mid1 = AvxVector::column_butterfly4(rows1, self.twiddles_butterfly4);
1031
1032        // apply twiddle factors
1033        for r in 1..4 {
1034            mid0[r] = AvxVector::mul_complex(mid0[r], self.twiddles[2 * r - 2]);
1035            mid1[r] = AvxVector::mul_complex(mid1[r], self.twiddles[2 * r - 1]);
1036        }
1037
1038        // Transpose our 8x4 array to an 4x8 array
1039        let transposed = avx32_utils::transpose_8x4_to_4x8_f32(mid0, mid1);
1040
1041        // Do 4 butterfly 8's down the columns of our transpsed array
1042        let output_rows = AvxVector::column_butterfly8(transposed, self.twiddles_butterfly4);
1043
1044        // Manually unrolling this loop because writing a "for r in 0..8" loop results in slow codegen that makes the whole thing take 1.5x longer :(
1045        buffer.store_complex(output_rows[0], 0);
1046        buffer.store_complex(output_rows[1], 1 * 4);
1047        buffer.store_complex(output_rows[2], 2 * 4);
1048        buffer.store_complex(output_rows[3], 3 * 4);
1049        buffer.store_complex(output_rows[4], 4 * 4);
1050        buffer.store_complex(output_rows[5], 5 * 4);
1051        buffer.store_complex(output_rows[6], 6 * 4);
1052        buffer.store_complex(output_rows[7], 7 * 4);
1053    }
1054}
1055
1056pub struct Butterfly36Avx<T> {
1057    twiddles: [__m256; 6],
1058    twiddles_butterfly9: [__m256; 3],
1059    twiddles_butterfly3: __m256,
1060    twiddles_butterfly4: Rotation90<__m256>,
1061    direction: FftDirection,
1062    _phantom_t: std::marker::PhantomData<T>,
1063}
1064boilerplate_fft_simd_butterfly!(Butterfly36Avx, 36);
1065impl Butterfly36Avx<f32> {
1066    #[target_feature(enable = "avx")]
1067    unsafe fn new_with_avx(direction: FftDirection) -> Self {
1068        Self {
1069            twiddles: gen_butterfly_twiddles_interleaved_columns!(4, 9, 1, direction),
1070            twiddles_butterfly9: [
1071                AvxVector::broadcast_twiddle(1, 9, direction),
1072                AvxVector::broadcast_twiddle(2, 9, direction),
1073                AvxVector::broadcast_twiddle(4, 9, direction),
1074            ],
1075            twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, direction),
1076            twiddles_butterfly4: AvxVector::make_rotation90(direction),
1077            direction,
1078            _phantom_t: PhantomData,
1079        }
1080    }
1081}
1082impl<T> Butterfly36Avx<T> {
1083    #[target_feature(enable = "avx", enable = "fma")]
1084    unsafe fn perform_fft_f32(&self, mut buffer: impl AvxArrayMut<f32>) {
1085        // we're going to load our data in a peculiar way. we're going to load the first column on its own as a column of __m128.
1086        // it's faster to just load the first 2 columns into these m128s than trying to worry about masks, etc, so the second column will piggyback along and we just won't use it
1087        let mut rows0 = [AvxVector::zero(); 4];
1088        let mut rows1 = [AvxVector::zero(); 4];
1089        let mut rows2 = [AvxVector::zero(); 4];
1090        for r in 0..4 {
1091            rows0[r] = buffer.load_partial2_complex(r * 9);
1092            rows1[r] = buffer.load_complex(r * 9 + 1);
1093            rows2[r] = buffer.load_complex(r * 9 + 5);
1094        }
1095
1096        // butterfly 4s down the columns
1097        let mid0 = AvxVector::column_butterfly4(rows0, self.twiddles_butterfly4.lo());
1098        let mut mid1 = AvxVector::column_butterfly4(rows1, self.twiddles_butterfly4);
1099        let mut mid2 = AvxVector::column_butterfly4(rows2, self.twiddles_butterfly4);
1100
1101        // apply twiddle factors
1102        for r in 1..4 {
1103            mid1[r] = AvxVector::mul_complex(mid1[r], self.twiddles[2 * r - 2]);
1104            mid2[r] = AvxVector::mul_complex(mid2[r], self.twiddles[2 * r - 1]);
1105        }
1106
1107        // transpose 9x4 to 4x9. this will be a little awkward because of rows0 containing garbage data, so use a transpose function that knows to ignore it
1108        let transposed = avx32_utils::transpose_9x4_to_4x9_emptycolumn1_f32(mid0, mid1, mid2);
1109
1110        // butterfly 9s down the rows
1111        let output_rows = AvxVector256::column_butterfly9(
1112            transposed,
1113            self.twiddles_butterfly9,
1114            self.twiddles_butterfly3,
1115        );
1116
1117        for r in 0..3 {
1118            buffer.store_complex(output_rows[r * 3], r * 12);
1119            buffer.store_complex(output_rows[r * 3 + 1], r * 12 + 4);
1120            buffer.store_complex(output_rows[r * 3 + 2], r * 12 + 8);
1121        }
1122    }
1123}
1124
1125pub struct Butterfly48Avx<T> {
1126    twiddles: [__m256; 9],
1127    twiddles_butterfly3: __m256,
1128    twiddles_butterfly4: Rotation90<__m256>,
1129    direction: FftDirection,
1130    _phantom_t: std::marker::PhantomData<T>,
1131}
1132boilerplate_fft_simd_butterfly!(Butterfly48Avx, 48);
1133impl Butterfly48Avx<f32> {
1134    #[target_feature(enable = "avx")]
1135    unsafe fn new_with_avx(direction: FftDirection) -> Self {
1136        Self {
1137            twiddles: gen_butterfly_twiddles_interleaved_columns!(4, 12, 0, direction),
1138            twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, direction),
1139            twiddles_butterfly4: AvxVector::make_rotation90(direction),
1140            direction,
1141            _phantom_t: PhantomData,
1142        }
1143    }
1144}
1145impl<T> Butterfly48Avx<T> {
1146    #[target_feature(enable = "avx", enable = "fma")]
1147    unsafe fn perform_fft_f32(&self, mut buffer: impl AvxArrayMut<f32>) {
1148        let mut rows0 = [AvxVector::zero(); 4];
1149        let mut rows1 = [AvxVector::zero(); 4];
1150        let mut rows2 = [AvxVector::zero(); 4];
1151        for r in 0..4 {
1152            rows0[r] = buffer.load_complex(12 * r);
1153            rows1[r] = buffer.load_complex(12 * r + 4);
1154            rows2[r] = buffer.load_complex(12 * r + 8);
1155        }
1156
1157        // We're going to treat our input as a 12x4 2d array. First, do 12 butterfly 4's down the columns of that array.
1158        let mut mid0 = AvxVector::column_butterfly4(rows0, self.twiddles_butterfly4);
1159        let mut mid1 = AvxVector::column_butterfly4(rows1, self.twiddles_butterfly4);
1160        let mut mid2 = AvxVector::column_butterfly4(rows2, self.twiddles_butterfly4);
1161
1162        // apply twiddle factors
1163        for r in 1..4 {
1164            mid0[r] = AvxVector::mul_complex(mid0[r], self.twiddles[3 * r - 3]);
1165            mid1[r] = AvxVector::mul_complex(mid1[r], self.twiddles[3 * r - 2]);
1166            mid2[r] = AvxVector::mul_complex(mid2[r], self.twiddles[3 * r - 1]);
1167        }
1168
1169        // Transpose our 12x4 array into a 4x12.
1170        let transposed = avx32_utils::transpose_12x4_to_4x12_f32(mid0, mid1, mid2);
1171
1172        // Do 4 butterfly 12's down the columns of our transposed array
1173        let output_rows = AvxVector256::column_butterfly12(
1174            transposed,
1175            self.twiddles_butterfly3,
1176            self.twiddles_butterfly4,
1177        );
1178
1179        // Manually unrolling this loop because writing a "for r in 0..12" loop results in slow codegen that makes the whole thing take 1.5x longer :(
1180        buffer.store_complex(output_rows[0], 0);
1181        buffer.store_complex(output_rows[1], 4);
1182        buffer.store_complex(output_rows[2], 8);
1183        buffer.store_complex(output_rows[3], 12);
1184        buffer.store_complex(output_rows[4], 16);
1185        buffer.store_complex(output_rows[5], 20);
1186        buffer.store_complex(output_rows[6], 24);
1187        buffer.store_complex(output_rows[7], 28);
1188        buffer.store_complex(output_rows[8], 32);
1189        buffer.store_complex(output_rows[9], 36);
1190        buffer.store_complex(output_rows[10], 40);
1191        buffer.store_complex(output_rows[11], 44);
1192    }
1193}
1194
1195pub struct Butterfly54Avx<T> {
1196    twiddles: [__m256; 10],
1197    twiddles_butterfly9: [__m256; 3],
1198    twiddles_butterfly9_lo: [__m256; 2],
1199    twiddles_butterfly3: __m256,
1200    direction: FftDirection,
1201    _phantom_t: std::marker::PhantomData<T>,
1202}
1203boilerplate_fft_simd_butterfly!(Butterfly54Avx, 54);
1204impl Butterfly54Avx<f32> {
1205    #[target_feature(enable = "avx")]
1206    unsafe fn new_with_avx(direction: FftDirection) -> Self {
1207        let twiddle1 = __m128::broadcast_twiddle(1, 9, direction);
1208        let twiddle2 = __m128::broadcast_twiddle(2, 9, direction);
1209        let twiddle4 = __m128::broadcast_twiddle(4, 9, direction);
1210
1211        Self {
1212            twiddles: gen_butterfly_twiddles_interleaved_columns!(6, 9, 1, direction),
1213            twiddles_butterfly9: [
1214                AvxVector::broadcast_twiddle(1, 9, direction),
1215                AvxVector::broadcast_twiddle(2, 9, direction),
1216                AvxVector::broadcast_twiddle(4, 9, direction),
1217            ],
1218            twiddles_butterfly9_lo: [
1219                AvxVector256::merge(twiddle1, twiddle2),
1220                AvxVector256::merge(twiddle2, twiddle4),
1221            ],
1222            twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, direction),
1223            direction,
1224            _phantom_t: PhantomData,
1225        }
1226    }
1227}
1228impl<T> Butterfly54Avx<T> {
1229    #[target_feature(enable = "avx", enable = "fma")]
1230    unsafe fn perform_fft_f32(&self, mut buffer: impl AvxArrayMut<f32>) {
1231        // we're going to load our data in a peculiar way. we're going to load the first column on its own as a column of __m128.
1232        // it's faster to just load the first 2 columns into these m128s than trying to worry about masks, etc, so the second column will piggyback along and we just won't use it
1233        //
1234        // we have too much data to fit into registers all at once, so split up our data processing so that we entirely finish with one "rows_" array before moving to the next
1235        let mut rows0 = [AvxVector::zero(); 6];
1236        for r in 0..3 {
1237            rows0[r * 2] = buffer.load_partial2_complex(r * 18);
1238            rows0[r * 2 + 1] = buffer.load_partial2_complex(r * 18 + 9);
1239        }
1240        let mid0 = AvxVector128::column_butterfly6(rows0, self.twiddles_butterfly3);
1241
1242        // next set of butterfly 6's
1243        let mut rows1 = [AvxVector::zero(); 6];
1244        for r in 0..3 {
1245            rows1[r * 2] = buffer.load_complex(r * 18 + 1);
1246            rows1[r * 2 + 1] = buffer.load_complex(r * 18 + 10);
1247        }
1248        let mut mid1 = AvxVector256::column_butterfly6(rows1, self.twiddles_butterfly3);
1249        for r in 1..6 {
1250            mid1[r] = AvxVector::mul_complex(mid1[r], self.twiddles[2 * r - 2]);
1251        }
1252
1253        // final set of butterfly 6's
1254        let mut rows2 = [AvxVector::zero(); 6];
1255        for r in 0..3 {
1256            rows2[r * 2] = buffer.load_complex(r * 18 + 5);
1257            rows2[r * 2 + 1] = buffer.load_complex(r * 18 + 14);
1258        }
1259        let mut mid2 = AvxVector256::column_butterfly6(rows2, self.twiddles_butterfly3);
1260        for r in 1..6 {
1261            mid2[r] = AvxVector::mul_complex(mid2[r], self.twiddles[2 * r - 1]);
1262        }
1263
1264        // transpose 9x6 to 6x9. this will be a little awkward because of rows0 containing garbage data, so use a transpose function that knows to ignore it
1265        let (transposed0, transposed1) =
1266            avx32_utils::transpose_9x6_to_6x9_emptycolumn1_f32(mid0, mid1, mid2);
1267
1268        // butterfly 9s down the rows
1269        // process the other half
1270        let output_rows1 = AvxVector128::column_butterfly9(
1271            transposed1,
1272            self.twiddles_butterfly9_lo,
1273            self.twiddles_butterfly3,
1274        );
1275        for r in 0..9 {
1276            buffer.store_partial2_complex(output_rows1[r], r * 6 + 4);
1277        }
1278
1279        // we have too much data to fit into registers all at once, do one set of butterfly 9's and output them before even starting on the others, to make it easier for the compiler to figure out what to spill
1280        let output_rows0 = AvxVector256::column_butterfly9(
1281            transposed0,
1282            self.twiddles_butterfly9,
1283            self.twiddles_butterfly3,
1284        );
1285        for r in 0..9 {
1286            buffer.store_complex(output_rows0[r], r * 6);
1287        }
1288    }
1289}
1290
1291pub struct Butterfly64Avx<T> {
1292    twiddles: [__m256; 14],
1293    twiddles_butterfly4: Rotation90<__m256>,
1294    direction: FftDirection,
1295    _phantom_t: std::marker::PhantomData<T>,
1296}
1297boilerplate_fft_simd_butterfly!(Butterfly64Avx, 64);
1298impl Butterfly64Avx<f32> {
1299    #[target_feature(enable = "avx")]
1300    unsafe fn new_with_avx(direction: FftDirection) -> Self {
1301        Self {
1302            twiddles: gen_butterfly_twiddles_separated_columns!(8, 8, 0, direction),
1303            twiddles_butterfly4: AvxVector::make_rotation90(direction),
1304            direction,
1305            _phantom_t: PhantomData,
1306        }
1307    }
1308}
1309impl<T> Butterfly64Avx<T> {
1310    #[target_feature(enable = "avx", enable = "fma")]
1311    unsafe fn perform_fft_f32(&self, mut buffer: impl AvxArrayMut<f32>) {
1312        // We're going to treat our input as a 8x8 2d array. First, do 8 butterfly 8's down the columns of that array.
1313        // We can't fit the whole problem into AVX registers at once, so we'll have to spill some things.
1314        // By computing a sizeable chunk and not referencing any of it for a while, we're making it easy for the compiler to decide what to spill
1315        let mut rows0 = [AvxVector::zero(); 8];
1316        for r in 0..8 {
1317            rows0[r] = buffer.load_complex(8 * r);
1318        }
1319        let mut mid0 = AvxVector::column_butterfly8(rows0, self.twiddles_butterfly4);
1320        for r in 1..8 {
1321            mid0[r] = AvxVector::mul_complex(mid0[r], self.twiddles[r - 1]);
1322        }
1323
1324        // One half is done, so the compiler can spill everything above this. Now do the other set of columns
1325        let mut rows1 = [AvxVector::zero(); 8];
1326        for r in 0..8 {
1327            rows1[r] = buffer.load_complex(8 * r + 4);
1328        }
1329        let mut mid1 = AvxVector::column_butterfly8(rows1, self.twiddles_butterfly4);
1330        for r in 1..8 {
1331            mid1[r] = AvxVector::mul_complex(mid1[r], self.twiddles[r - 1 + 7]);
1332        }
1333
1334        // Transpose our 8x8 array
1335        let (transposed0, transposed1) = avx32_utils::transpose_8x8_f32(mid0, mid1);
1336
1337        // Do 8 butterfly 8's down the columns of our transposed array, and store the results
1338        // Same thing as above - Do the half of the butterfly 8's separately to give the compiler a better hint about what to spill
1339        let output0 = AvxVector::column_butterfly8(transposed0, self.twiddles_butterfly4);
1340        for r in 0..8 {
1341            buffer.store_complex(output0[r], 8 * r);
1342        }
1343
1344        let output1 = AvxVector::column_butterfly8(transposed1, self.twiddles_butterfly4);
1345        for r in 0..8 {
1346            buffer.store_complex(output1[r], 8 * r + 4);
1347        }
1348    }
1349}
1350
1351pub struct Butterfly72Avx<T> {
1352    twiddles: [__m256; 15],
1353    twiddles_butterfly4: Rotation90<__m256>,
1354    twiddles_butterfly3: __m256,
1355    direction: FftDirection,
1356    _phantom_t: std::marker::PhantomData<T>,
1357}
1358boilerplate_fft_simd_butterfly!(Butterfly72Avx, 72);
1359impl Butterfly72Avx<f32> {
1360    #[target_feature(enable = "avx")]
1361    unsafe fn new_with_avx(direction: FftDirection) -> Self {
1362        Self {
1363            twiddles: gen_butterfly_twiddles_separated_columns!(6, 12, 0, direction),
1364            twiddles_butterfly4: AvxVector::make_rotation90(direction),
1365            twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, direction),
1366            direction,
1367            _phantom_t: PhantomData,
1368        }
1369    }
1370}
1371impl<T> Butterfly72Avx<T> {
1372    #[target_feature(enable = "avx", enable = "fma")]
1373    unsafe fn perform_fft_f32(&self, mut buffer: impl AvxArrayMut<f32>) {
1374        // We're going to treat our input as a 12x6 2d array. First, do butterfly 6's down the columns of that array.
1375        // We can't fit the whole problem into AVX registers at once, so we'll have to spill some things.
1376        // By computing a sizeable chunk and not referencing any of it for a while, we're making it easy for the compiler to decide what to spill
1377        let mut rows0 = [AvxVector::zero(); 6];
1378        for r in 0..6 {
1379            rows0[r] = buffer.load_complex(12 * r);
1380        }
1381        let mut mid0 = AvxVector256::column_butterfly6(rows0, self.twiddles_butterfly3);
1382        for r in 1..6 {
1383            mid0[r] = AvxVector::mul_complex(mid0[r], self.twiddles[r - 1]);
1384        }
1385
1386        // One third is done, so the compiler can spill everything above this. Now do the middle set of columns
1387        let mut rows1 = [AvxVector::zero(); 6];
1388        for r in 0..6 {
1389            rows1[r] = buffer.load_complex(12 * r + 4);
1390        }
1391        let mut mid1 = AvxVector256::column_butterfly6(rows1, self.twiddles_butterfly3);
1392        for r in 1..6 {
1393            mid1[r] = AvxVector::mul_complex(mid1[r], self.twiddles[r - 1 + 5]);
1394        }
1395
1396        // two thirds are done, so the compiler can spill everything above this. Now do the final set of columns
1397        let mut rows2 = [AvxVector::zero(); 6];
1398        for r in 0..6 {
1399            rows2[r] = buffer.load_complex(12 * r + 8);
1400        }
1401        let mut mid2 = AvxVector256::column_butterfly6(rows2, self.twiddles_butterfly3);
1402        for r in 1..6 {
1403            mid2[r] = AvxVector::mul_complex(mid2[r], self.twiddles[r - 1 + 10]);
1404        }
1405
1406        // Transpose our 12x6 array to 6x12 array
1407        let (transposed0, transposed1) = avx32_utils::transpose_12x6_to_6x12_f32(mid0, mid1, mid2);
1408
1409        // Do butterfly 12's down the columns of our transposed array, and store the results
1410        // Same thing as above - Do the half of the butterfly 12's separately to give the compiler a better hint about what to spill
1411        let output0 = AvxVector128::column_butterfly12(
1412            transposed0,
1413            self.twiddles_butterfly3,
1414            self.twiddles_butterfly4,
1415        );
1416        for r in 0..12 {
1417            buffer.store_partial2_complex(output0[r], 6 * r);
1418        }
1419
1420        let output1 = AvxVector256::column_butterfly12(
1421            transposed1,
1422            self.twiddles_butterfly3,
1423            self.twiddles_butterfly4,
1424        );
1425        for r in 0..12 {
1426            buffer.store_complex(output1[r], 6 * r + 2);
1427        }
1428    }
1429}
1430
1431pub struct Butterfly128Avx<T> {
1432    twiddles: [__m256; 28],
1433    twiddles_butterfly16: [__m256; 2],
1434    twiddles_butterfly4: Rotation90<__m256>,
1435    direction: FftDirection,
1436    _phantom_t: std::marker::PhantomData<T>,
1437}
1438boilerplate_fft_simd_butterfly_with_scratch!(Butterfly128Avx, 128);
1439impl Butterfly128Avx<f32> {
1440    #[target_feature(enable = "avx")]
1441    unsafe fn new_with_avx(direction: FftDirection) -> Self {
1442        Self {
1443            twiddles: gen_butterfly_twiddles_separated_columns!(8, 16, 0, direction),
1444            twiddles_butterfly16: [
1445                AvxVector::broadcast_twiddle(1, 16, direction),
1446                AvxVector::broadcast_twiddle(3, 16, direction),
1447            ],
1448            twiddles_butterfly4: AvxVector::make_rotation90(direction),
1449            direction,
1450            _phantom_t: PhantomData,
1451        }
1452    }
1453}
1454impl<T> Butterfly128Avx<T> {
1455    #[target_feature(enable = "avx", enable = "fma")]
1456    unsafe fn column_butterflies_and_transpose(
1457        &self,
1458        input: &[Complex<f32>],
1459        mut output: &mut [Complex<f32>],
1460    ) {
1461        // A size-128 FFT is way too big to fit in registers, so instead we're going to compute it in two phases, storing in scratch in between.
1462
1463        // First phase is to treat this size-128 array like a 16x8 2D array, and do butterfly 8's down the columns
1464        // Then, apply twiddle factors, and finally transpose into the scratch space
1465
1466        // But again, we don't have enough registers to load it all at once, so only load one column of AVX vectors at a time
1467        for columnset in 0..4 {
1468            let mut rows = [AvxVector::zero(); 8];
1469            for r in 0..8 {
1470                rows[r] = input.load_complex(columnset * 4 + 16 * r);
1471            }
1472            // apply butterfly 8
1473            let mut mid = AvxVector::column_butterfly8(rows, self.twiddles_butterfly4);
1474
1475            // apply twiddle factors
1476            for r in 1..8 {
1477                mid[r] = AvxVector::mul_complex(mid[r], self.twiddles[r - 1 + 7 * columnset]);
1478            }
1479
1480            // transpose
1481            let transposed = AvxVector::transpose8_packed(mid);
1482
1483            // write out
1484            for i in 0..4 {
1485                output.store_complex(transposed[i * 2], columnset * 32 + i * 8);
1486                output.store_complex(transposed[i * 2 + 1], columnset * 32 + i * 8 + 4);
1487            }
1488        }
1489    }
1490
1491    #[target_feature(enable = "avx", enable = "fma")]
1492    unsafe fn row_butterflies(&self, mut buffer: impl AvxArrayMut<f32>) {
1493        // Second phase: Butterfly 16's down the columns of our transposed array.
1494        // Thankfully, during the first phase, we set everything up so that all we have to do here is compute the size-16 FFT columns and write them back out where we got them
1495        // We're also using a customized butterfly16 function that is smarter about when it loads/stores data, to reduce register spilling
1496        for columnset in 0usize..2 {
1497            column_butterfly16_loadfn!(
1498                |index: usize| buffer.load_complex(columnset * 4 + index * 8),
1499                |data, index| buffer.store_complex(data, columnset * 4 + index * 8),
1500                self.twiddles_butterfly16,
1501                self.twiddles_butterfly4
1502            );
1503        }
1504    }
1505}
1506
1507#[allow(non_camel_case_types)]
1508pub struct Butterfly256Avx<T> {
1509    twiddles: [__m256; 56],
1510    twiddles_butterfly32: [__m256; 6],
1511    twiddles_butterfly4: Rotation90<__m256>,
1512    direction: FftDirection,
1513    _phantom_t: std::marker::PhantomData<T>,
1514}
1515boilerplate_fft_simd_butterfly_with_scratch!(Butterfly256Avx, 256);
1516impl Butterfly256Avx<f32> {
1517    #[target_feature(enable = "avx")]
1518    unsafe fn new_with_avx(direction: FftDirection) -> Self {
1519        Self {
1520            twiddles: gen_butterfly_twiddles_separated_columns!(8, 32, 0, direction),
1521            twiddles_butterfly32: [
1522                AvxVector::broadcast_twiddle(1, 32, direction),
1523                AvxVector::broadcast_twiddle(2, 32, direction),
1524                AvxVector::broadcast_twiddle(3, 32, direction),
1525                AvxVector::broadcast_twiddle(5, 32, direction),
1526                AvxVector::broadcast_twiddle(6, 32, direction),
1527                AvxVector::broadcast_twiddle(7, 32, direction),
1528            ],
1529            twiddles_butterfly4: AvxVector::make_rotation90(direction),
1530            direction,
1531            _phantom_t: PhantomData,
1532        }
1533    }
1534}
1535impl<T> Butterfly256Avx<T> {
1536    #[target_feature(enable = "avx", enable = "fma")]
1537    unsafe fn column_butterflies_and_transpose(
1538        &self,
1539        input: &[Complex<f32>],
1540        mut output: &mut [Complex<f32>],
1541    ) {
1542        // A size-256 FFT is way too big to fit in registers, so instead we're going to compute it in two phases, storing in scratch in between.
1543
1544        // First phase is to treeat this size-256 array like a 32x8 2D array, and do butterfly 8's down the columns
1545        // Then, apply twiddle factors, and finally transpose into the scratch space
1546
1547        // But again, we don't have enough registers to load it all at once, so only load one column of AVX vectors at a time
1548        for columnset in 0..8 {
1549            let mut rows = [AvxVector::zero(); 8];
1550            for r in 0..8 {
1551                rows[r] = input.load_complex(columnset * 4 + 32 * r);
1552            }
1553            let mut mid = AvxVector::column_butterfly8(rows, self.twiddles_butterfly4);
1554            for r in 1..8 {
1555                mid[r] = AvxVector::mul_complex(mid[r], self.twiddles[r - 1 + 7 * columnset]);
1556            }
1557
1558            // Before writing to the scratch, transpose this chunk of the array
1559            let transposed = AvxVector::transpose8_packed(mid);
1560
1561            for i in 0..4 {
1562                output.store_complex(transposed[i * 2], columnset * 32 + i * 8);
1563                output.store_complex(transposed[i * 2 + 1], columnset * 32 + i * 8 + 4);
1564            }
1565        }
1566    }
1567
1568    #[target_feature(enable = "avx", enable = "fma")]
1569    unsafe fn row_butterflies(&self, mut buffer: impl AvxArrayMut<f32>) {
1570        // Second phase: Butterfly 32's down the columns of our transposed array.
1571        // Thankfully, during the first phase, we set everything up so that all we have to do here is compute the size-32 FFT columns and write them back out where we got them
1572        // We're also using a customized butterfly32 function that is smarter about when it loads/stores data, to reduce register spilling
1573        for columnset in 0..2 {
1574            column_butterfly32_loadfn!(
1575                |index: usize| buffer.load_complex(columnset * 4 + index * 8),
1576                |data, index| buffer.store_complex(data, columnset * 4 + index * 8),
1577                self.twiddles_butterfly32,
1578                self.twiddles_butterfly4
1579            );
1580        }
1581    }
1582}
1583
1584pub struct Butterfly512Avx<T> {
1585    twiddles: [__m256; 120],
1586    twiddles_butterfly32: [__m256; 6],
1587    twiddles_butterfly16: [__m256; 2],
1588    twiddles_butterfly4: Rotation90<__m256>,
1589    direction: FftDirection,
1590    _phantom_t: std::marker::PhantomData<T>,
1591}
1592boilerplate_fft_simd_butterfly_with_scratch!(Butterfly512Avx, 512);
1593impl Butterfly512Avx<f32> {
1594    #[target_feature(enable = "avx")]
1595    unsafe fn new_with_avx(direction: FftDirection) -> Self {
1596        Self {
1597            twiddles: gen_butterfly_twiddles_separated_columns!(16, 32, 0, direction),
1598            twiddles_butterfly32: [
1599                AvxVector::broadcast_twiddle(1, 32, direction),
1600                AvxVector::broadcast_twiddle(2, 32, direction),
1601                AvxVector::broadcast_twiddle(3, 32, direction),
1602                AvxVector::broadcast_twiddle(5, 32, direction),
1603                AvxVector::broadcast_twiddle(6, 32, direction),
1604                AvxVector::broadcast_twiddle(7, 32, direction),
1605            ],
1606            twiddles_butterfly16: [
1607                AvxVector::broadcast_twiddle(1, 16, direction),
1608                AvxVector::broadcast_twiddle(3, 16, direction),
1609            ],
1610            twiddles_butterfly4: AvxVector::make_rotation90(direction),
1611            direction,
1612            _phantom_t: PhantomData,
1613        }
1614    }
1615}
1616impl<T> Butterfly512Avx<T> {
1617    #[target_feature(enable = "avx", enable = "fma")]
1618    unsafe fn column_butterflies_and_transpose(
1619        &self,
1620        input: &[Complex<f32>],
1621        mut output: &mut [Complex<f32>],
1622    ) {
1623        // A size-512 FFT is way too big to fit in registers, so instead we're going to compute it in two phases, storing in scratch in between.
1624
1625        // First phase is to treat this size-512 array like a 32x16 2D array, and do butterfly 16's down the columns
1626        // Then, apply twiddle factors, and finally transpose into the scratch space
1627
1628        // But again, we don't have enough registers to load it all at once, so only load one column of AVX vectors at a time
1629        // We're also using a customized butterfly16 function that is smarter about when it loads/stores data, to reduce register spilling
1630        const TWIDDLES_PER_COLUMN: usize = 15;
1631        for (columnset, twiddle_chunk) in
1632            self.twiddles.chunks_exact(TWIDDLES_PER_COLUMN).enumerate()
1633        {
1634            // Sadly we have to use MaybeUninit here. If we init an array like normal with AvxVector::Zero(), the compiler can't seem to figure out that it can
1635            // eliminate the dead stores of zeroes to the stack. By using uninit here, we avoid those unnecessary writes
1636            let mut mid_uninit: [MaybeUninit<__m256>; 16] = [MaybeUninit::<__m256>::uninit(); 16];
1637
1638            column_butterfly16_loadfn!(
1639                |index: usize| input.load_complex(columnset * 4 + 32 * index),
1640                |data, index: usize| {
1641                    mid_uninit[index].as_mut_ptr().write(data);
1642                },
1643                self.twiddles_butterfly16,
1644                self.twiddles_butterfly4
1645            );
1646
1647            // Apply twiddle factors, transpose, and store. Traditionally we apply all the twiddle factors at once and then do all the transposes at once,
1648            // But our data is pushing the limit of what we can store in registers, so the idea here is to get the data out the door with as few spills to the stack as possible
1649            for chunk in 0..4 {
1650                let twiddled = [
1651                    if chunk > 0 {
1652                        AvxVector::mul_complex(
1653                            mid_uninit[4 * chunk].assume_init(),
1654                            twiddle_chunk[4 * chunk - 1],
1655                        )
1656                    } else {
1657                        mid_uninit[4 * chunk].assume_init()
1658                    },
1659                    AvxVector::mul_complex(
1660                        mid_uninit[4 * chunk + 1].assume_init(),
1661                        twiddle_chunk[4 * chunk],
1662                    ),
1663                    AvxVector::mul_complex(
1664                        mid_uninit[4 * chunk + 2].assume_init(),
1665                        twiddle_chunk[4 * chunk + 1],
1666                    ),
1667                    AvxVector::mul_complex(
1668                        mid_uninit[4 * chunk + 3].assume_init(),
1669                        twiddle_chunk[4 * chunk + 2],
1670                    ),
1671                ];
1672
1673                let transposed = AvxVector::transpose4_packed(twiddled);
1674
1675                output.store_complex(transposed[0], columnset * 64 + 0 * 16 + 4 * chunk);
1676                output.store_complex(transposed[1], columnset * 64 + 1 * 16 + 4 * chunk);
1677                output.store_complex(transposed[2], columnset * 64 + 2 * 16 + 4 * chunk);
1678                output.store_complex(transposed[3], columnset * 64 + 3 * 16 + 4 * chunk);
1679            }
1680        }
1681    }
1682
1683    #[target_feature(enable = "avx", enable = "fma")]
1684    unsafe fn row_butterflies(&self, mut buffer: impl AvxArrayMut<f32>) {
1685        // Second phase: Butterfly 32's down the columns of our transposed array.
1686        // Thankfully, during the first phase, we set everything up so that all we have to do here is compute the size-32 FFT columns and write them back out where we got them
1687        // We're also using a customized butterfly32 function that is smarter about when it loads/stores data, to reduce register spilling
1688        for columnset in 0..4 {
1689            column_butterfly32_loadfn!(
1690                |index: usize| buffer.load_complex(columnset * 4 + index * 16),
1691                |data, index| buffer.store_complex(data, columnset * 4 + index * 16),
1692                self.twiddles_butterfly32,
1693                self.twiddles_butterfly4
1694            );
1695        }
1696    }
1697}
1698
1699#[cfg(test)]
1700mod unit_tests {
1701    use super::*;
1702    use crate::test_utils::check_fft_algorithm;
1703
1704    macro_rules! test_avx_butterfly {
1705        ($test_name:ident, $struct_name:ident, $size:expr) => (
1706            #[test]
1707            fn $test_name() {
1708                let butterfly = $struct_name::<f32>::new(FftDirection::Forward).expect("Can't run test because this machine doesn't have the required instruction sets");
1709                check_fft_algorithm(&butterfly as &dyn Fft<f32>, $size, FftDirection::Forward);
1710
1711                let butterfly_inverse = $struct_name::<f32>::new(FftDirection::Inverse).expect("Can't run test because this machine doesn't have the required instruction sets");
1712                check_fft_algorithm(&butterfly_inverse as &dyn Fft<f32>, $size, FftDirection::Inverse);
1713            }
1714        )
1715    }
1716    test_avx_butterfly!(test_avx_butterfly5, Butterfly5Avx, 5);
1717    test_avx_butterfly!(test_avx_butterfly7, Butterfly7Avx, 7);
1718    test_avx_butterfly!(test_avx_butterfly8, Butterfly8Avx, 8);
1719    test_avx_butterfly!(test_avx_butterfly9, Butterfly9Avx, 9);
1720    test_avx_butterfly!(test_avx_butterfly11, Butterfly11Avx, 11);
1721    test_avx_butterfly!(test_avx_butterfly12, Butterfly12Avx, 12);
1722    test_avx_butterfly!(test_avx_butterfly16, Butterfly16Avx, 16);
1723    test_avx_butterfly!(test_avx_butterfly24, Butterfly24Avx, 24);
1724    test_avx_butterfly!(test_avx_butterfly27, Butterfly27Avx, 27);
1725    test_avx_butterfly!(test_avx_butterfly32, Butterfly32Avx, 32);
1726    test_avx_butterfly!(test_avx_butterfly36, Butterfly36Avx, 36);
1727    test_avx_butterfly!(test_avx_butterfly48, Butterfly48Avx, 48);
1728    test_avx_butterfly!(test_avx_butterfly54, Butterfly54Avx, 54);
1729    test_avx_butterfly!(test_avx_butterfly64, Butterfly64Avx, 64);
1730    test_avx_butterfly!(test_avx_butterfly72, Butterfly72Avx, 72);
1731    test_avx_butterfly!(test_avx_butterfly128, Butterfly128Avx, 128);
1732    test_avx_butterfly!(test_avx_butterfly256, Butterfly256Avx, 256);
1733    test_avx_butterfly!(test_avx_butterfly512, Butterfly512Avx, 512);
1734}