1use std::any::TypeId;
2use std::sync::Arc;
3
4use num_complex::Complex;
5use num_integer::div_ceil;
6
7use crate::array_utils;
8use crate::common::{fft_error_inplace, fft_error_outofplace};
9use crate::{Direction, Fft, FftDirection, FftNum, Length};
10
11use super::{AvxNum, CommonSimdData};
12
13use super::avx_vector;
14use super::avx_vector::{AvxArray, AvxArrayMut, AvxVector, AvxVector128, AvxVector256, Rotation90};
15
16macro_rules! boilerplate_mixedradix {
17 () => {
18 #[inline]
21 pub fn new(inner_fft: Arc<dyn Fft<T>>) -> Result<Self, ()> {
22 let id_a = TypeId::of::<A>();
26 let id_t = TypeId::of::<T>();
27 assert_eq!(id_a, id_t);
28
29 let has_avx = is_x86_feature_detected!("avx");
30 let has_fma = is_x86_feature_detected!("fma");
31 if has_avx && has_fma {
32 Ok(unsafe { Self::new_with_avx(inner_fft) })
34 } else {
35 Err(())
36 }
37 }
38
39 #[inline]
40 fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
41 unsafe {
44 let transmuted_buffer: &mut [Complex<A>] =
46 array_utils::workaround_transmute_mut(buffer);
47
48 self.perform_column_butterflies(transmuted_buffer)
49 }
50
51 let (scratch, inner_scratch) = scratch.split_at_mut(self.len());
53 self.common_data.inner_fft.process_outofplace_with_scratch(
54 buffer,
55 scratch,
56 inner_scratch,
57 );
58
59 unsafe {
62 let transmuted_scratch: &mut [Complex<A>] =
64 array_utils::workaround_transmute_mut(scratch);
65 let transmuted_buffer: &mut [Complex<A>] =
66 array_utils::workaround_transmute_mut(buffer);
67
68 self.transpose(transmuted_scratch, transmuted_buffer)
69 }
70 }
71
72 #[inline]
73 fn perform_fft_out_of_place(
74 &self,
75 input: &mut [Complex<T>],
76 output: &mut [Complex<T>],
77 scratch: &mut [Complex<T>],
78 ) {
79 unsafe {
82 let transmuted_input: &mut [Complex<A>] =
84 array_utils::workaround_transmute_mut(input);
85
86 self.perform_column_butterflies(transmuted_input);
87 }
88
89 let inner_scratch = if scratch.len() > 0 {
91 scratch
92 } else {
93 &mut output[..]
94 };
95 self.common_data
96 .inner_fft
97 .process_with_scratch(input, inner_scratch);
98
99 unsafe {
102 let transmuted_input: &mut [Complex<A>] =
104 array_utils::workaround_transmute_mut(input);
105 let transmuted_output: &mut [Complex<A>] =
106 array_utils::workaround_transmute_mut(output);
107
108 self.transpose(transmuted_input, transmuted_output)
109 }
110 }
111 };
112}
113
114macro_rules! mixedradix_gen_data {
115 ($row_count: expr, $inner_fft:expr) => {{
116 const ROW_COUNT : usize = $row_count;
118 const TWIDDLES_PER_COLUMN : usize = ROW_COUNT - 1;
119
120 let direction = $inner_fft.fft_direction();
122 let len_per_row = $inner_fft.len();
123 let len = len_per_row * ROW_COUNT;
124
125 let quotient = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
128 let remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
129
130 let num_twiddle_columns = quotient + div_ceil(remainder, A::VectorType::COMPLEX_PER_VECTOR);
132 let mut twiddles = Vec::with_capacity(num_twiddle_columns * TWIDDLES_PER_COLUMN);
133 for x in 0..num_twiddle_columns {
134 for y in 1..ROW_COUNT {
135 twiddles.push(AvxVector::make_mixedradix_twiddle_chunk(x * A::VectorType::COMPLEX_PER_VECTOR, y, len, direction));
136 }
137 }
138
139 let inner_outofplace_scratch = $inner_fft.get_outofplace_scratch_len();
140 let inner_inplace_scratch = $inner_fft.get_inplace_scratch_len();
141
142 CommonSimdData {
143 twiddles: twiddles.into_boxed_slice(),
144 inplace_scratch_len: len + inner_outofplace_scratch,
145 outofplace_scratch_len: if inner_inplace_scratch > len { inner_inplace_scratch } else { 0 },
146 inner_fft: $inner_fft,
147 len,
148 direction,
149 }
150 }}
151}
152
153macro_rules! mixedradix_column_butterflies {
154 ($row_count: expr, $butterfly_fn: expr, $butterfly_fn_lo: expr) => {
155 #[target_feature(enable = "avx", enable = "fma")]
156 unsafe fn perform_column_butterflies(&self, mut buffer: impl AvxArrayMut<A>) {
157 const ROW_COUNT: usize = $row_count;
159 const TWIDDLES_PER_COLUMN: usize = ROW_COUNT - 1;
160
161 let len_per_row = self.len() / ROW_COUNT;
162 let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
163
164 for (c, twiddle_chunk) in self
166 .common_data
167 .twiddles
168 .chunks_exact(TWIDDLES_PER_COLUMN)
169 .take(chunk_count)
170 .enumerate()
171 {
172 let index_base = c * A::VectorType::COMPLEX_PER_VECTOR;
173
174 let mut columns = [AvxVector::zero(); ROW_COUNT];
176 for i in 0..ROW_COUNT {
177 columns[i] = buffer.load_complex(index_base + len_per_row * i);
178 }
179
180 let output = $butterfly_fn(columns, self);
182
183 buffer.store_complex(output[0], index_base);
185
186 for i in 1..ROW_COUNT {
188 let twiddle = twiddle_chunk[i - 1];
189 let output = AvxVector::mul_complex(twiddle, output[i]);
190 buffer.store_complex(output, index_base + len_per_row * i);
191 }
192 }
193
194 let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
197 if partial_remainder > 0 {
198 let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
199 let partial_remainder_twiddle_base =
200 self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN;
201 let final_twiddle_chunk =
202 &self.common_data.twiddles[partial_remainder_twiddle_base..];
203
204 if partial_remainder > 2 {
205 let mut columns = [AvxVector::zero(); ROW_COUNT];
207 for i in 0..ROW_COUNT {
208 columns[i] =
209 buffer.load_partial3_complex(partial_remainder_base + len_per_row * i);
210 }
211
212 let mid = $butterfly_fn(columns, self);
214
215 buffer.store_partial3_complex(mid[0], partial_remainder_base);
217
218 for i in 1..ROW_COUNT {
220 let twiddle = final_twiddle_chunk[i - 1];
221 let output = AvxVector::mul_complex(twiddle, mid[i]);
222 buffer.store_partial3_complex(
223 output,
224 partial_remainder_base + len_per_row * i,
225 );
226 }
227 } else {
228 let mut columns = [AvxVector::zero(); ROW_COUNT];
230 if partial_remainder == 1 {
231 for i in 0..ROW_COUNT {
232 columns[i] = buffer
233 .load_partial1_complex(partial_remainder_base + len_per_row * i);
234 }
235 } else {
236 for i in 0..ROW_COUNT {
237 columns[i] = buffer
238 .load_partial2_complex(partial_remainder_base + len_per_row * i);
239 }
240 }
241
242 let mut mid = $butterfly_fn_lo(columns, self);
244
245 for i in 1..ROW_COUNT {
247 mid[i] = AvxVector::mul_complex(final_twiddle_chunk[i - 1].lo(), mid[i]);
248 }
249
250 if partial_remainder == 1 {
252 for i in 0..ROW_COUNT {
253 buffer.store_partial1_complex(
254 mid[i],
255 partial_remainder_base + len_per_row * i,
256 );
257 }
258 } else {
259 for i in 0..ROW_COUNT {
260 buffer.store_partial2_complex(
261 mid[i],
262 partial_remainder_base + len_per_row * i,
263 );
264 }
265 }
266 }
267 }
268 }
269 };
270}
271
272macro_rules! mixedradix_transpose{
273 ($row_count: expr, $transpose_fn: path, $transpose_fn_lo: path, $($unroll_workaround_index:expr);*, $($remainder3_unroll_workaround_index:expr);*) => (
274
275 #[target_feature(enable = "avx")]
277 unsafe fn transpose(&self, input: &[Complex<A>], mut output: &mut [Complex<A>]) {
278 const ROW_COUNT : usize = $row_count;
279
280 let len_per_row = self.len() / ROW_COUNT;
281 let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
282
283 for c in 0..chunk_count {
285 let input_index_base = c*A::VectorType::COMPLEX_PER_VECTOR;
286 let output_index_base = input_index_base * ROW_COUNT;
287
288 let mut rows : [A::VectorType; ROW_COUNT] = [AvxVector::zero(); ROW_COUNT];
290 for i in 0..ROW_COUNT {
291 rows[i] = input.load_complex(input_index_base + len_per_row*i);
292 }
293
294 let transposed = $transpose_fn(rows);
296
297 $(
307 output.store_complex(transposed[$unroll_workaround_index], output_index_base + A::VectorType::COMPLEX_PER_VECTOR * $unroll_workaround_index);
308 )*
309 }
310
311 let input_index_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
313 let output_index_base = input_index_base * ROW_COUNT;
314
315 let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
316 if partial_remainder == 1 {
317 for i in 0..ROW_COUNT {
319 let input_cell = input.get_unchecked(input_index_base + len_per_row*i);
320 let output_cell = output.get_unchecked_mut(output_index_base + i);
321 *output_cell = *input_cell;
322 }
323 } else if partial_remainder == 2 {
324 let mut rows = [AvxVector::zero(); ROW_COUNT];
326 for i in 0..ROW_COUNT {
327 rows[i] = input.load_partial2_complex(input_index_base + len_per_row*i);
328 }
329
330 let transposed = $transpose_fn_lo(rows);
331
332 $(
334 output.store_partial2_complex(transposed[$unroll_workaround_index], output_index_base + <A::VectorType as AvxVector256>::HalfVector::COMPLEX_PER_VECTOR * $unroll_workaround_index);
335 )*
336 }
337 else if partial_remainder == 3 {
338 let mut rows = [AvxVector::zero(); ROW_COUNT];
340 for i in 0..ROW_COUNT {
341 rows[i] = input.load_partial3_complex(input_index_base + len_per_row*i);
342 }
343
344 let transposed = $transpose_fn(rows);
346
347 let element_count = 3*ROW_COUNT;
350 let full_vector_count = element_count / A::VectorType::COMPLEX_PER_VECTOR;
351 let final_remainder_count = element_count % A::VectorType::COMPLEX_PER_VECTOR;
352
353 $(
360 output.store_complex(transposed[$remainder3_unroll_workaround_index], output_index_base + A::VectorType::COMPLEX_PER_VECTOR * $remainder3_unroll_workaround_index);
361 )*
362
363 match final_remainder_count {
365 0 => {},
366 1 => output.store_partial1_complex(transposed[full_vector_count].lo(), output_index_base + full_vector_count * A::VectorType::COMPLEX_PER_VECTOR),
367 2 => output.store_partial2_complex(transposed[full_vector_count].lo(), output_index_base + full_vector_count * A::VectorType::COMPLEX_PER_VECTOR),
368 3 => output.store_partial3_complex(transposed[full_vector_count], output_index_base + full_vector_count * A::VectorType::COMPLEX_PER_VECTOR),
369 _ => unreachable!(),
370 }
371 }
372 }
373)}
374
375pub struct MixedRadix2xnAvx<A: AvxNum, T> {
376 common_data: CommonSimdData<T, A::VectorType>,
377 _phantom: std::marker::PhantomData<T>,
378}
379boilerplate_avx_fft_commondata!(MixedRadix2xnAvx);
380
381impl<A: AvxNum, T: FftNum> MixedRadix2xnAvx<A, T> {
382 #[target_feature(enable = "avx")]
383 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
384 Self {
385 common_data: mixedradix_gen_data!(2, inner_fft),
386 _phantom: std::marker::PhantomData,
387 }
388 }
389 mixedradix_column_butterflies!(
390 2,
391 |columns, _: _| AvxVector::column_butterfly2(columns),
392 |columns, _: _| AvxVector::column_butterfly2(columns)
393 );
394 mixedradix_transpose!(2,
395 AvxVector::transpose2_packed,
396 AvxVector::transpose2_packed,
397 0;1, 0
398 );
399 boilerplate_mixedradix!();
400}
401
402pub struct MixedRadix3xnAvx<A: AvxNum, T> {
403 twiddles_butterfly3: A::VectorType,
404 common_data: CommonSimdData<T, A::VectorType>,
405 _phantom: std::marker::PhantomData<T>,
406}
407boilerplate_avx_fft_commondata!(MixedRadix3xnAvx);
408
409impl<A: AvxNum, T: FftNum> MixedRadix3xnAvx<A, T> {
410 #[target_feature(enable = "avx")]
411 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
412 Self {
413 twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, inner_fft.fft_direction()),
414 common_data: mixedradix_gen_data!(3, inner_fft),
415 _phantom: std::marker::PhantomData,
416 }
417 }
418 mixedradix_column_butterflies!(
419 3,
420 |columns, this: &Self| AvxVector::column_butterfly3(columns, this.twiddles_butterfly3),
421 |columns, this: &Self| AvxVector::column_butterfly3(columns, this.twiddles_butterfly3.lo())
422 );
423 mixedradix_transpose!(3,
424 AvxVector::transpose3_packed,
425 AvxVector::transpose3_packed,
426 0;1;2, 0;1
427 );
428 boilerplate_mixedradix!();
429}
430
431pub struct MixedRadix4xnAvx<A: AvxNum, T> {
432 twiddles_butterfly4: Rotation90<A::VectorType>,
433 common_data: CommonSimdData<T, A::VectorType>,
434 _phantom: std::marker::PhantomData<T>,
435}
436boilerplate_avx_fft_commondata!(MixedRadix4xnAvx);
437
438impl<A: AvxNum, T: FftNum> MixedRadix4xnAvx<A, T> {
439 #[target_feature(enable = "avx")]
440 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
441 Self {
442 twiddles_butterfly4: AvxVector::make_rotation90(inner_fft.fft_direction()),
443 common_data: mixedradix_gen_data!(4, inner_fft),
444 _phantom: std::marker::PhantomData,
445 }
446 }
447 mixedradix_column_butterflies!(
448 4,
449 |columns, this: &Self| AvxVector::column_butterfly4(columns, this.twiddles_butterfly4),
450 |columns, this: &Self| AvxVector::column_butterfly4(columns, this.twiddles_butterfly4.lo())
451 );
452 mixedradix_transpose!(4,
453 AvxVector::transpose4_packed,
454 AvxVector::transpose4_packed,
455 0;1;2;3, 0;1;2
456 );
457 boilerplate_mixedradix!();
458}
459
460pub struct MixedRadix5xnAvx<A: AvxNum, T> {
461 twiddles_butterfly5: [A::VectorType; 2],
462 common_data: CommonSimdData<T, A::VectorType>,
463 _phantom: std::marker::PhantomData<T>,
464}
465boilerplate_avx_fft_commondata!(MixedRadix5xnAvx);
466
467impl<A: AvxNum, T: FftNum> MixedRadix5xnAvx<A, T> {
468 #[target_feature(enable = "avx")]
469 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
470 Self {
471 twiddles_butterfly5: [
472 AvxVector::broadcast_twiddle(1, 5, inner_fft.fft_direction()),
473 AvxVector::broadcast_twiddle(2, 5, inner_fft.fft_direction()),
474 ],
475 common_data: mixedradix_gen_data!(5, inner_fft),
476 _phantom: std::marker::PhantomData,
477 }
478 }
479 mixedradix_column_butterflies!(
480 5,
481 |columns, this: &Self| AvxVector::column_butterfly5(columns, this.twiddles_butterfly5),
482 |columns, this: &Self| AvxVector::column_butterfly5(
483 columns,
484 [
485 this.twiddles_butterfly5[0].lo(),
486 this.twiddles_butterfly5[1].lo()
487 ]
488 )
489 );
490 mixedradix_transpose!(5,
491 AvxVector::transpose5_packed,
492 AvxVector::transpose5_packed,
493 0;1;2;3;4, 0;1;2
494 );
495 boilerplate_mixedradix!();
496}
497
498pub struct MixedRadix6xnAvx<A: AvxNum, T> {
499 twiddles_butterfly3: A::VectorType,
500 common_data: CommonSimdData<T, A::VectorType>,
501 _phantom: std::marker::PhantomData<T>,
502}
503boilerplate_avx_fft_commondata!(MixedRadix6xnAvx);
504
505impl<A: AvxNum, T: FftNum> MixedRadix6xnAvx<A, T> {
506 #[target_feature(enable = "avx")]
507 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
508 Self {
509 twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, inner_fft.fft_direction()),
510 common_data: mixedradix_gen_data!(6, inner_fft),
511 _phantom: std::marker::PhantomData,
512 }
513 }
514 mixedradix_column_butterflies!(
515 6,
516 |columns, this: &Self| AvxVector256::column_butterfly6(columns, this.twiddles_butterfly3),
517 |columns, this: &Self| AvxVector128::column_butterfly6(columns, this.twiddles_butterfly3)
518 );
519 mixedradix_transpose!(6,
520 AvxVector::transpose6_packed,
521 AvxVector::transpose6_packed,
522 0;1;2;3;4;5, 0;1;2;3
523 );
524 boilerplate_mixedradix!();
525}
526
527pub struct MixedRadix7xnAvx<A: AvxNum, T> {
528 twiddles_butterfly7: [A::VectorType; 3],
529 common_data: CommonSimdData<T, A::VectorType>,
530 _phantom: std::marker::PhantomData<T>,
531}
532boilerplate_avx_fft_commondata!(MixedRadix7xnAvx);
533
534impl<A: AvxNum, T: FftNum> MixedRadix7xnAvx<A, T> {
535 #[target_feature(enable = "avx")]
536 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
537 Self {
538 twiddles_butterfly7: [
539 AvxVector::broadcast_twiddle(1, 7, inner_fft.fft_direction()),
540 AvxVector::broadcast_twiddle(2, 7, inner_fft.fft_direction()),
541 AvxVector::broadcast_twiddle(3, 7, inner_fft.fft_direction()),
542 ],
543 common_data: mixedradix_gen_data!(7, inner_fft),
544 _phantom: std::marker::PhantomData,
545 }
546 }
547 mixedradix_column_butterflies!(
548 7,
549 |columns, this: &Self| AvxVector::column_butterfly7(columns, this.twiddles_butterfly7),
550 |columns, this: &Self| AvxVector::column_butterfly7(
551 columns,
552 [
553 this.twiddles_butterfly7[0].lo(),
554 this.twiddles_butterfly7[1].lo(),
555 this.twiddles_butterfly7[2].lo()
556 ]
557 )
558 );
559 mixedradix_transpose!(7,
560 AvxVector::transpose7_packed,
561 AvxVector::transpose7_packed,
562 0;1;2;3;4;5;6, 0;1;2;3;4
563 );
564 boilerplate_mixedradix!();
565}
566
567pub struct MixedRadix8xnAvx<A: AvxNum, T> {
568 twiddles_butterfly4: Rotation90<A::VectorType>,
569 common_data: CommonSimdData<T, A::VectorType>,
570 _phantom: std::marker::PhantomData<T>,
571}
572boilerplate_avx_fft_commondata!(MixedRadix8xnAvx);
573
574impl<A: AvxNum, T: FftNum> MixedRadix8xnAvx<A, T> {
575 #[target_feature(enable = "avx")]
576 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
577 Self {
578 twiddles_butterfly4: AvxVector::make_rotation90(inner_fft.fft_direction()),
579 common_data: mixedradix_gen_data!(8, inner_fft),
580 _phantom: std::marker::PhantomData,
581 }
582 }
583
584 mixedradix_column_butterflies!(
585 8,
586 |columns, this: &Self| AvxVector::column_butterfly8(columns, this.twiddles_butterfly4),
587 |columns, this: &Self| AvxVector::column_butterfly8(columns, this.twiddles_butterfly4.lo())
588 );
589 mixedradix_transpose!(8,
590 AvxVector::transpose8_packed,
591 AvxVector::transpose8_packed,
592 0;1;2;3;4;5;6;7, 0;1;2;3;4;5
593 );
594 boilerplate_mixedradix!();
595}
596
597pub struct MixedRadix9xnAvx<A: AvxNum, T> {
598 twiddles_butterfly9: [A::VectorType; 3],
599 twiddles_butterfly9_lo: [A::VectorType; 2],
600 twiddles_butterfly3: A::VectorType,
601 common_data: CommonSimdData<T, A::VectorType>,
602 _phantom: std::marker::PhantomData<T>,
603}
604boilerplate_avx_fft_commondata!(MixedRadix9xnAvx);
605
606impl<A: AvxNum, T: FftNum> MixedRadix9xnAvx<A, T> {
607 #[target_feature(enable = "avx")]
608 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
609 let inverse = inner_fft.fft_direction();
610
611 let twiddle1 = AvxVector::broadcast_twiddle(1, 9, inner_fft.fft_direction());
612 let twiddle2 = AvxVector::broadcast_twiddle(2, 9, inner_fft.fft_direction());
613 let twiddle4 = AvxVector::broadcast_twiddle(4, 9, inner_fft.fft_direction());
614
615 Self {
616 twiddles_butterfly9: [
617 AvxVector::broadcast_twiddle(1, 9, inverse),
618 AvxVector::broadcast_twiddle(2, 9, inverse),
619 AvxVector::broadcast_twiddle(4, 9, inverse),
620 ],
621 twiddles_butterfly9_lo: [
622 AvxVector256::merge(twiddle1, twiddle2),
623 AvxVector256::merge(twiddle2, twiddle4),
624 ],
625 twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, inner_fft.fft_direction()),
626 common_data: mixedradix_gen_data!(9, inner_fft),
627 _phantom: std::marker::PhantomData,
628 }
629 }
630
631 mixedradix_column_butterflies!(
632 9,
633 |columns, this: &Self| AvxVector256::column_butterfly9(
634 columns,
635 this.twiddles_butterfly9,
636 this.twiddles_butterfly3
637 ),
638 |columns, this: &Self| AvxVector128::column_butterfly9(
639 columns,
640 this.twiddles_butterfly9_lo,
641 this.twiddles_butterfly3
642 )
643 );
644 mixedradix_transpose!(9,
645 AvxVector::transpose9_packed,
646 AvxVector::transpose9_packed,
647 0;1;2;3;4;5;6;7;8, 0;1;2;3;4;5
648 );
649 boilerplate_mixedradix!();
650}
651
652pub struct MixedRadix11xnAvx<A: AvxNum, T> {
653 twiddles_butterfly11: [A::VectorType; 5],
654 common_data: CommonSimdData<T, A::VectorType>,
655 _phantom: std::marker::PhantomData<T>,
656}
657boilerplate_avx_fft_commondata!(MixedRadix11xnAvx);
658
659impl<A: AvxNum, T: FftNum> MixedRadix11xnAvx<A, T> {
660 #[target_feature(enable = "avx")]
661 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
662 Self {
663 twiddles_butterfly11: [
664 AvxVector::broadcast_twiddle(1, 11, inner_fft.fft_direction()),
665 AvxVector::broadcast_twiddle(2, 11, inner_fft.fft_direction()),
666 AvxVector::broadcast_twiddle(3, 11, inner_fft.fft_direction()),
667 AvxVector::broadcast_twiddle(4, 11, inner_fft.fft_direction()),
668 AvxVector::broadcast_twiddle(5, 11, inner_fft.fft_direction()),
669 ],
670 common_data: mixedradix_gen_data!(11, inner_fft),
671 _phantom: std::marker::PhantomData,
672 }
673 }
674 mixedradix_column_butterflies!(
675 11,
676 |columns, this: &Self| AvxVector::column_butterfly11(columns, this.twiddles_butterfly11),
677 |columns, this: &Self| AvxVector::column_butterfly11(
678 columns,
679 [
680 this.twiddles_butterfly11[0].lo(),
681 this.twiddles_butterfly11[1].lo(),
682 this.twiddles_butterfly11[2].lo(),
683 this.twiddles_butterfly11[3].lo(),
684 this.twiddles_butterfly11[4].lo()
685 ]
686 )
687 );
688 mixedradix_transpose!(11,
689 AvxVector::transpose11_packed,
690 AvxVector::transpose11_packed,
691 0;1;2;3;4;5;6;7;8;9;10, 0;1;2;3;4;5;6;7
692 );
693 boilerplate_mixedradix!();
694}
695
696pub struct MixedRadix12xnAvx<A: AvxNum, T> {
697 twiddles_butterfly4: Rotation90<A::VectorType>,
698 twiddles_butterfly3: A::VectorType,
699 common_data: CommonSimdData<T, A::VectorType>,
700 _phantom: std::marker::PhantomData<T>,
701}
702boilerplate_avx_fft_commondata!(MixedRadix12xnAvx);
703
704impl<A: AvxNum, T: FftNum> MixedRadix12xnAvx<A, T> {
705 #[target_feature(enable = "avx")]
706 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
707 let inverse = inner_fft.fft_direction();
708 Self {
709 twiddles_butterfly4: AvxVector::make_rotation90(inverse),
710 twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, inverse),
711 common_data: mixedradix_gen_data!(12, inner_fft),
712 _phantom: std::marker::PhantomData,
713 }
714 }
715
716 mixedradix_column_butterflies!(
717 12,
718 |columns, this: &Self| AvxVector256::column_butterfly12(
719 columns,
720 this.twiddles_butterfly3,
721 this.twiddles_butterfly4
722 ),
723 |columns, this: &Self| AvxVector128::column_butterfly12(
724 columns,
725 this.twiddles_butterfly3,
726 this.twiddles_butterfly4
727 )
728 );
729 mixedradix_transpose!(12,
730 AvxVector::transpose12_packed,
731 AvxVector::transpose12_packed,
732 0;1;2;3;4;5;6;7;8;9;10;11, 0;1;2;3;4;5;6;7;8
733 );
734 boilerplate_mixedradix!();
735}
736
737pub struct MixedRadix16xnAvx<A: AvxNum, T> {
738 twiddles_butterfly4: Rotation90<A::VectorType>,
739 twiddles_butterfly16: [A::VectorType; 2],
740 common_data: CommonSimdData<T, A::VectorType>,
741 _phantom: std::marker::PhantomData<T>,
742}
743boilerplate_avx_fft_commondata!(MixedRadix16xnAvx);
744
745impl<A: AvxNum, T: FftNum> MixedRadix16xnAvx<A, T> {
746 #[target_feature(enable = "avx")]
747 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
748 let inverse = inner_fft.fft_direction();
749 Self {
750 twiddles_butterfly4: AvxVector::make_rotation90(inner_fft.fft_direction()),
751 twiddles_butterfly16: [
752 AvxVector::broadcast_twiddle(1, 16, inverse),
753 AvxVector::broadcast_twiddle(3, 16, inverse),
754 ],
755 common_data: mixedradix_gen_data!(16, inner_fft),
756 _phantom: std::marker::PhantomData,
757 }
758 }
759
760 #[target_feature(enable = "avx", enable = "fma")]
761 unsafe fn perform_column_butterflies(&self, mut buffer: impl AvxArrayMut<A>) {
762 const ROW_COUNT: usize = 16;
764 const TWIDDLES_PER_COLUMN: usize = ROW_COUNT - 1;
765
766 let len_per_row = self.len() / ROW_COUNT;
767 let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
768
769 for (c, twiddle_chunk) in self
771 .common_data
772 .twiddles
773 .chunks_exact(TWIDDLES_PER_COLUMN)
774 .take(chunk_count)
775 .enumerate()
776 {
777 let index_base = c * A::VectorType::COMPLEX_PER_VECTOR;
778
779 column_butterfly16_loadfn!(
780 |index| buffer.load_complex(index_base + len_per_row * index),
781 |mut data, index| {
782 if index > 0 {
783 data = AvxVector::mul_complex(data, twiddle_chunk[index - 1]);
784 }
785 buffer.store_complex(data, index_base + len_per_row * index)
786 },
787 self.twiddles_butterfly16,
788 self.twiddles_butterfly4
789 );
790 }
791
792 let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
795 if partial_remainder > 0 {
796 let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
797 let partial_remainder_twiddle_base =
798 self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN;
799 let final_twiddle_chunk = &self.common_data.twiddles[partial_remainder_twiddle_base..];
800
801 match partial_remainder {
802 1 => {
803 column_butterfly16_loadfn!(
804 |index| buffer
805 .load_partial1_complex(partial_remainder_base + len_per_row * index),
806 |mut data, index| {
807 if index > 0 {
808 let twiddle: A::VectorType = final_twiddle_chunk[index - 1];
809 data = AvxVector::mul_complex(data, twiddle.lo());
810 }
811 buffer.store_partial1_complex(
812 data,
813 partial_remainder_base + len_per_row * index,
814 )
815 },
816 [
817 self.twiddles_butterfly16[0].lo(),
818 self.twiddles_butterfly16[1].lo()
819 ],
820 self.twiddles_butterfly4.lo()
821 );
822 }
823 2 => {
824 column_butterfly16_loadfn!(
825 |index| buffer
826 .load_partial2_complex(partial_remainder_base + len_per_row * index),
827 |mut data, index| {
828 if index > 0 {
829 let twiddle: A::VectorType = final_twiddle_chunk[index - 1];
830 data = AvxVector::mul_complex(data, twiddle.lo());
831 }
832 buffer.store_partial2_complex(
833 data,
834 partial_remainder_base + len_per_row * index,
835 )
836 },
837 [
838 self.twiddles_butterfly16[0].lo(),
839 self.twiddles_butterfly16[1].lo()
840 ],
841 self.twiddles_butterfly4.lo()
842 );
843 }
844 3 => {
845 column_butterfly16_loadfn!(
846 |index| buffer
847 .load_partial3_complex(partial_remainder_base + len_per_row * index),
848 |mut data, index| {
849 if index > 0 {
850 data = AvxVector::mul_complex(data, final_twiddle_chunk[index - 1]);
851 }
852 buffer.store_partial3_complex(
853 data,
854 partial_remainder_base + len_per_row * index,
855 )
856 },
857 self.twiddles_butterfly16,
858 self.twiddles_butterfly4
859 );
860 }
861 _ => unreachable!(),
862 }
863 }
864 }
865 mixedradix_transpose!(16,
866 AvxVector::transpose16_packed,
867 AvxVector::transpose16_packed,
868 0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15, 0;1;2;3;4;5;6;7;8;9;10;11
869 );
870 boilerplate_mixedradix!();
871}
872
873#[cfg(test)]
874mod unit_tests {
875 use super::*;
876 use crate::algorithm::*;
877 use crate::test_utils::check_fft_algorithm;
878 use std::sync::Arc;
879
880 macro_rules! test_avx_mixed_radix {
881 ($f32_test_name:ident, $f64_test_name:ident, $struct_name:ident, $inner_count:expr) => (
882 #[test]
883 fn $f32_test_name() {
884 for inner_fft_len in 1..32 {
885 let len = inner_fft_len * $inner_count;
886
887 let inner_fft_forward = Arc::new(Dft::new(inner_fft_len, FftDirection::Forward)) as Arc<dyn Fft<f32>>;
888 let fft_forward = $struct_name::<f32, f32>::new(inner_fft_forward).expect("Can't run test because this machine doesn't have the required instruction sets");
889 check_fft_algorithm(&fft_forward, len, FftDirection::Forward);
890
891 let inner_fft_inverse = Arc::new(Dft::new(inner_fft_len, FftDirection::Inverse)) as Arc<dyn Fft<f32>>;
892 let fft_inverse = $struct_name::<f32, f32>::new(inner_fft_inverse).expect("Can't run test because this machine doesn't have the required instruction sets");
893 check_fft_algorithm(&fft_inverse, len, FftDirection::Inverse);
894 }
895 }
896 #[test]
897 fn $f64_test_name() {
898 for inner_fft_len in 1..32 {
899 let len = inner_fft_len * $inner_count;
900
901 let inner_fft_forward = Arc::new(Dft::new(inner_fft_len, FftDirection::Forward)) as Arc<dyn Fft<f64>>;
902 let fft_forward = $struct_name::<f64, f64>::new(inner_fft_forward).expect("Can't run test because this machine doesn't have the required instruction sets");
903 check_fft_algorithm(&fft_forward, len, FftDirection::Forward);
904
905 let inner_fft_inverse = Arc::new(Dft::new(inner_fft_len, FftDirection::Inverse)) as Arc<dyn Fft<f64>>;
906 let fft_inverse = $struct_name::<f64, f64>::new(inner_fft_inverse).expect("Can't run test because this machine doesn't have the required instruction sets");
907 check_fft_algorithm(&fft_inverse, len, FftDirection::Inverse);
908 }
909 }
910 )
911 }
912
913 test_avx_mixed_radix!(
914 test_mixedradix_2xn_avx_f32,
915 test_mixedradix_2xn_avx_f64,
916 MixedRadix2xnAvx,
917 2
918 );
919 test_avx_mixed_radix!(
920 test_mixedradix_3xn_avx_f32,
921 test_mixedradix_3xn_avx_f64,
922 MixedRadix3xnAvx,
923 3
924 );
925 test_avx_mixed_radix!(
926 test_mixedradix_4xn_avx_f32,
927 test_mixedradix_4xn_avx_f64,
928 MixedRadix4xnAvx,
929 4
930 );
931 test_avx_mixed_radix!(
932 test_mixedradix_5xn_avx_f32,
933 test_mixedradix_5xn_avx_f64,
934 MixedRadix5xnAvx,
935 5
936 );
937 test_avx_mixed_radix!(
938 test_mixedradix_6xn_avx_f32,
939 test_mixedradix_6xn_avx_f64,
940 MixedRadix6xnAvx,
941 6
942 );
943 test_avx_mixed_radix!(
944 test_mixedradix_7xn_avx_f32,
945 test_mixedradix_7xn_avx_f64,
946 MixedRadix7xnAvx,
947 7
948 );
949 test_avx_mixed_radix!(
950 test_mixedradix_8xn_avx_f32,
951 test_mixedradix_8xn_avx_f64,
952 MixedRadix8xnAvx,
953 8
954 );
955 test_avx_mixed_radix!(
956 test_mixedradix_9xn_avx_f32,
957 test_mixedradix_9xn_avx_f64,
958 MixedRadix9xnAvx,
959 9
960 );
961 test_avx_mixed_radix!(
962 test_mixedradix_11xn_avx_f32,
963 test_mixedradix_11xn_avx_f64,
964 MixedRadix11xnAvx,
965 11
966 );
967 test_avx_mixed_radix!(
968 test_mixedradix_12xn_avx_f32,
969 test_mixedradix_12xn_avx_f64,
970 MixedRadix12xnAvx,
971 12
972 );
973 test_avx_mixed_radix!(
974 test_mixedradix_16xn_avx_f32,
975 test_mixedradix_16xn_avx_f64,
976 MixedRadix16xnAvx,
977 16
978 );
979}