rustfft/algorithm/
mixed_radix.rs

1use std::cmp::max;
2use std::sync::Arc;
3
4use num_complex::Complex;
5use num_traits::Zero;
6use transpose;
7
8use crate::array_utils;
9use crate::common::{fft_error_inplace, fft_error_outofplace};
10use crate::{common::FftNum, twiddles, FftDirection};
11use crate::{Direction, Fft, Length};
12
13/// Implementation of the Mixed-Radix FFT algorithm
14///
15/// This algorithm factors a size n FFT into n1 * n2, computes several inner FFTs of size n1 and n2, then combines the
16/// results to get the final answer
17///
18/// ~~~
19/// // Computes a forward FFT of size 1200, using the Mixed-Radix Algorithm
20/// use rustfft::algorithm::MixedRadix;
21/// use rustfft::{Fft, FftPlanner};
22/// use rustfft::num_complex::Complex;
23///
24/// let mut buffer = vec![Complex{ re: 0.0f32, im: 0.0f32 }; 1200];
25///
26/// // we need to find an n1 and n2 such that n1 * n2 == 1200
27/// // n1 = 30 and n2 = 40 satisfies this
28/// let mut planner = FftPlanner::new();
29/// let inner_fft_n1 = planner.plan_fft_forward(30);
30/// let inner_fft_n2 = planner.plan_fft_forward(40);
31///
32/// // the mixed radix FFT length will be inner_fft_n1.len() * inner_fft_n2.len() = 1200
33/// let fft = MixedRadix::new(inner_fft_n1, inner_fft_n2);
34/// fft.process(&mut buffer);
35/// ~~~
36pub struct MixedRadix<T> {
37    twiddles: Box<[Complex<T>]>,
38
39    width_size_fft: Arc<dyn Fft<T>>,
40    width: usize,
41
42    height_size_fft: Arc<dyn Fft<T>>,
43    height: usize,
44
45    inplace_scratch_len: usize,
46    outofplace_scratch_len: usize,
47
48    direction: FftDirection,
49}
50
51impl<T: FftNum> MixedRadix<T> {
52    /// Creates a FFT instance which will process inputs/outputs of size `width_fft.len() * height_fft.len()`
53    pub fn new(width_fft: Arc<dyn Fft<T>>, height_fft: Arc<dyn Fft<T>>) -> Self {
54        assert_eq!(
55            width_fft.fft_direction(), height_fft.fft_direction(),
56            "width_fft and height_fft must have the same direction. got width direction={}, height direction={}",
57            width_fft.fft_direction(), height_fft.fft_direction());
58
59        let direction = width_fft.fft_direction();
60
61        let width = width_fft.len();
62        let height = height_fft.len();
63
64        let len = width * height;
65
66        let mut twiddles = vec![Complex::zero(); len];
67        for (x, twiddle_chunk) in twiddles.chunks_exact_mut(height).enumerate() {
68            for (y, twiddle_element) in twiddle_chunk.iter_mut().enumerate() {
69                *twiddle_element = twiddles::compute_twiddle(x * y, len, direction);
70            }
71        }
72
73        // Collect some data about what kind of scratch space our inner FFTs need
74        let height_inplace_scratch = height_fft.get_inplace_scratch_len();
75        let width_inplace_scratch = width_fft.get_inplace_scratch_len();
76        let width_outofplace_scratch = width_fft.get_outofplace_scratch_len();
77
78        // Computing the scratch we'll require is a somewhat confusing process.
79        // When we compute an out-of-place FFT, both of our inner FFTs are in-place
80        // When we compute an inplace FFT, our inner width FFT will be inplace, and our height FFT will be out-of-place
81        // For the out-of-place FFT, one of 2 things can happen regarding scratch:
82        //      - If the required scratch of both FFTs is <= self.len(), then we can use the input or output buffer as scratch, and so we need 0 extra scratch
83        //      - If either of the inner FFTs require more, then we'll have to request an entire scratch buffer for the inner FFTs,
84        //          whose size is the max of the two inner FFTs' required scratch
85        let max_inner_inplace_scratch = max(height_inplace_scratch, width_inplace_scratch);
86        let outofplace_scratch_len = if max_inner_inplace_scratch > len {
87            max_inner_inplace_scratch
88        } else {
89            0
90        };
91
92        // For the in-place FFT, again the best case is that we can just bounce data around between internal buffers, and the only inplace scratch we need is self.len()
93        // If our width fft's OOP FFT requires any scratch, then we can tack that on the end of our own scratch, and use split_at_mut to separate our own from our internal FFT's
94        // Likewise, if our height inplace FFT requires more inplace scracth than self.len(), we can tack that on to the end of our own inplace scratch.
95        // Thus, the total inplace scratch is our own length plus the max of what the two inner FFTs will need
96        let inplace_scratch_len = len
97            + max(
98                if height_inplace_scratch > len {
99                    height_inplace_scratch
100                } else {
101                    0
102                },
103                width_outofplace_scratch,
104            );
105
106        Self {
107            twiddles: twiddles.into_boxed_slice(),
108
109            width_size_fft: width_fft,
110            width: width,
111
112            height_size_fft: height_fft,
113            height: height,
114
115            inplace_scratch_len,
116            outofplace_scratch_len,
117
118            direction,
119        }
120    }
121
122    fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
123        // SIX STEP FFT:
124        let (scratch, inner_scratch) = scratch.split_at_mut(self.len());
125
126        // STEP 1: transpose
127        transpose::transpose(buffer, scratch, self.width, self.height);
128
129        // STEP 2: perform FFTs of size `height`
130        let height_scratch = if inner_scratch.len() > buffer.len() {
131            &mut inner_scratch[..]
132        } else {
133            &mut buffer[..]
134        };
135        self.height_size_fft
136            .process_with_scratch(scratch, height_scratch);
137
138        // STEP 3: Apply twiddle factors
139        for (element, twiddle) in scratch.iter_mut().zip(self.twiddles.iter()) {
140            *element = *element * twiddle;
141        }
142
143        // STEP 4: transpose again
144        transpose::transpose(scratch, buffer, self.height, self.width);
145
146        // STEP 5: perform FFTs of size `width`
147        self.width_size_fft
148            .process_outofplace_with_scratch(buffer, scratch, inner_scratch);
149
150        // STEP 6: transpose again
151        transpose::transpose(scratch, buffer, self.width, self.height);
152    }
153
154    fn perform_fft_out_of_place(
155        &self,
156        input: &mut [Complex<T>],
157        output: &mut [Complex<T>],
158        scratch: &mut [Complex<T>],
159    ) {
160        // SIX STEP FFT:
161
162        // STEP 1: transpose
163        transpose::transpose(input, output, self.width, self.height);
164
165        // STEP 2: perform FFTs of size `height`
166        let height_scratch = if scratch.len() > input.len() {
167            &mut scratch[..]
168        } else {
169            &mut input[..]
170        };
171        self.height_size_fft
172            .process_with_scratch(output, height_scratch);
173
174        // STEP 3: Apply twiddle factors
175        for (element, twiddle) in output.iter_mut().zip(self.twiddles.iter()) {
176            *element = *element * twiddle;
177        }
178
179        // STEP 4: transpose again
180        transpose::transpose(output, input, self.height, self.width);
181
182        // STEP 5: perform FFTs of size `width`
183        let width_scratch = if scratch.len() > output.len() {
184            &mut scratch[..]
185        } else {
186            &mut output[..]
187        };
188        self.width_size_fft
189            .process_with_scratch(input, width_scratch);
190
191        // STEP 6: transpose again
192        transpose::transpose(input, output, self.width, self.height);
193    }
194}
195boilerplate_fft!(
196    MixedRadix,
197    |this: &MixedRadix<_>| this.twiddles.len(),
198    |this: &MixedRadix<_>| this.inplace_scratch_len,
199    |this: &MixedRadix<_>| this.outofplace_scratch_len
200);
201
202/// Implementation of the Mixed-Radix FFT algorithm, specialized for smaller input sizes
203///
204/// This algorithm factors a size n FFT into n1 * n2, computes several inner FFTs of size n1 and n2, then combines the
205/// results to get the final answer
206///
207/// ~~~
208/// // Computes a forward FFT of size 40 using MixedRadixSmall
209/// use std::sync::Arc;
210/// use rustfft::algorithm::MixedRadixSmall;
211/// use rustfft::algorithm::butterflies::{Butterfly5, Butterfly8};
212/// use rustfft::{Fft, FftDirection};
213/// use rustfft::num_complex::Complex;
214///
215/// let len = 40;
216///
217/// let mut buffer = vec![Complex{ re: 0.0f32, im: 0.0f32 }; len];
218///
219/// // we need to find an n1 and n2 such that n1 * n2 == 40
220/// // n1 = 5 and n2 = 8 satisfies this
221/// let inner_fft_n1 = Arc::new(Butterfly5::new(FftDirection::Forward));
222/// let inner_fft_n2 = Arc::new(Butterfly8::new(FftDirection::Forward));
223///
224/// // the mixed radix FFT length will be inner_fft_n1.len() * inner_fft_n2.len() = 40
225/// let fft = MixedRadixSmall::new(inner_fft_n1, inner_fft_n2);
226/// fft.process(&mut buffer);
227/// ~~~
228pub struct MixedRadixSmall<T> {
229    twiddles: Box<[Complex<T>]>,
230
231    width_size_fft: Arc<dyn Fft<T>>,
232    width: usize,
233
234    height_size_fft: Arc<dyn Fft<T>>,
235    height: usize,
236
237    direction: FftDirection,
238}
239
240impl<T: FftNum> MixedRadixSmall<T> {
241    /// Creates a FFT instance which will process inputs/outputs of size `width_fft.len() * height_fft.len()`
242    pub fn new(width_fft: Arc<dyn Fft<T>>, height_fft: Arc<dyn Fft<T>>) -> Self {
243        assert_eq!(
244            width_fft.fft_direction(), height_fft.fft_direction(),
245            "width_fft and height_fft must have the same direction. got width direction={}, height direction={}",
246            width_fft.fft_direction(), height_fft.fft_direction());
247
248        // Verify that the inner FFTs don't require out-of-place scratch, and only arequire a small amount of inplace scratch
249        let width = width_fft.len();
250        let height = height_fft.len();
251        let len = width * height;
252
253        assert_eq!(width_fft.get_outofplace_scratch_len(), 0, "MixedRadixSmall should only be used with algorithms that require 0 out-of-place scratch. Width FFT (len={}) requires {}, should require 0", width, width_fft.get_outofplace_scratch_len());
254        assert_eq!(height_fft.get_outofplace_scratch_len(), 0, "MixedRadixSmall should only be used with algorithms that require 0 out-of-place scratch. Height FFT (len={}) requires {}, should require 0", height, height_fft.get_outofplace_scratch_len());
255
256        assert!(width_fft.get_inplace_scratch_len() <= width, "MixedRadixSmall should only be used with algorithms that require little inplace scratch. Width FFT (len={}) requires {}, should require {} or less", width, width_fft.get_inplace_scratch_len(), width);
257        assert!(height_fft.get_inplace_scratch_len() <= height, "MixedRadixSmall should only be used with algorithms that require little inplace scratch. Height FFT (len={}) requires {}, should require {} or less", height, height_fft.get_inplace_scratch_len(), height);
258
259        let direction = width_fft.fft_direction();
260
261        let mut twiddles = vec![Complex::zero(); len];
262        for (x, twiddle_chunk) in twiddles.chunks_exact_mut(height).enumerate() {
263            for (y, twiddle_element) in twiddle_chunk.iter_mut().enumerate() {
264                *twiddle_element = twiddles::compute_twiddle(x * y, len, direction);
265            }
266        }
267
268        Self {
269            twiddles: twiddles.into_boxed_slice(),
270
271            width_size_fft: width_fft,
272            width: width,
273
274            height_size_fft: height_fft,
275            height: height,
276
277            direction,
278        }
279    }
280
281    fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
282        // SIX STEP FFT:
283        // STEP 1: transpose
284        unsafe { array_utils::transpose_small(self.width, self.height, buffer, scratch) };
285
286        // STEP 2: perform FFTs of size `height`
287        self.height_size_fft.process_with_scratch(scratch, buffer);
288
289        // STEP 3: Apply twiddle factors
290        for (element, twiddle) in scratch.iter_mut().zip(self.twiddles.iter()) {
291            *element = *element * twiddle;
292        }
293
294        // STEP 4: transpose again
295        unsafe { array_utils::transpose_small(self.height, self.width, scratch, buffer) };
296
297        // STEP 5: perform FFTs of size `width`
298        self.width_size_fft
299            .process_outofplace_with_scratch(buffer, scratch, &mut []);
300
301        // STEP 6: transpose again
302        unsafe { array_utils::transpose_small(self.width, self.height, scratch, buffer) };
303    }
304
305    fn perform_fft_out_of_place(
306        &self,
307        input: &mut [Complex<T>],
308        output: &mut [Complex<T>],
309        _scratch: &mut [Complex<T>],
310    ) {
311        // SIX STEP FFT:
312        // STEP 1: transpose
313        unsafe { array_utils::transpose_small(self.width, self.height, input, output) };
314
315        // STEP 2: perform FFTs of size `height`
316        self.height_size_fft.process_with_scratch(output, input);
317
318        // STEP 3: Apply twiddle factors
319        for (element, twiddle) in output.iter_mut().zip(self.twiddles.iter()) {
320            *element = *element * twiddle;
321        }
322
323        // STEP 4: transpose again
324        unsafe { array_utils::transpose_small(self.height, self.width, output, input) };
325
326        // STEP 5: perform FFTs of size `width`
327        self.width_size_fft.process_with_scratch(input, output);
328
329        // STEP 6: transpose again
330        unsafe { array_utils::transpose_small(self.width, self.height, input, output) };
331    }
332}
333boilerplate_fft!(
334    MixedRadixSmall,
335    |this: &MixedRadixSmall<_>| this.twiddles.len(),
336    |this: &MixedRadixSmall<_>| this.len(),
337    |_| 0
338);
339
340#[cfg(test)]
341mod unit_tests {
342    use super::*;
343    use crate::test_utils::check_fft_algorithm;
344    use crate::{algorithm::Dft, test_utils::BigScratchAlgorithm};
345    use num_traits::Zero;
346    use std::sync::Arc;
347
348    #[test]
349    fn test_mixed_radix() {
350        for width in 1..7 {
351            for height in 1..7 {
352                test_mixed_radix_with_lengths(width, height, FftDirection::Forward);
353                test_mixed_radix_with_lengths(width, height, FftDirection::Inverse);
354            }
355        }
356    }
357
358    #[test]
359    fn test_mixed_radix_small() {
360        for width in 2..7 {
361            for height in 2..7 {
362                test_mixed_radix_small_with_lengths(width, height, FftDirection::Forward);
363                test_mixed_radix_small_with_lengths(width, height, FftDirection::Inverse);
364            }
365        }
366    }
367
368    fn test_mixed_radix_with_lengths(width: usize, height: usize, direction: FftDirection) {
369        let width_fft = Arc::new(Dft::new(width, direction)) as Arc<dyn Fft<f32>>;
370        let height_fft = Arc::new(Dft::new(height, direction)) as Arc<dyn Fft<f32>>;
371
372        let fft = MixedRadix::new(width_fft, height_fft);
373
374        check_fft_algorithm(&fft, width * height, direction);
375    }
376
377    fn test_mixed_radix_small_with_lengths(width: usize, height: usize, direction: FftDirection) {
378        let width_fft = Arc::new(Dft::new(width, direction)) as Arc<dyn Fft<f32>>;
379        let height_fft = Arc::new(Dft::new(height, direction)) as Arc<dyn Fft<f32>>;
380
381        let fft = MixedRadixSmall::new(width_fft, height_fft);
382
383        check_fft_algorithm(&fft, width * height, direction);
384    }
385
386    // Verify that the mixed radix algorithm correctly provides scratch space to inner FFTs
387    #[test]
388    fn test_mixed_radix_inner_scratch() {
389        let scratch_lengths = [1, 5, 25];
390
391        let mut inner_ffts = Vec::new();
392
393        for &len in &scratch_lengths {
394            for &inplace_scratch in &scratch_lengths {
395                for &outofplace_scratch in &scratch_lengths {
396                    inner_ffts.push(Arc::new(BigScratchAlgorithm {
397                        len,
398                        inplace_scratch,
399                        outofplace_scratch,
400                        direction: FftDirection::Forward,
401                    }) as Arc<dyn Fft<f32>>);
402                }
403            }
404        }
405
406        for width_fft in inner_ffts.iter() {
407            for height_fft in inner_ffts.iter() {
408                let fft = MixedRadix::new(Arc::clone(width_fft), Arc::clone(height_fft));
409
410                let mut inplace_buffer = vec![Complex::zero(); fft.len()];
411                let mut inplace_scratch = vec![Complex::zero(); fft.get_inplace_scratch_len()];
412
413                fft.process_with_scratch(&mut inplace_buffer, &mut inplace_scratch);
414
415                let mut outofplace_input = vec![Complex::zero(); fft.len()];
416                let mut outofplace_output = vec![Complex::zero(); fft.len()];
417                let mut outofplace_scratch =
418                    vec![Complex::zero(); fft.get_outofplace_scratch_len()];
419                fft.process_outofplace_with_scratch(
420                    &mut outofplace_input,
421                    &mut outofplace_output,
422                    &mut outofplace_scratch,
423                );
424            }
425        }
426    }
427}