rustfft/avx/
avx_mixed_radix.rs

1use std::any::TypeId;
2use std::sync::Arc;
3
4use num_complex::Complex;
5use num_integer::div_ceil;
6
7use crate::array_utils;
8use crate::common::{fft_error_inplace, fft_error_outofplace};
9use crate::{Direction, Fft, FftDirection, FftNum, Length};
10
11use super::{AvxNum, CommonSimdData};
12
13use super::avx_vector;
14use super::avx_vector::{AvxArray, AvxArrayMut, AvxVector, AvxVector128, AvxVector256, Rotation90};
15
16macro_rules! boilerplate_mixedradix {
17    () => {
18        /// Preallocates necessary arrays and precomputes necessary data to efficiently compute the FFT
19        /// Returns Ok() if this machine has the required instruction sets, Err() if some instruction sets are missing
20        #[inline]
21        pub fn new(inner_fft: Arc<dyn Fft<T>>) -> Result<Self, ()> {
22            // Internal sanity check: Make sure that A == T.
23            // This struct has two generic parameters A and T, but they must always be the same, and are only kept separate to help work around the lack of specialization.
24            // It would be cool if we could do this as a static_assert instead
25            let id_a = TypeId::of::<A>();
26            let id_t = TypeId::of::<T>();
27            assert_eq!(id_a, id_t);
28
29            let has_avx = is_x86_feature_detected!("avx");
30            let has_fma = is_x86_feature_detected!("fma");
31            if has_avx && has_fma {
32                // Safety: new_with_avx requires the "avx" feature set. Since we know it's present, we're safe
33                Ok(unsafe { Self::new_with_avx(inner_fft) })
34            } else {
35                Err(())
36            }
37        }
38
39        #[inline]
40        fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
41            // Perform the column FFTs
42            // Safety: self.perform_column_butterflies() requres the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available
43            unsafe {
44                // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
45                let transmuted_buffer: &mut [Complex<A>] =
46                    array_utils::workaround_transmute_mut(buffer);
47
48                self.perform_column_butterflies(transmuted_buffer)
49            }
50
51            // process the row FFTs
52            let (scratch, inner_scratch) = scratch.split_at_mut(self.len());
53            self.common_data.inner_fft.process_outofplace_with_scratch(
54                buffer,
55                scratch,
56                inner_scratch,
57            );
58
59            // Transpose
60            // Safety: self.transpose() requres the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available
61            unsafe {
62                // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
63                let transmuted_scratch: &mut [Complex<A>] =
64                    array_utils::workaround_transmute_mut(scratch);
65                let transmuted_buffer: &mut [Complex<A>] =
66                    array_utils::workaround_transmute_mut(buffer);
67
68                self.transpose(transmuted_scratch, transmuted_buffer)
69            }
70        }
71
72        #[inline]
73        fn perform_fft_out_of_place(
74            &self,
75            input: &mut [Complex<T>],
76            output: &mut [Complex<T>],
77            scratch: &mut [Complex<T>],
78        ) {
79            // Perform the column FFTs
80            // Safety: self.perform_column_butterflies() requires the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available
81            unsafe {
82                // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
83                let transmuted_input: &mut [Complex<A>] =
84                    array_utils::workaround_transmute_mut(input);
85
86                self.perform_column_butterflies(transmuted_input);
87            }
88
89            // process the row FFTs. If extra scratch was provided, pass it in. Otherwise, use the output.
90            let inner_scratch = if scratch.len() > 0 {
91                scratch
92            } else {
93                &mut output[..]
94            };
95            self.common_data
96                .inner_fft
97                .process_with_scratch(input, inner_scratch);
98
99            // Transpose
100            // Safety: self.transpose() requires the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available
101            unsafe {
102                // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
103                let transmuted_input: &mut [Complex<A>] =
104                    array_utils::workaround_transmute_mut(input);
105                let transmuted_output: &mut [Complex<A>] =
106                    array_utils::workaround_transmute_mut(output);
107
108                self.transpose(transmuted_input, transmuted_output)
109            }
110        }
111    };
112}
113
114macro_rules! mixedradix_gen_data {
115    ($row_count: expr, $inner_fft:expr) => {{
116        // Important constants
117        const ROW_COUNT : usize = $row_count;
118        const TWIDDLES_PER_COLUMN : usize = ROW_COUNT - 1;
119
120        // derive some info from our inner FFT
121        let direction = $inner_fft.fft_direction();
122        let len_per_row = $inner_fft.len();
123        let len = len_per_row * ROW_COUNT;
124
125        // We're going to process each row of the FFT one AVX register at a time. We need to know how many AVX registers each row can fit,
126        // and if the last register in each row going to have partial data (ie a remainder)
127        let quotient = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
128        let remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
129
130        // Compute our twiddle factors, and arrange them so that we can access them one column of AVX vectors at a time
131        let num_twiddle_columns = quotient + div_ceil(remainder, A::VectorType::COMPLEX_PER_VECTOR);
132        let mut twiddles = Vec::with_capacity(num_twiddle_columns * TWIDDLES_PER_COLUMN);
133        for x in 0..num_twiddle_columns {
134            for y in 1..ROW_COUNT {
135                twiddles.push(AvxVector::make_mixedradix_twiddle_chunk(x * A::VectorType::COMPLEX_PER_VECTOR, y, len, direction));
136            }
137        }
138
139        let inner_outofplace_scratch = $inner_fft.get_outofplace_scratch_len();
140        let inner_inplace_scratch = $inner_fft.get_inplace_scratch_len();
141
142        CommonSimdData {
143            twiddles: twiddles.into_boxed_slice(),
144            inplace_scratch_len: len + inner_outofplace_scratch,
145            outofplace_scratch_len: if inner_inplace_scratch > len { inner_inplace_scratch } else { 0 },
146            inner_fft: $inner_fft,
147            len,
148            direction,
149        }
150    }}
151}
152
153macro_rules! mixedradix_column_butterflies {
154    ($row_count: expr, $butterfly_fn: expr, $butterfly_fn_lo: expr) => {
155        #[target_feature(enable = "avx", enable = "fma")]
156        unsafe fn perform_column_butterflies(&self, mut buffer: impl AvxArrayMut<A>) {
157            // How many rows this FFT has, ie 2 for 2xn, 4 for 4xn, etc
158            const ROW_COUNT: usize = $row_count;
159            const TWIDDLES_PER_COLUMN: usize = ROW_COUNT - 1;
160
161            let len_per_row = self.len() / ROW_COUNT;
162            let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
163
164            // process the column FFTs
165            for (c, twiddle_chunk) in self
166                .common_data
167                .twiddles
168                .chunks_exact(TWIDDLES_PER_COLUMN)
169                .take(chunk_count)
170                .enumerate()
171            {
172                let index_base = c * A::VectorType::COMPLEX_PER_VECTOR;
173
174                // Load columns from the buffer into registers
175                let mut columns = [AvxVector::zero(); ROW_COUNT];
176                for i in 0..ROW_COUNT {
177                    columns[i] = buffer.load_complex(index_base + len_per_row * i);
178                }
179
180                // apply our butterfly function down the columns
181                let output = $butterfly_fn(columns, self);
182
183                // always write the first row directly back without twiddles
184                buffer.store_complex(output[0], index_base);
185
186                // for every other row, apply twiddle factors and then write back to memory
187                for i in 1..ROW_COUNT {
188                    let twiddle = twiddle_chunk[i - 1];
189                    let output = AvxVector::mul_complex(twiddle, output[i]);
190                    buffer.store_complex(output, index_base + len_per_row * i);
191                }
192            }
193
194            // finally, we might have a remainder chunk
195            // Normally, we can fit COMPLEX_PER_VECTOR complex numbers into an AVX register, but we only have `partial_remainder` columns left, so we need special logic to handle these final columns
196            let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
197            if partial_remainder > 0 {
198                let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
199                let partial_remainder_twiddle_base =
200                    self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN;
201                let final_twiddle_chunk =
202                    &self.common_data.twiddles[partial_remainder_twiddle_base..];
203
204                if partial_remainder > 2 {
205                    // Load 3 columns into full AVX vectors to process our remainder
206                    let mut columns = [AvxVector::zero(); ROW_COUNT];
207                    for i in 0..ROW_COUNT {
208                        columns[i] =
209                            buffer.load_partial3_complex(partial_remainder_base + len_per_row * i);
210                    }
211
212                    // apply our butterfly function down the columns
213                    let mid = $butterfly_fn(columns, self);
214
215                    // always write the first row without twiddles
216                    buffer.store_partial3_complex(mid[0], partial_remainder_base);
217
218                    // for the remaining rows, apply twiddle factors and then write back to memory
219                    for i in 1..ROW_COUNT {
220                        let twiddle = final_twiddle_chunk[i - 1];
221                        let output = AvxVector::mul_complex(twiddle, mid[i]);
222                        buffer.store_partial3_complex(
223                            output,
224                            partial_remainder_base + len_per_row * i,
225                        );
226                    }
227                } else {
228                    // Load 1 or 2 columns into half vectors to process our remainder. Thankfully, the compiler is smart enough to eliminate this branch on f64, since the partial remainder can only possibly be 1
229                    let mut columns = [AvxVector::zero(); ROW_COUNT];
230                    if partial_remainder == 1 {
231                        for i in 0..ROW_COUNT {
232                            columns[i] = buffer
233                                .load_partial1_complex(partial_remainder_base + len_per_row * i);
234                        }
235                    } else {
236                        for i in 0..ROW_COUNT {
237                            columns[i] = buffer
238                                .load_partial2_complex(partial_remainder_base + len_per_row * i);
239                        }
240                    }
241
242                    // apply our butterfly function down the columns
243                    let mut mid = $butterfly_fn_lo(columns, self);
244
245                    // apply twiddle factors
246                    for i in 1..ROW_COUNT {
247                        mid[i] = AvxVector::mul_complex(final_twiddle_chunk[i - 1].lo(), mid[i]);
248                    }
249
250                    // store output
251                    if partial_remainder == 1 {
252                        for i in 0..ROW_COUNT {
253                            buffer.store_partial1_complex(
254                                mid[i],
255                                partial_remainder_base + len_per_row * i,
256                            );
257                        }
258                    } else {
259                        for i in 0..ROW_COUNT {
260                            buffer.store_partial2_complex(
261                                mid[i],
262                                partial_remainder_base + len_per_row * i,
263                            );
264                        }
265                    }
266                }
267            }
268        }
269    };
270}
271
272macro_rules! mixedradix_transpose{
273    ($row_count: expr, $transpose_fn: path, $transpose_fn_lo: path, $($unroll_workaround_index:expr);*, $($remainder3_unroll_workaround_index:expr);*) => (
274
275    // Transpose the input (treated as a nxc array) into the output (as a cxn array)
276    #[target_feature(enable = "avx")]
277    unsafe fn transpose(&self, input: &[Complex<A>], mut output: &mut [Complex<A>]) {
278        const ROW_COUNT : usize = $row_count;
279
280        let len_per_row = self.len() / ROW_COUNT;
281        let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
282
283        // transpose the scratch as a nx2 array into the buffer as an 2xn array
284        for c in 0..chunk_count {
285            let input_index_base = c*A::VectorType::COMPLEX_PER_VECTOR;
286            let output_index_base = input_index_base * ROW_COUNT;
287
288            // Load rows from the input into registers
289            let mut rows : [A::VectorType; ROW_COUNT] = [AvxVector::zero(); ROW_COUNT];
290            for i in 0..ROW_COUNT {
291                rows[i] = input.load_complex(input_index_base + len_per_row*i);
292            }
293
294            // transpose the rows to the columns
295            let transposed = $transpose_fn(rows);
296
297            // store the transposed rows contiguously
298            // IE, unlike the way we loaded the data, which was to load it strided across each of our rows
299            // we will not output it strided, but instead writing it out as a contiguous block
300
301            // we are using a macro hack to manually unroll the loop, to work around this rustc bug:
302            // https://github.com/rust-lang/rust/issues/71025
303
304            // if we don't manually unroll the loop, the compiler will insert unnecessary writes+reads to the stack which tank performance
305            // once the compiler bug is fixed, this can be replaced by a "for i in 0..ROW_COUNT" loop
306            $(
307                output.store_complex(transposed[$unroll_workaround_index], output_index_base + A::VectorType::COMPLEX_PER_VECTOR * $unroll_workaround_index);
308            )*
309        }
310
311        // transpose the remainder
312        let input_index_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
313        let output_index_base = input_index_base * ROW_COUNT;
314
315        let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
316        if partial_remainder == 1 {
317            // If the partial remainder is 1, there's no transposing to do - just gather from across the rows and store contiguously
318            for i in 0..ROW_COUNT {
319                let input_cell = input.get_unchecked(input_index_base + len_per_row*i);
320                let output_cell = output.get_unchecked_mut(output_index_base + i);
321                *output_cell = *input_cell;
322            }
323        } else if partial_remainder == 2 {
324            // If the partial remainder is 2, use the provided transpose_lo function to do a transpose on half-vectors
325            let mut rows = [AvxVector::zero(); ROW_COUNT];
326            for i in 0..ROW_COUNT {
327                rows[i] = input.load_partial2_complex(input_index_base + len_per_row*i);
328            }
329
330            let transposed = $transpose_fn_lo(rows);
331
332            // use the same macro hack as above to unroll the loop
333            $(
334                output.store_partial2_complex(transposed[$unroll_workaround_index], output_index_base + <A::VectorType as AvxVector256>::HalfVector::COMPLEX_PER_VECTOR * $unroll_workaround_index);
335            )*
336        }
337        else if partial_remainder == 3 {
338            // If the partial remainder is 3, we have to load full vectors, use the full transpose, and then write out a variable number of outputs
339            let mut rows = [AvxVector::zero(); ROW_COUNT];
340            for i in 0..ROW_COUNT {
341                rows[i] = input.load_partial3_complex(input_index_base + len_per_row*i);
342            }
343
344            // transpose the rows to the columns
345            let transposed = $transpose_fn(rows);
346
347            // We're going to write constant number of full vectors, and then some constant-sized partial vector
348            // Sadly, because of rust limitations, we can't make full_vector_count a const, so we have to cross our fingers that the compiler optimizes it to a constant
349            let element_count = 3*ROW_COUNT;
350            let full_vector_count = element_count / A::VectorType::COMPLEX_PER_VECTOR;
351            let final_remainder_count = element_count % A::VectorType::COMPLEX_PER_VECTOR;
352
353            // write out our full vectors
354            // we are using a macro hack to manually unroll the loop, to work around this rustc bug:
355            // https://github.com/rust-lang/rust/issues/71025
356
357            // if we don't manually unroll the loop, the compiler will insert unnecessary writes+reads to the stack which tank performance
358            // once the compiler bug is fixed, this can be replaced by a "for i in 0..full_vector_count" loop
359            $(
360                output.store_complex(transposed[$remainder3_unroll_workaround_index], output_index_base + A::VectorType::COMPLEX_PER_VECTOR * $remainder3_unroll_workaround_index);
361            )*
362
363            // write out our partial vector. again, this is a compile-time constant, even if we can't represent that within rust yet
364            match final_remainder_count {
365                0 => {},
366                1 => output.store_partial1_complex(transposed[full_vector_count].lo(), output_index_base + full_vector_count * A::VectorType::COMPLEX_PER_VECTOR),
367                2 => output.store_partial2_complex(transposed[full_vector_count].lo(), output_index_base + full_vector_count * A::VectorType::COMPLEX_PER_VECTOR),
368                3 => output.store_partial3_complex(transposed[full_vector_count], output_index_base + full_vector_count * A::VectorType::COMPLEX_PER_VECTOR),
369                _ => unreachable!(),
370            }
371        }
372    }
373)}
374
375pub struct MixedRadix2xnAvx<A: AvxNum, T> {
376    common_data: CommonSimdData<T, A::VectorType>,
377    _phantom: std::marker::PhantomData<T>,
378}
379boilerplate_avx_fft_commondata!(MixedRadix2xnAvx);
380
381impl<A: AvxNum, T: FftNum> MixedRadix2xnAvx<A, T> {
382    #[target_feature(enable = "avx")]
383    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
384        Self {
385            common_data: mixedradix_gen_data!(2, inner_fft),
386            _phantom: std::marker::PhantomData,
387        }
388    }
389    mixedradix_column_butterflies!(
390        2,
391        |columns, _: _| AvxVector::column_butterfly2(columns),
392        |columns, _: _| AvxVector::column_butterfly2(columns)
393    );
394    mixedradix_transpose!(2,
395        AvxVector::transpose2_packed,
396        AvxVector::transpose2_packed,
397        0;1, 0
398    );
399    boilerplate_mixedradix!();
400}
401
402pub struct MixedRadix3xnAvx<A: AvxNum, T> {
403    twiddles_butterfly3: A::VectorType,
404    common_data: CommonSimdData<T, A::VectorType>,
405    _phantom: std::marker::PhantomData<T>,
406}
407boilerplate_avx_fft_commondata!(MixedRadix3xnAvx);
408
409impl<A: AvxNum, T: FftNum> MixedRadix3xnAvx<A, T> {
410    #[target_feature(enable = "avx")]
411    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
412        Self {
413            twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, inner_fft.fft_direction()),
414            common_data: mixedradix_gen_data!(3, inner_fft),
415            _phantom: std::marker::PhantomData,
416        }
417    }
418    mixedradix_column_butterflies!(
419        3,
420        |columns, this: &Self| AvxVector::column_butterfly3(columns, this.twiddles_butterfly3),
421        |columns, this: &Self| AvxVector::column_butterfly3(columns, this.twiddles_butterfly3.lo())
422    );
423    mixedradix_transpose!(3,
424        AvxVector::transpose3_packed,
425        AvxVector::transpose3_packed,
426        0;1;2, 0;1
427    );
428    boilerplate_mixedradix!();
429}
430
431pub struct MixedRadix4xnAvx<A: AvxNum, T> {
432    twiddles_butterfly4: Rotation90<A::VectorType>,
433    common_data: CommonSimdData<T, A::VectorType>,
434    _phantom: std::marker::PhantomData<T>,
435}
436boilerplate_avx_fft_commondata!(MixedRadix4xnAvx);
437
438impl<A: AvxNum, T: FftNum> MixedRadix4xnAvx<A, T> {
439    #[target_feature(enable = "avx")]
440    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
441        Self {
442            twiddles_butterfly4: AvxVector::make_rotation90(inner_fft.fft_direction()),
443            common_data: mixedradix_gen_data!(4, inner_fft),
444            _phantom: std::marker::PhantomData,
445        }
446    }
447    mixedradix_column_butterflies!(
448        4,
449        |columns, this: &Self| AvxVector::column_butterfly4(columns, this.twiddles_butterfly4),
450        |columns, this: &Self| AvxVector::column_butterfly4(columns, this.twiddles_butterfly4.lo())
451    );
452    mixedradix_transpose!(4,
453        AvxVector::transpose4_packed,
454        AvxVector::transpose4_packed,
455        0;1;2;3, 0;1;2
456    );
457    boilerplate_mixedradix!();
458}
459
460pub struct MixedRadix5xnAvx<A: AvxNum, T> {
461    twiddles_butterfly5: [A::VectorType; 2],
462    common_data: CommonSimdData<T, A::VectorType>,
463    _phantom: std::marker::PhantomData<T>,
464}
465boilerplate_avx_fft_commondata!(MixedRadix5xnAvx);
466
467impl<A: AvxNum, T: FftNum> MixedRadix5xnAvx<A, T> {
468    #[target_feature(enable = "avx")]
469    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
470        Self {
471            twiddles_butterfly5: [
472                AvxVector::broadcast_twiddle(1, 5, inner_fft.fft_direction()),
473                AvxVector::broadcast_twiddle(2, 5, inner_fft.fft_direction()),
474            ],
475            common_data: mixedradix_gen_data!(5, inner_fft),
476            _phantom: std::marker::PhantomData,
477        }
478    }
479    mixedradix_column_butterflies!(
480        5,
481        |columns, this: &Self| AvxVector::column_butterfly5(columns, this.twiddles_butterfly5),
482        |columns, this: &Self| AvxVector::column_butterfly5(
483            columns,
484            [
485                this.twiddles_butterfly5[0].lo(),
486                this.twiddles_butterfly5[1].lo()
487            ]
488        )
489    );
490    mixedradix_transpose!(5,
491        AvxVector::transpose5_packed,
492        AvxVector::transpose5_packed,
493        0;1;2;3;4, 0;1;2
494    );
495    boilerplate_mixedradix!();
496}
497
498pub struct MixedRadix6xnAvx<A: AvxNum, T> {
499    twiddles_butterfly3: A::VectorType,
500    common_data: CommonSimdData<T, A::VectorType>,
501    _phantom: std::marker::PhantomData<T>,
502}
503boilerplate_avx_fft_commondata!(MixedRadix6xnAvx);
504
505impl<A: AvxNum, T: FftNum> MixedRadix6xnAvx<A, T> {
506    #[target_feature(enable = "avx")]
507    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
508        Self {
509            twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, inner_fft.fft_direction()),
510            common_data: mixedradix_gen_data!(6, inner_fft),
511            _phantom: std::marker::PhantomData,
512        }
513    }
514    mixedradix_column_butterflies!(
515        6,
516        |columns, this: &Self| AvxVector256::column_butterfly6(columns, this.twiddles_butterfly3),
517        |columns, this: &Self| AvxVector128::column_butterfly6(columns, this.twiddles_butterfly3)
518    );
519    mixedradix_transpose!(6,
520        AvxVector::transpose6_packed,
521        AvxVector::transpose6_packed,
522        0;1;2;3;4;5, 0;1;2;3
523    );
524    boilerplate_mixedradix!();
525}
526
527pub struct MixedRadix7xnAvx<A: AvxNum, T> {
528    twiddles_butterfly7: [A::VectorType; 3],
529    common_data: CommonSimdData<T, A::VectorType>,
530    _phantom: std::marker::PhantomData<T>,
531}
532boilerplate_avx_fft_commondata!(MixedRadix7xnAvx);
533
534impl<A: AvxNum, T: FftNum> MixedRadix7xnAvx<A, T> {
535    #[target_feature(enable = "avx")]
536    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
537        Self {
538            twiddles_butterfly7: [
539                AvxVector::broadcast_twiddle(1, 7, inner_fft.fft_direction()),
540                AvxVector::broadcast_twiddle(2, 7, inner_fft.fft_direction()),
541                AvxVector::broadcast_twiddle(3, 7, inner_fft.fft_direction()),
542            ],
543            common_data: mixedradix_gen_data!(7, inner_fft),
544            _phantom: std::marker::PhantomData,
545        }
546    }
547    mixedradix_column_butterflies!(
548        7,
549        |columns, this: &Self| AvxVector::column_butterfly7(columns, this.twiddles_butterfly7),
550        |columns, this: &Self| AvxVector::column_butterfly7(
551            columns,
552            [
553                this.twiddles_butterfly7[0].lo(),
554                this.twiddles_butterfly7[1].lo(),
555                this.twiddles_butterfly7[2].lo()
556            ]
557        )
558    );
559    mixedradix_transpose!(7,
560        AvxVector::transpose7_packed,
561        AvxVector::transpose7_packed,
562        0;1;2;3;4;5;6, 0;1;2;3;4
563    );
564    boilerplate_mixedradix!();
565}
566
567pub struct MixedRadix8xnAvx<A: AvxNum, T> {
568    twiddles_butterfly4: Rotation90<A::VectorType>,
569    common_data: CommonSimdData<T, A::VectorType>,
570    _phantom: std::marker::PhantomData<T>,
571}
572boilerplate_avx_fft_commondata!(MixedRadix8xnAvx);
573
574impl<A: AvxNum, T: FftNum> MixedRadix8xnAvx<A, T> {
575    #[target_feature(enable = "avx")]
576    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
577        Self {
578            twiddles_butterfly4: AvxVector::make_rotation90(inner_fft.fft_direction()),
579            common_data: mixedradix_gen_data!(8, inner_fft),
580            _phantom: std::marker::PhantomData,
581        }
582    }
583
584    mixedradix_column_butterflies!(
585        8,
586        |columns, this: &Self| AvxVector::column_butterfly8(columns, this.twiddles_butterfly4),
587        |columns, this: &Self| AvxVector::column_butterfly8(columns, this.twiddles_butterfly4.lo())
588    );
589    mixedradix_transpose!(8,
590        AvxVector::transpose8_packed,
591        AvxVector::transpose8_packed,
592        0;1;2;3;4;5;6;7, 0;1;2;3;4;5
593    );
594    boilerplate_mixedradix!();
595}
596
597pub struct MixedRadix9xnAvx<A: AvxNum, T> {
598    twiddles_butterfly9: [A::VectorType; 3],
599    twiddles_butterfly9_lo: [A::VectorType; 2],
600    twiddles_butterfly3: A::VectorType,
601    common_data: CommonSimdData<T, A::VectorType>,
602    _phantom: std::marker::PhantomData<T>,
603}
604boilerplate_avx_fft_commondata!(MixedRadix9xnAvx);
605
606impl<A: AvxNum, T: FftNum> MixedRadix9xnAvx<A, T> {
607    #[target_feature(enable = "avx")]
608    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
609        let inverse = inner_fft.fft_direction();
610
611        let twiddle1 = AvxVector::broadcast_twiddle(1, 9, inner_fft.fft_direction());
612        let twiddle2 = AvxVector::broadcast_twiddle(2, 9, inner_fft.fft_direction());
613        let twiddle4 = AvxVector::broadcast_twiddle(4, 9, inner_fft.fft_direction());
614
615        Self {
616            twiddles_butterfly9: [
617                AvxVector::broadcast_twiddle(1, 9, inverse),
618                AvxVector::broadcast_twiddle(2, 9, inverse),
619                AvxVector::broadcast_twiddle(4, 9, inverse),
620            ],
621            twiddles_butterfly9_lo: [
622                AvxVector256::merge(twiddle1, twiddle2),
623                AvxVector256::merge(twiddle2, twiddle4),
624            ],
625            twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, inner_fft.fft_direction()),
626            common_data: mixedradix_gen_data!(9, inner_fft),
627            _phantom: std::marker::PhantomData,
628        }
629    }
630
631    mixedradix_column_butterflies!(
632        9,
633        |columns, this: &Self| AvxVector256::column_butterfly9(
634            columns,
635            this.twiddles_butterfly9,
636            this.twiddles_butterfly3
637        ),
638        |columns, this: &Self| AvxVector128::column_butterfly9(
639            columns,
640            this.twiddles_butterfly9_lo,
641            this.twiddles_butterfly3
642        )
643    );
644    mixedradix_transpose!(9,
645        AvxVector::transpose9_packed,
646        AvxVector::transpose9_packed,
647        0;1;2;3;4;5;6;7;8, 0;1;2;3;4;5
648    );
649    boilerplate_mixedradix!();
650}
651
652pub struct MixedRadix11xnAvx<A: AvxNum, T> {
653    twiddles_butterfly11: [A::VectorType; 5],
654    common_data: CommonSimdData<T, A::VectorType>,
655    _phantom: std::marker::PhantomData<T>,
656}
657boilerplate_avx_fft_commondata!(MixedRadix11xnAvx);
658
659impl<A: AvxNum, T: FftNum> MixedRadix11xnAvx<A, T> {
660    #[target_feature(enable = "avx")]
661    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
662        Self {
663            twiddles_butterfly11: [
664                AvxVector::broadcast_twiddle(1, 11, inner_fft.fft_direction()),
665                AvxVector::broadcast_twiddle(2, 11, inner_fft.fft_direction()),
666                AvxVector::broadcast_twiddle(3, 11, inner_fft.fft_direction()),
667                AvxVector::broadcast_twiddle(4, 11, inner_fft.fft_direction()),
668                AvxVector::broadcast_twiddle(5, 11, inner_fft.fft_direction()),
669            ],
670            common_data: mixedradix_gen_data!(11, inner_fft),
671            _phantom: std::marker::PhantomData,
672        }
673    }
674    mixedradix_column_butterflies!(
675        11,
676        |columns, this: &Self| AvxVector::column_butterfly11(columns, this.twiddles_butterfly11),
677        |columns, this: &Self| AvxVector::column_butterfly11(
678            columns,
679            [
680                this.twiddles_butterfly11[0].lo(),
681                this.twiddles_butterfly11[1].lo(),
682                this.twiddles_butterfly11[2].lo(),
683                this.twiddles_butterfly11[3].lo(),
684                this.twiddles_butterfly11[4].lo()
685            ]
686        )
687    );
688    mixedradix_transpose!(11,
689        AvxVector::transpose11_packed,
690        AvxVector::transpose11_packed,
691        0;1;2;3;4;5;6;7;8;9;10, 0;1;2;3;4;5;6;7
692    );
693    boilerplate_mixedradix!();
694}
695
696pub struct MixedRadix12xnAvx<A: AvxNum, T> {
697    twiddles_butterfly4: Rotation90<A::VectorType>,
698    twiddles_butterfly3: A::VectorType,
699    common_data: CommonSimdData<T, A::VectorType>,
700    _phantom: std::marker::PhantomData<T>,
701}
702boilerplate_avx_fft_commondata!(MixedRadix12xnAvx);
703
704impl<A: AvxNum, T: FftNum> MixedRadix12xnAvx<A, T> {
705    #[target_feature(enable = "avx")]
706    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
707        let inverse = inner_fft.fft_direction();
708        Self {
709            twiddles_butterfly4: AvxVector::make_rotation90(inverse),
710            twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, inverse),
711            common_data: mixedradix_gen_data!(12, inner_fft),
712            _phantom: std::marker::PhantomData,
713        }
714    }
715
716    mixedradix_column_butterflies!(
717        12,
718        |columns, this: &Self| AvxVector256::column_butterfly12(
719            columns,
720            this.twiddles_butterfly3,
721            this.twiddles_butterfly4
722        ),
723        |columns, this: &Self| AvxVector128::column_butterfly12(
724            columns,
725            this.twiddles_butterfly3,
726            this.twiddles_butterfly4
727        )
728    );
729    mixedradix_transpose!(12,
730        AvxVector::transpose12_packed,
731        AvxVector::transpose12_packed,
732        0;1;2;3;4;5;6;7;8;9;10;11, 0;1;2;3;4;5;6;7;8
733    );
734    boilerplate_mixedradix!();
735}
736
737pub struct MixedRadix16xnAvx<A: AvxNum, T> {
738    twiddles_butterfly4: Rotation90<A::VectorType>,
739    twiddles_butterfly16: [A::VectorType; 2],
740    common_data: CommonSimdData<T, A::VectorType>,
741    _phantom: std::marker::PhantomData<T>,
742}
743boilerplate_avx_fft_commondata!(MixedRadix16xnAvx);
744
745impl<A: AvxNum, T: FftNum> MixedRadix16xnAvx<A, T> {
746    #[target_feature(enable = "avx")]
747    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
748        let inverse = inner_fft.fft_direction();
749        Self {
750            twiddles_butterfly4: AvxVector::make_rotation90(inner_fft.fft_direction()),
751            twiddles_butterfly16: [
752                AvxVector::broadcast_twiddle(1, 16, inverse),
753                AvxVector::broadcast_twiddle(3, 16, inverse),
754            ],
755            common_data: mixedradix_gen_data!(16, inner_fft),
756            _phantom: std::marker::PhantomData,
757        }
758    }
759
760    #[target_feature(enable = "avx", enable = "fma")]
761    unsafe fn perform_column_butterflies(&self, mut buffer: impl AvxArrayMut<A>) {
762        // How many rows this FFT has, ie 2 for 2xn, 4 for 4xn, etc
763        const ROW_COUNT: usize = 16;
764        const TWIDDLES_PER_COLUMN: usize = ROW_COUNT - 1;
765
766        let len_per_row = self.len() / ROW_COUNT;
767        let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
768
769        // process the column FFTs
770        for (c, twiddle_chunk) in self
771            .common_data
772            .twiddles
773            .chunks_exact(TWIDDLES_PER_COLUMN)
774            .take(chunk_count)
775            .enumerate()
776        {
777            let index_base = c * A::VectorType::COMPLEX_PER_VECTOR;
778
779            column_butterfly16_loadfn!(
780                |index| buffer.load_complex(index_base + len_per_row * index),
781                |mut data, index| {
782                    if index > 0 {
783                        data = AvxVector::mul_complex(data, twiddle_chunk[index - 1]);
784                    }
785                    buffer.store_complex(data, index_base + len_per_row * index)
786                },
787                self.twiddles_butterfly16,
788                self.twiddles_butterfly4
789            );
790        }
791
792        // finally, we might have a single partial chunk.
793        // Normally, we can fit 4 complex numbers into an AVX register, but we only have `partial_remainder` columns left, so we need special logic to handle these final columns
794        let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
795        if partial_remainder > 0 {
796            let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
797            let partial_remainder_twiddle_base =
798                self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN;
799            let final_twiddle_chunk = &self.common_data.twiddles[partial_remainder_twiddle_base..];
800
801            match partial_remainder {
802                1 => {
803                    column_butterfly16_loadfn!(
804                        |index| buffer
805                            .load_partial1_complex(partial_remainder_base + len_per_row * index),
806                        |mut data, index| {
807                            if index > 0 {
808                                let twiddle: A::VectorType = final_twiddle_chunk[index - 1];
809                                data = AvxVector::mul_complex(data, twiddle.lo());
810                            }
811                            buffer.store_partial1_complex(
812                                data,
813                                partial_remainder_base + len_per_row * index,
814                            )
815                        },
816                        [
817                            self.twiddles_butterfly16[0].lo(),
818                            self.twiddles_butterfly16[1].lo()
819                        ],
820                        self.twiddles_butterfly4.lo()
821                    );
822                }
823                2 => {
824                    column_butterfly16_loadfn!(
825                        |index| buffer
826                            .load_partial2_complex(partial_remainder_base + len_per_row * index),
827                        |mut data, index| {
828                            if index > 0 {
829                                let twiddle: A::VectorType = final_twiddle_chunk[index - 1];
830                                data = AvxVector::mul_complex(data, twiddle.lo());
831                            }
832                            buffer.store_partial2_complex(
833                                data,
834                                partial_remainder_base + len_per_row * index,
835                            )
836                        },
837                        [
838                            self.twiddles_butterfly16[0].lo(),
839                            self.twiddles_butterfly16[1].lo()
840                        ],
841                        self.twiddles_butterfly4.lo()
842                    );
843                }
844                3 => {
845                    column_butterfly16_loadfn!(
846                        |index| buffer
847                            .load_partial3_complex(partial_remainder_base + len_per_row * index),
848                        |mut data, index| {
849                            if index > 0 {
850                                data = AvxVector::mul_complex(data, final_twiddle_chunk[index - 1]);
851                            }
852                            buffer.store_partial3_complex(
853                                data,
854                                partial_remainder_base + len_per_row * index,
855                            )
856                        },
857                        self.twiddles_butterfly16,
858                        self.twiddles_butterfly4
859                    );
860                }
861                _ => unreachable!(),
862            }
863        }
864    }
865    mixedradix_transpose!(16,
866        AvxVector::transpose16_packed,
867        AvxVector::transpose16_packed,
868        0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15, 0;1;2;3;4;5;6;7;8;9;10;11
869    );
870    boilerplate_mixedradix!();
871}
872
873#[cfg(test)]
874mod unit_tests {
875    use super::*;
876    use crate::algorithm::*;
877    use crate::test_utils::check_fft_algorithm;
878    use std::sync::Arc;
879
880    macro_rules! test_avx_mixed_radix {
881        ($f32_test_name:ident, $f64_test_name:ident, $struct_name:ident, $inner_count:expr) => (
882            #[test]
883            fn $f32_test_name() {
884                for inner_fft_len in 1..32 {
885                    let len = inner_fft_len * $inner_count;
886
887                    let inner_fft_forward = Arc::new(Dft::new(inner_fft_len, FftDirection::Forward)) as Arc<dyn Fft<f32>>;
888                    let fft_forward = $struct_name::<f32, f32>::new(inner_fft_forward).expect("Can't run test because this machine doesn't have the required instruction sets");
889                    check_fft_algorithm(&fft_forward, len, FftDirection::Forward);
890
891                    let inner_fft_inverse = Arc::new(Dft::new(inner_fft_len, FftDirection::Inverse)) as Arc<dyn Fft<f32>>;
892                    let fft_inverse = $struct_name::<f32, f32>::new(inner_fft_inverse).expect("Can't run test because this machine doesn't have the required instruction sets");
893                    check_fft_algorithm(&fft_inverse, len, FftDirection::Inverse);
894                }
895            }
896            #[test]
897            fn $f64_test_name() {
898                for inner_fft_len in 1..32 {
899                    let len = inner_fft_len * $inner_count;
900
901                    let inner_fft_forward = Arc::new(Dft::new(inner_fft_len, FftDirection::Forward)) as Arc<dyn Fft<f64>>;
902                    let fft_forward = $struct_name::<f64, f64>::new(inner_fft_forward).expect("Can't run test because this machine doesn't have the required instruction sets");
903                    check_fft_algorithm(&fft_forward, len, FftDirection::Forward);
904
905                    let inner_fft_inverse = Arc::new(Dft::new(inner_fft_len, FftDirection::Inverse)) as Arc<dyn Fft<f64>>;
906                    let fft_inverse = $struct_name::<f64, f64>::new(inner_fft_inverse).expect("Can't run test because this machine doesn't have the required instruction sets");
907                    check_fft_algorithm(&fft_inverse, len, FftDirection::Inverse);
908                }
909            }
910        )
911    }
912
913    test_avx_mixed_radix!(
914        test_mixedradix_2xn_avx_f32,
915        test_mixedradix_2xn_avx_f64,
916        MixedRadix2xnAvx,
917        2
918    );
919    test_avx_mixed_radix!(
920        test_mixedradix_3xn_avx_f32,
921        test_mixedradix_3xn_avx_f64,
922        MixedRadix3xnAvx,
923        3
924    );
925    test_avx_mixed_radix!(
926        test_mixedradix_4xn_avx_f32,
927        test_mixedradix_4xn_avx_f64,
928        MixedRadix4xnAvx,
929        4
930    );
931    test_avx_mixed_radix!(
932        test_mixedradix_5xn_avx_f32,
933        test_mixedradix_5xn_avx_f64,
934        MixedRadix5xnAvx,
935        5
936    );
937    test_avx_mixed_radix!(
938        test_mixedradix_6xn_avx_f32,
939        test_mixedradix_6xn_avx_f64,
940        MixedRadix6xnAvx,
941        6
942    );
943    test_avx_mixed_radix!(
944        test_mixedradix_7xn_avx_f32,
945        test_mixedradix_7xn_avx_f64,
946        MixedRadix7xnAvx,
947        7
948    );
949    test_avx_mixed_radix!(
950        test_mixedradix_8xn_avx_f32,
951        test_mixedradix_8xn_avx_f64,
952        MixedRadix8xnAvx,
953        8
954    );
955    test_avx_mixed_radix!(
956        test_mixedradix_9xn_avx_f32,
957        test_mixedradix_9xn_avx_f64,
958        MixedRadix9xnAvx,
959        9
960    );
961    test_avx_mixed_radix!(
962        test_mixedradix_11xn_avx_f32,
963        test_mixedradix_11xn_avx_f64,
964        MixedRadix11xnAvx,
965        11
966    );
967    test_avx_mixed_radix!(
968        test_mixedradix_12xn_avx_f32,
969        test_mixedradix_12xn_avx_f64,
970        MixedRadix12xnAvx,
971        12
972    );
973    test_avx_mixed_radix!(
974        test_mixedradix_16xn_avx_f32,
975        test_mixedradix_16xn_avx_f64,
976        MixedRadix16xnAvx,
977        16
978    );
979}