rustfft/algorithm/
radixn.rs

1use std::sync::Arc;
2
3use num_complex::Complex;
4
5use crate::array_utils::{self, factor_transpose, Load, LoadStore, TransposeFactor};
6use crate::common::{fft_error_inplace, fft_error_outofplace, RadixFactor};
7use crate::{common::FftNum, twiddles, FftDirection};
8use crate::{Direction, Fft, Length};
9
10use super::butterflies::{Butterfly2, Butterfly3, Butterfly4, Butterfly5, Butterfly6, Butterfly7};
11
12#[repr(u8)]
13enum InternalRadixFactor<T> {
14    Factor2(Butterfly2<T>),
15    Factor3(Butterfly3<T>),
16    Factor4(Butterfly4<T>),
17    Factor5(Butterfly5<T>),
18    Factor6(Butterfly6<T>),
19    Factor7(Butterfly7<T>),
20}
21impl<T> InternalRadixFactor<T> {
22    pub const fn radix(&self) -> usize {
23        // note: if we had rustc 1.66, we could just turn these values explicit discriminators on the enum
24        match self {
25            InternalRadixFactor::Factor2(_) => 2,
26            InternalRadixFactor::Factor3(_) => 3,
27            InternalRadixFactor::Factor4(_) => 4,
28            InternalRadixFactor::Factor5(_) => 5,
29            InternalRadixFactor::Factor6(_) => 6,
30            InternalRadixFactor::Factor7(_) => 7,
31        }
32    }
33}
34
35pub(crate) struct RadixN<T> {
36    twiddles: Box<[Complex<T>]>,
37
38    base_fft: Arc<dyn Fft<T>>,
39    base_len: usize,
40
41    factors: Box<[TransposeFactor]>,
42    butterflies: Box<[InternalRadixFactor<T>]>,
43
44    len: usize,
45    direction: FftDirection,
46    inplace_scratch_len: usize,
47    outofplace_scratch_len: usize,
48}
49
50impl<T: FftNum> RadixN<T> {
51    /// Constructs a RadixN instance which computes FFTs of length `factor_product * base_fft.len()`
52    pub fn new(factors: &[RadixFactor], base_fft: Arc<dyn Fft<T>>) -> Self {
53        let base_len = base_fft.len();
54        let direction = base_fft.fft_direction();
55
56        // set up our cross FFT butterfly instances. simultaneously, compute the number of twiddle factors
57        let mut butterflies = Vec::with_capacity(factors.len());
58        let mut cross_fft_len = base_len;
59        let mut twiddle_count = 0;
60
61        for factor in factors {
62            // compute how many twiddles this cross-FFT needs
63            let cross_fft_rows = factor.radix();
64            let cross_fft_columns = cross_fft_len;
65
66            twiddle_count += cross_fft_columns * (cross_fft_rows - 1);
67
68            // set up the butterfly for this cross-FFT
69            let butterfly = match factor {
70                RadixFactor::Factor2 => InternalRadixFactor::Factor2(Butterfly2::new(direction)),
71                RadixFactor::Factor3 => InternalRadixFactor::Factor3(Butterfly3::new(direction)),
72                RadixFactor::Factor4 => InternalRadixFactor::Factor4(Butterfly4::new(direction)),
73                RadixFactor::Factor5 => InternalRadixFactor::Factor5(Butterfly5::new(direction)),
74                RadixFactor::Factor6 => InternalRadixFactor::Factor6(Butterfly6::new(direction)),
75                RadixFactor::Factor7 => InternalRadixFactor::Factor7(Butterfly7::new(direction)),
76            };
77            butterflies.push(butterfly);
78
79            cross_fft_len *= cross_fft_rows;
80        }
81        let len = cross_fft_len;
82
83        // set up our list of transpose factors - it's the same list but reversed, and we want to collapse duplicates
84        // Note that we are only de-duplicating adjacent factors. If we're passed 7 * 2 * 7, we can't collapse the sevens
85        // because the exact order of factors is is important for the transpose
86        let mut transpose_factors: Vec<TransposeFactor> = Vec::with_capacity(factors.len());
87        for f in factors.iter().rev() {
88            // I really want let chains for this!
89            let mut push_new = true;
90            if let Some(last) = transpose_factors.last_mut() {
91                if last.factor == *f {
92                    last.count += 1;
93                    push_new = false;
94                }
95            }
96            if push_new {
97                transpose_factors.push(TransposeFactor {
98                    factor: *f,
99                    count: 1,
100                });
101            }
102        }
103
104        // precompute the twiddle factors this algorithm will use.
105        // we're doing the same precomputation of twiddle factors as the mixed radix algorithm where width=factor.radix() and height=len/factor.radix()
106        // but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down
107        // so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom layer and going up
108        let mut cross_fft_len = base_len;
109        let mut twiddle_factors = Vec::with_capacity(twiddle_count);
110
111        for factor in factors {
112            // Compute the twiddle factors for the cross FFT
113            let cross_fft_columns = cross_fft_len;
114            cross_fft_len *= factor.radix();
115
116            for i in 0..cross_fft_columns {
117                for k in 1..factor.radix() {
118                    let twiddle = twiddles::compute_twiddle(i * k, cross_fft_len, direction);
119                    twiddle_factors.push(twiddle);
120                }
121            }
122        }
123
124        // figure out how much scratch space we need to request from callers
125        let base_inplace_scratch = base_fft.get_inplace_scratch_len();
126        let inplace_scratch_len = if base_inplace_scratch > len {
127            len + base_inplace_scratch
128        } else {
129            len
130        };
131        let outofplace_scratch_len = if base_inplace_scratch > len {
132            base_inplace_scratch
133        } else {
134            0
135        };
136
137        Self {
138            twiddles: twiddle_factors.into_boxed_slice(),
139
140            base_fft,
141            base_len,
142
143            factors: transpose_factors.into_boxed_slice(),
144            butterflies: butterflies.into_boxed_slice(),
145
146            len,
147            direction,
148
149            inplace_scratch_len,
150            outofplace_scratch_len,
151        }
152    }
153
154    fn inplace_scratch_len(&self) -> usize {
155        self.inplace_scratch_len
156    }
157    fn outofplace_scratch_len(&self) -> usize {
158        self.outofplace_scratch_len
159    }
160
161    fn perform_fft_out_of_place(
162        &self,
163        input: &mut [Complex<T>],
164        output: &mut [Complex<T>],
165        scratch: &mut [Complex<T>],
166    ) {
167        if let Some(unroll_factor) = self.factors.first() {
168            // for performance, we really, really want to unroll the transpose, but we need to make sure the output length is divisible by the unroll amount
169            // choosing the first factor seems to reliably perform well
170            match unroll_factor.factor {
171                RadixFactor::Factor2 => {
172                    factor_transpose::<Complex<T>, 2>(self.base_len, input, output, &self.factors)
173                }
174                RadixFactor::Factor3 => {
175                    factor_transpose::<Complex<T>, 3>(self.base_len, input, output, &self.factors)
176                }
177                RadixFactor::Factor4 => {
178                    factor_transpose::<Complex<T>, 4>(self.base_len, input, output, &self.factors)
179                }
180                RadixFactor::Factor5 => {
181                    factor_transpose::<Complex<T>, 5>(self.base_len, input, output, &self.factors)
182                }
183                RadixFactor::Factor6 => {
184                    factor_transpose::<Complex<T>, 6>(self.base_len, input, output, &self.factors)
185                }
186                RadixFactor::Factor7 => {
187                    factor_transpose::<Complex<T>, 7>(self.base_len, input, output, &self.factors)
188                }
189            }
190        } else {
191            // no factors, so just pass data straight to our base
192            output.copy_from_slice(input);
193        }
194
195        // Base-level FFTs
196        let base_scratch = if scratch.len() > 0 { scratch } else { input };
197        self.base_fft.process_with_scratch(output, base_scratch);
198
199        // cross-FFTs
200        let mut cross_fft_len = self.base_len;
201        let mut layer_twiddles: &[Complex<T>] = &self.twiddles;
202
203        for factor in self.butterflies.iter() {
204            let cross_fft_columns = cross_fft_len;
205            cross_fft_len *= factor.radix();
206
207            match factor {
208                InternalRadixFactor::Factor2(butterfly2) => {
209                    for data in output.chunks_exact_mut(cross_fft_len) {
210                        unsafe { butterfly_2(data, layer_twiddles, cross_fft_columns, butterfly2) }
211                    }
212                }
213                InternalRadixFactor::Factor3(butterfly3) => {
214                    for data in output.chunks_exact_mut(cross_fft_len) {
215                        unsafe { butterfly_3(data, layer_twiddles, cross_fft_columns, butterfly3) }
216                    }
217                }
218                InternalRadixFactor::Factor4(butterfly4) => {
219                    for data in output.chunks_exact_mut(cross_fft_len) {
220                        unsafe { butterfly_4(data, layer_twiddles, cross_fft_columns, butterfly4) }
221                    }
222                }
223                InternalRadixFactor::Factor5(butterfly5) => {
224                    for data in output.chunks_exact_mut(cross_fft_len) {
225                        unsafe { butterfly_5(data, layer_twiddles, cross_fft_columns, butterfly5) }
226                    }
227                }
228                InternalRadixFactor::Factor6(butterfly6) => {
229                    for data in output.chunks_exact_mut(cross_fft_len) {
230                        unsafe { butterfly_6(data, layer_twiddles, cross_fft_columns, butterfly6) }
231                    }
232                }
233                InternalRadixFactor::Factor7(butterfly7) => {
234                    for data in output.chunks_exact_mut(cross_fft_len) {
235                        unsafe { butterfly_7(data, layer_twiddles, cross_fft_columns, butterfly7) }
236                    }
237                }
238            }
239
240            // skip past all the twiddle factors used in this layer
241            let twiddle_offset = cross_fft_columns * (factor.radix() - 1);
242            layer_twiddles = &layer_twiddles[twiddle_offset..];
243        }
244    }
245}
246boilerplate_fft_oop!(RadixN, |this: &RadixN<_>| this.len);
247
248#[inline(never)]
249pub(crate) unsafe fn butterfly_2<T: FftNum>(
250    mut data: impl LoadStore<T>,
251    twiddles: impl Load<T>,
252    num_columns: usize,
253    butterfly2: &Butterfly2<T>,
254) {
255    for idx in 0..num_columns {
256        let mut scratch = [
257            data.load(idx + 0 * num_columns),
258            data.load(idx + 1 * num_columns) * twiddles.load(idx),
259        ];
260
261        butterfly2.perform_fft_butterfly(&mut scratch);
262
263        data.store(scratch[0], idx + num_columns * 0);
264        data.store(scratch[1], idx + num_columns * 1);
265    }
266}
267
268#[inline(never)]
269pub(crate) unsafe fn butterfly_3<T: FftNum>(
270    mut data: impl LoadStore<T>,
271    twiddles: impl Load<T>,
272    num_columns: usize,
273    butterfly3: &Butterfly3<T>,
274) {
275    for idx in 0..num_columns {
276        let tw_idx = idx * 2;
277        let mut scratch = [
278            data.load(idx + 0 * num_columns),
279            data.load(idx + 1 * num_columns) * twiddles.load(tw_idx + 0),
280            data.load(idx + 2 * num_columns) * twiddles.load(tw_idx + 1),
281        ];
282
283        butterfly3.perform_fft_butterfly(&mut scratch);
284
285        data.store(scratch[0], idx + 0 * num_columns);
286        data.store(scratch[1], idx + 1 * num_columns);
287        data.store(scratch[2], idx + 2 * num_columns);
288    }
289}
290
291#[inline(never)]
292pub(crate) unsafe fn butterfly_4<T: FftNum>(
293    mut data: impl LoadStore<T>,
294    twiddles: impl Load<T>,
295    num_columns: usize,
296    butterfly4: &Butterfly4<T>,
297) {
298    for idx in 0..num_columns {
299        let tw_idx = idx * 3;
300        let mut scratch = [
301            data.load(idx + 0 * num_columns),
302            data.load(idx + 1 * num_columns) * twiddles.load(tw_idx + 0),
303            data.load(idx + 2 * num_columns) * twiddles.load(tw_idx + 1),
304            data.load(idx + 3 * num_columns) * twiddles.load(tw_idx + 2),
305        ];
306
307        butterfly4.perform_fft_butterfly(&mut scratch);
308
309        data.store(scratch[0], idx + 0 * num_columns);
310        data.store(scratch[1], idx + 1 * num_columns);
311        data.store(scratch[2], idx + 2 * num_columns);
312        data.store(scratch[3], idx + 3 * num_columns);
313    }
314}
315
316#[inline(never)]
317pub(crate) unsafe fn butterfly_5<T: FftNum>(
318    mut data: impl LoadStore<T>,
319    twiddles: impl Load<T>,
320    num_columns: usize,
321    butterfly5: &Butterfly5<T>,
322) {
323    for idx in 0..num_columns {
324        let tw_idx = idx * 4;
325        let mut scratch = [
326            data.load(idx + 0 * num_columns),
327            data.load(idx + 1 * num_columns) * twiddles.load(tw_idx + 0),
328            data.load(idx + 2 * num_columns) * twiddles.load(tw_idx + 1),
329            data.load(idx + 3 * num_columns) * twiddles.load(tw_idx + 2),
330            data.load(idx + 4 * num_columns) * twiddles.load(tw_idx + 3),
331        ];
332
333        butterfly5.perform_fft_butterfly(&mut scratch);
334
335        data.store(scratch[0], idx + 0 * num_columns);
336        data.store(scratch[1], idx + 1 * num_columns);
337        data.store(scratch[2], idx + 2 * num_columns);
338        data.store(scratch[3], idx + 3 * num_columns);
339        data.store(scratch[4], idx + 4 * num_columns);
340    }
341}
342
343#[inline(never)]
344pub(crate) unsafe fn butterfly_6<T: FftNum>(
345    mut data: impl LoadStore<T>,
346    twiddles: impl Load<T>,
347    num_columns: usize,
348    butterfly6: &Butterfly6<T>,
349) {
350    for idx in 0..num_columns {
351        let tw_idx = idx * 5;
352        let mut scratch = [
353            data.load(idx + 0 * num_columns),
354            data.load(idx + 1 * num_columns) * twiddles.load(tw_idx + 0),
355            data.load(idx + 2 * num_columns) * twiddles.load(tw_idx + 1),
356            data.load(idx + 3 * num_columns) * twiddles.load(tw_idx + 2),
357            data.load(idx + 4 * num_columns) * twiddles.load(tw_idx + 3),
358            data.load(idx + 5 * num_columns) * twiddles.load(tw_idx + 4),
359        ];
360
361        butterfly6.perform_fft_butterfly(&mut scratch);
362
363        data.store(scratch[0], idx + 0 * num_columns);
364        data.store(scratch[1], idx + 1 * num_columns);
365        data.store(scratch[2], idx + 2 * num_columns);
366        data.store(scratch[3], idx + 3 * num_columns);
367        data.store(scratch[4], idx + 4 * num_columns);
368        data.store(scratch[5], idx + 5 * num_columns);
369    }
370}
371
372#[inline(never)]
373pub(crate) unsafe fn butterfly_7<T: FftNum>(
374    mut data: impl LoadStore<T>,
375    twiddles: impl Load<T>,
376    num_columns: usize,
377    butterfly7: &Butterfly7<T>,
378) {
379    for idx in 0..num_columns {
380        let tw_idx = idx * 6;
381        let mut scratch = [
382            data.load(idx + 0 * num_columns),
383            data.load(idx + 1 * num_columns) * twiddles.load(tw_idx + 0),
384            data.load(idx + 2 * num_columns) * twiddles.load(tw_idx + 1),
385            data.load(idx + 3 * num_columns) * twiddles.load(tw_idx + 2),
386            data.load(idx + 4 * num_columns) * twiddles.load(tw_idx + 3),
387            data.load(idx + 5 * num_columns) * twiddles.load(tw_idx + 4),
388            data.load(idx + 6 * num_columns) * twiddles.load(tw_idx + 5),
389        ];
390
391        butterfly7.perform_fft_butterfly(&mut scratch);
392
393        data.store(scratch[0], idx + 0 * num_columns);
394        data.store(scratch[1], idx + 1 * num_columns);
395        data.store(scratch[2], idx + 2 * num_columns);
396        data.store(scratch[3], idx + 3 * num_columns);
397        data.store(scratch[4], idx + 4 * num_columns);
398        data.store(scratch[5], idx + 5 * num_columns);
399        data.store(scratch[6], idx + 6 * num_columns);
400    }
401}
402
403#[cfg(test)]
404mod unit_tests {
405    use super::*;
406    use crate::test_utils::{check_fft_algorithm, construct_base};
407
408    #[test]
409    fn test_scalar_radixn() {
410        let factor_list = &[
411            RadixFactor::Factor2,
412            RadixFactor::Factor3,
413            RadixFactor::Factor4,
414            RadixFactor::Factor5,
415            RadixFactor::Factor6,
416            RadixFactor::Factor7,
417        ];
418
419        for base in 1..7 {
420            let base_forward = construct_base(base, FftDirection::Forward);
421            let base_inverse = construct_base(base, FftDirection::Inverse);
422
423            // test just the base with no factors
424            test_radixn(&[], Arc::clone(&base_forward));
425            test_radixn(&[], Arc::clone(&base_inverse));
426
427            // test one factor
428            for factor_a in factor_list {
429                let factors = &[*factor_a];
430                test_radixn(factors, Arc::clone(&base_forward));
431                test_radixn(factors, Arc::clone(&base_inverse));
432            }
433
434            // test two factors
435            for factor_a in factor_list {
436                for factor_b in factor_list {
437                    let factors = &[*factor_a, *factor_b];
438                    test_radixn(factors, Arc::clone(&base_forward));
439                    test_radixn(factors, Arc::clone(&base_inverse));
440                }
441            }
442        }
443    }
444
445    fn test_radixn(factors: &[RadixFactor], base_fft: Arc<dyn Fft<f64>>) {
446        let len = base_fft.len() * factors.iter().map(|f| f.radix()).product::<usize>();
447        let direction = base_fft.fft_direction();
448        let fft = RadixN::new(factors, base_fft);
449
450        check_fft_algorithm::<f64>(&fft, len, direction);
451    }
452}