1use std::sync::Arc;
2
3use num_complex::Complex;
4
5use crate::array_utils::{self, factor_transpose, Load, LoadStore, TransposeFactor};
6use crate::common::{fft_error_inplace, fft_error_outofplace, RadixFactor};
7use crate::{common::FftNum, twiddles, FftDirection};
8use crate::{Direction, Fft, Length};
9
10use super::butterflies::{Butterfly2, Butterfly3, Butterfly4, Butterfly5, Butterfly6, Butterfly7};
11
12#[repr(u8)]
13enum InternalRadixFactor<T> {
14 Factor2(Butterfly2<T>),
15 Factor3(Butterfly3<T>),
16 Factor4(Butterfly4<T>),
17 Factor5(Butterfly5<T>),
18 Factor6(Butterfly6<T>),
19 Factor7(Butterfly7<T>),
20}
21impl<T> InternalRadixFactor<T> {
22 pub const fn radix(&self) -> usize {
23 match self {
25 InternalRadixFactor::Factor2(_) => 2,
26 InternalRadixFactor::Factor3(_) => 3,
27 InternalRadixFactor::Factor4(_) => 4,
28 InternalRadixFactor::Factor5(_) => 5,
29 InternalRadixFactor::Factor6(_) => 6,
30 InternalRadixFactor::Factor7(_) => 7,
31 }
32 }
33}
34
35pub(crate) struct RadixN<T> {
36 twiddles: Box<[Complex<T>]>,
37
38 base_fft: Arc<dyn Fft<T>>,
39 base_len: usize,
40
41 factors: Box<[TransposeFactor]>,
42 butterflies: Box<[InternalRadixFactor<T>]>,
43
44 len: usize,
45 direction: FftDirection,
46 inplace_scratch_len: usize,
47 outofplace_scratch_len: usize,
48}
49
50impl<T: FftNum> RadixN<T> {
51 pub fn new(factors: &[RadixFactor], base_fft: Arc<dyn Fft<T>>) -> Self {
53 let base_len = base_fft.len();
54 let direction = base_fft.fft_direction();
55
56 let mut butterflies = Vec::with_capacity(factors.len());
58 let mut cross_fft_len = base_len;
59 let mut twiddle_count = 0;
60
61 for factor in factors {
62 let cross_fft_rows = factor.radix();
64 let cross_fft_columns = cross_fft_len;
65
66 twiddle_count += cross_fft_columns * (cross_fft_rows - 1);
67
68 let butterfly = match factor {
70 RadixFactor::Factor2 => InternalRadixFactor::Factor2(Butterfly2::new(direction)),
71 RadixFactor::Factor3 => InternalRadixFactor::Factor3(Butterfly3::new(direction)),
72 RadixFactor::Factor4 => InternalRadixFactor::Factor4(Butterfly4::new(direction)),
73 RadixFactor::Factor5 => InternalRadixFactor::Factor5(Butterfly5::new(direction)),
74 RadixFactor::Factor6 => InternalRadixFactor::Factor6(Butterfly6::new(direction)),
75 RadixFactor::Factor7 => InternalRadixFactor::Factor7(Butterfly7::new(direction)),
76 };
77 butterflies.push(butterfly);
78
79 cross_fft_len *= cross_fft_rows;
80 }
81 let len = cross_fft_len;
82
83 let mut transpose_factors: Vec<TransposeFactor> = Vec::with_capacity(factors.len());
87 for f in factors.iter().rev() {
88 let mut push_new = true;
90 if let Some(last) = transpose_factors.last_mut() {
91 if last.factor == *f {
92 last.count += 1;
93 push_new = false;
94 }
95 }
96 if push_new {
97 transpose_factors.push(TransposeFactor {
98 factor: *f,
99 count: 1,
100 });
101 }
102 }
103
104 let mut cross_fft_len = base_len;
109 let mut twiddle_factors = Vec::with_capacity(twiddle_count);
110
111 for factor in factors {
112 let cross_fft_columns = cross_fft_len;
114 cross_fft_len *= factor.radix();
115
116 for i in 0..cross_fft_columns {
117 for k in 1..factor.radix() {
118 let twiddle = twiddles::compute_twiddle(i * k, cross_fft_len, direction);
119 twiddle_factors.push(twiddle);
120 }
121 }
122 }
123
124 let base_inplace_scratch = base_fft.get_inplace_scratch_len();
126 let inplace_scratch_len = if base_inplace_scratch > len {
127 len + base_inplace_scratch
128 } else {
129 len
130 };
131 let outofplace_scratch_len = if base_inplace_scratch > len {
132 base_inplace_scratch
133 } else {
134 0
135 };
136
137 Self {
138 twiddles: twiddle_factors.into_boxed_slice(),
139
140 base_fft,
141 base_len,
142
143 factors: transpose_factors.into_boxed_slice(),
144 butterflies: butterflies.into_boxed_slice(),
145
146 len,
147 direction,
148
149 inplace_scratch_len,
150 outofplace_scratch_len,
151 }
152 }
153
154 fn inplace_scratch_len(&self) -> usize {
155 self.inplace_scratch_len
156 }
157 fn outofplace_scratch_len(&self) -> usize {
158 self.outofplace_scratch_len
159 }
160
161 fn perform_fft_out_of_place(
162 &self,
163 input: &mut [Complex<T>],
164 output: &mut [Complex<T>],
165 scratch: &mut [Complex<T>],
166 ) {
167 if let Some(unroll_factor) = self.factors.first() {
168 match unroll_factor.factor {
171 RadixFactor::Factor2 => {
172 factor_transpose::<Complex<T>, 2>(self.base_len, input, output, &self.factors)
173 }
174 RadixFactor::Factor3 => {
175 factor_transpose::<Complex<T>, 3>(self.base_len, input, output, &self.factors)
176 }
177 RadixFactor::Factor4 => {
178 factor_transpose::<Complex<T>, 4>(self.base_len, input, output, &self.factors)
179 }
180 RadixFactor::Factor5 => {
181 factor_transpose::<Complex<T>, 5>(self.base_len, input, output, &self.factors)
182 }
183 RadixFactor::Factor6 => {
184 factor_transpose::<Complex<T>, 6>(self.base_len, input, output, &self.factors)
185 }
186 RadixFactor::Factor7 => {
187 factor_transpose::<Complex<T>, 7>(self.base_len, input, output, &self.factors)
188 }
189 }
190 } else {
191 output.copy_from_slice(input);
193 }
194
195 let base_scratch = if scratch.len() > 0 { scratch } else { input };
197 self.base_fft.process_with_scratch(output, base_scratch);
198
199 let mut cross_fft_len = self.base_len;
201 let mut layer_twiddles: &[Complex<T>] = &self.twiddles;
202
203 for factor in self.butterflies.iter() {
204 let cross_fft_columns = cross_fft_len;
205 cross_fft_len *= factor.radix();
206
207 match factor {
208 InternalRadixFactor::Factor2(butterfly2) => {
209 for data in output.chunks_exact_mut(cross_fft_len) {
210 unsafe { butterfly_2(data, layer_twiddles, cross_fft_columns, butterfly2) }
211 }
212 }
213 InternalRadixFactor::Factor3(butterfly3) => {
214 for data in output.chunks_exact_mut(cross_fft_len) {
215 unsafe { butterfly_3(data, layer_twiddles, cross_fft_columns, butterfly3) }
216 }
217 }
218 InternalRadixFactor::Factor4(butterfly4) => {
219 for data in output.chunks_exact_mut(cross_fft_len) {
220 unsafe { butterfly_4(data, layer_twiddles, cross_fft_columns, butterfly4) }
221 }
222 }
223 InternalRadixFactor::Factor5(butterfly5) => {
224 for data in output.chunks_exact_mut(cross_fft_len) {
225 unsafe { butterfly_5(data, layer_twiddles, cross_fft_columns, butterfly5) }
226 }
227 }
228 InternalRadixFactor::Factor6(butterfly6) => {
229 for data in output.chunks_exact_mut(cross_fft_len) {
230 unsafe { butterfly_6(data, layer_twiddles, cross_fft_columns, butterfly6) }
231 }
232 }
233 InternalRadixFactor::Factor7(butterfly7) => {
234 for data in output.chunks_exact_mut(cross_fft_len) {
235 unsafe { butterfly_7(data, layer_twiddles, cross_fft_columns, butterfly7) }
236 }
237 }
238 }
239
240 let twiddle_offset = cross_fft_columns * (factor.radix() - 1);
242 layer_twiddles = &layer_twiddles[twiddle_offset..];
243 }
244 }
245}
246boilerplate_fft_oop!(RadixN, |this: &RadixN<_>| this.len);
247
248#[inline(never)]
249pub(crate) unsafe fn butterfly_2<T: FftNum>(
250 mut data: impl LoadStore<T>,
251 twiddles: impl Load<T>,
252 num_columns: usize,
253 butterfly2: &Butterfly2<T>,
254) {
255 for idx in 0..num_columns {
256 let mut scratch = [
257 data.load(idx + 0 * num_columns),
258 data.load(idx + 1 * num_columns) * twiddles.load(idx),
259 ];
260
261 butterfly2.perform_fft_butterfly(&mut scratch);
262
263 data.store(scratch[0], idx + num_columns * 0);
264 data.store(scratch[1], idx + num_columns * 1);
265 }
266}
267
268#[inline(never)]
269pub(crate) unsafe fn butterfly_3<T: FftNum>(
270 mut data: impl LoadStore<T>,
271 twiddles: impl Load<T>,
272 num_columns: usize,
273 butterfly3: &Butterfly3<T>,
274) {
275 for idx in 0..num_columns {
276 let tw_idx = idx * 2;
277 let mut scratch = [
278 data.load(idx + 0 * num_columns),
279 data.load(idx + 1 * num_columns) * twiddles.load(tw_idx + 0),
280 data.load(idx + 2 * num_columns) * twiddles.load(tw_idx + 1),
281 ];
282
283 butterfly3.perform_fft_butterfly(&mut scratch);
284
285 data.store(scratch[0], idx + 0 * num_columns);
286 data.store(scratch[1], idx + 1 * num_columns);
287 data.store(scratch[2], idx + 2 * num_columns);
288 }
289}
290
291#[inline(never)]
292pub(crate) unsafe fn butterfly_4<T: FftNum>(
293 mut data: impl LoadStore<T>,
294 twiddles: impl Load<T>,
295 num_columns: usize,
296 butterfly4: &Butterfly4<T>,
297) {
298 for idx in 0..num_columns {
299 let tw_idx = idx * 3;
300 let mut scratch = [
301 data.load(idx + 0 * num_columns),
302 data.load(idx + 1 * num_columns) * twiddles.load(tw_idx + 0),
303 data.load(idx + 2 * num_columns) * twiddles.load(tw_idx + 1),
304 data.load(idx + 3 * num_columns) * twiddles.load(tw_idx + 2),
305 ];
306
307 butterfly4.perform_fft_butterfly(&mut scratch);
308
309 data.store(scratch[0], idx + 0 * num_columns);
310 data.store(scratch[1], idx + 1 * num_columns);
311 data.store(scratch[2], idx + 2 * num_columns);
312 data.store(scratch[3], idx + 3 * num_columns);
313 }
314}
315
316#[inline(never)]
317pub(crate) unsafe fn butterfly_5<T: FftNum>(
318 mut data: impl LoadStore<T>,
319 twiddles: impl Load<T>,
320 num_columns: usize,
321 butterfly5: &Butterfly5<T>,
322) {
323 for idx in 0..num_columns {
324 let tw_idx = idx * 4;
325 let mut scratch = [
326 data.load(idx + 0 * num_columns),
327 data.load(idx + 1 * num_columns) * twiddles.load(tw_idx + 0),
328 data.load(idx + 2 * num_columns) * twiddles.load(tw_idx + 1),
329 data.load(idx + 3 * num_columns) * twiddles.load(tw_idx + 2),
330 data.load(idx + 4 * num_columns) * twiddles.load(tw_idx + 3),
331 ];
332
333 butterfly5.perform_fft_butterfly(&mut scratch);
334
335 data.store(scratch[0], idx + 0 * num_columns);
336 data.store(scratch[1], idx + 1 * num_columns);
337 data.store(scratch[2], idx + 2 * num_columns);
338 data.store(scratch[3], idx + 3 * num_columns);
339 data.store(scratch[4], idx + 4 * num_columns);
340 }
341}
342
343#[inline(never)]
344pub(crate) unsafe fn butterfly_6<T: FftNum>(
345 mut data: impl LoadStore<T>,
346 twiddles: impl Load<T>,
347 num_columns: usize,
348 butterfly6: &Butterfly6<T>,
349) {
350 for idx in 0..num_columns {
351 let tw_idx = idx * 5;
352 let mut scratch = [
353 data.load(idx + 0 * num_columns),
354 data.load(idx + 1 * num_columns) * twiddles.load(tw_idx + 0),
355 data.load(idx + 2 * num_columns) * twiddles.load(tw_idx + 1),
356 data.load(idx + 3 * num_columns) * twiddles.load(tw_idx + 2),
357 data.load(idx + 4 * num_columns) * twiddles.load(tw_idx + 3),
358 data.load(idx + 5 * num_columns) * twiddles.load(tw_idx + 4),
359 ];
360
361 butterfly6.perform_fft_butterfly(&mut scratch);
362
363 data.store(scratch[0], idx + 0 * num_columns);
364 data.store(scratch[1], idx + 1 * num_columns);
365 data.store(scratch[2], idx + 2 * num_columns);
366 data.store(scratch[3], idx + 3 * num_columns);
367 data.store(scratch[4], idx + 4 * num_columns);
368 data.store(scratch[5], idx + 5 * num_columns);
369 }
370}
371
372#[inline(never)]
373pub(crate) unsafe fn butterfly_7<T: FftNum>(
374 mut data: impl LoadStore<T>,
375 twiddles: impl Load<T>,
376 num_columns: usize,
377 butterfly7: &Butterfly7<T>,
378) {
379 for idx in 0..num_columns {
380 let tw_idx = idx * 6;
381 let mut scratch = [
382 data.load(idx + 0 * num_columns),
383 data.load(idx + 1 * num_columns) * twiddles.load(tw_idx + 0),
384 data.load(idx + 2 * num_columns) * twiddles.load(tw_idx + 1),
385 data.load(idx + 3 * num_columns) * twiddles.load(tw_idx + 2),
386 data.load(idx + 4 * num_columns) * twiddles.load(tw_idx + 3),
387 data.load(idx + 5 * num_columns) * twiddles.load(tw_idx + 4),
388 data.load(idx + 6 * num_columns) * twiddles.load(tw_idx + 5),
389 ];
390
391 butterfly7.perform_fft_butterfly(&mut scratch);
392
393 data.store(scratch[0], idx + 0 * num_columns);
394 data.store(scratch[1], idx + 1 * num_columns);
395 data.store(scratch[2], idx + 2 * num_columns);
396 data.store(scratch[3], idx + 3 * num_columns);
397 data.store(scratch[4], idx + 4 * num_columns);
398 data.store(scratch[5], idx + 5 * num_columns);
399 data.store(scratch[6], idx + 6 * num_columns);
400 }
401}
402
403#[cfg(test)]
404mod unit_tests {
405 use super::*;
406 use crate::test_utils::{check_fft_algorithm, construct_base};
407
408 #[test]
409 fn test_scalar_radixn() {
410 let factor_list = &[
411 RadixFactor::Factor2,
412 RadixFactor::Factor3,
413 RadixFactor::Factor4,
414 RadixFactor::Factor5,
415 RadixFactor::Factor6,
416 RadixFactor::Factor7,
417 ];
418
419 for base in 1..7 {
420 let base_forward = construct_base(base, FftDirection::Forward);
421 let base_inverse = construct_base(base, FftDirection::Inverse);
422
423 test_radixn(&[], Arc::clone(&base_forward));
425 test_radixn(&[], Arc::clone(&base_inverse));
426
427 for factor_a in factor_list {
429 let factors = &[*factor_a];
430 test_radixn(factors, Arc::clone(&base_forward));
431 test_radixn(factors, Arc::clone(&base_inverse));
432 }
433
434 for factor_a in factor_list {
436 for factor_b in factor_list {
437 let factors = &[*factor_a, *factor_b];
438 test_radixn(factors, Arc::clone(&base_forward));
439 test_radixn(factors, Arc::clone(&base_inverse));
440 }
441 }
442 }
443 }
444
445 fn test_radixn(factors: &[RadixFactor], base_fft: Arc<dyn Fft<f64>>) {
446 let len = base_fft.len() * factors.iter().map(|f| f.radix()).product::<usize>();
447 let direction = base_fft.fft_direction();
448 let fft = RadixN::new(factors, base_fft);
449
450 check_fft_algorithm::<f64>(&fft, len, direction);
451 }
452}