rustfft/algorithm/
radix3.rs

1use std::sync::Arc;
2
3use num_complex::Complex;
4
5use crate::algorithm::butterflies::{Butterfly1, Butterfly27, Butterfly3, Butterfly9};
6use crate::algorithm::radixn::butterfly_3;
7use crate::array_utils::{self, bitreversed_transpose, compute_logarithm};
8use crate::common::{fft_error_inplace, fft_error_outofplace};
9use crate::{common::FftNum, twiddles, FftDirection};
10use crate::{Direction, Fft, Length};
11
12/// FFT algorithm optimized for power-of-three sizes
13///
14/// ~~~
15/// // Computes a forward FFT of size 2187
16/// use rustfft::algorithm::Radix3;
17/// use rustfft::{Fft, FftDirection};
18/// use rustfft::num_complex::Complex;
19///
20/// let mut buffer = vec![Complex{ re: 0.0f32, im: 0.0f32 }; 2187];
21///
22/// let fft = Radix3::new(2187, FftDirection::Forward);
23/// fft.process(&mut buffer);
24/// ~~~
25
26pub struct Radix3<T> {
27    twiddles: Box<[Complex<T>]>,
28    butterfly3: Butterfly3<T>,
29
30    base_fft: Arc<dyn Fft<T>>,
31    base_len: usize,
32
33    len: usize,
34    direction: FftDirection,
35    inplace_scratch_len: usize,
36    outofplace_scratch_len: usize,
37}
38
39impl<T: FftNum> Radix3<T> {
40    /// Preallocates necessary arrays and precomputes necessary data to efficiently compute the power-of-three FFT
41    pub fn new(len: usize, direction: FftDirection) -> Self {
42        // Compute the total power of 3 for this length. IE, len = 3^exponent
43        let exponent = compute_logarithm::<3>(len).unwrap_or_else(|| {
44            panic!(
45                "Radix3 algorithm requires a power-of-three input size. Got {}",
46                len
47            )
48        });
49
50        // figure out which base length we're going to use
51        let (base_exponent, base_fft) = match exponent {
52            0 => (0, Arc::new(Butterfly1::new(direction)) as Arc<dyn Fft<T>>),
53            1 => (1, Arc::new(Butterfly3::new(direction)) as Arc<dyn Fft<T>>),
54            2 => (2, Arc::new(Butterfly9::new(direction)) as Arc<dyn Fft<T>>),
55            _ => (3, Arc::new(Butterfly27::new(direction)) as Arc<dyn Fft<T>>),
56        };
57
58        Self::new_with_base(exponent - base_exponent, base_fft)
59    }
60
61    /// Constructs a Radix3 instance which computes FFTs of length `3^k * base_fft.len()`
62    pub fn new_with_base(k: u32, base_fft: Arc<dyn Fft<T>>) -> Self {
63        let base_len = base_fft.len();
64        let len = base_len * 3usize.pow(k);
65
66        let direction = base_fft.fft_direction();
67
68        // precompute the twiddle factors this algorithm will use.
69        // we're doing the same precomputation of twiddle factors as the mixed radix algorithm where width=3 and height=len/3
70        // but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down
71        // so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom layer and going up
72        const ROW_COUNT: usize = 3;
73        let mut cross_fft_len = base_len;
74        let mut twiddle_factors = Vec::with_capacity(len * 2);
75        while cross_fft_len < len {
76            let num_columns = cross_fft_len;
77            cross_fft_len *= ROW_COUNT;
78
79            for i in 0..num_columns {
80                for k in 1..ROW_COUNT {
81                    let twiddle = twiddles::compute_twiddle(i * k, cross_fft_len, direction);
82                    twiddle_factors.push(twiddle);
83                }
84            }
85        }
86
87        let base_inplace_scratch = base_fft.get_inplace_scratch_len();
88        let inplace_scratch_len = if base_inplace_scratch > cross_fft_len {
89            cross_fft_len + base_inplace_scratch
90        } else {
91            cross_fft_len
92        };
93        let outofplace_scratch_len = if base_inplace_scratch > len {
94            base_inplace_scratch
95        } else {
96            0
97        };
98
99        Self {
100            twiddles: twiddle_factors.into_boxed_slice(),
101            butterfly3: Butterfly3::new(direction),
102
103            base_fft,
104            base_len,
105
106            len,
107            direction,
108
109            inplace_scratch_len,
110            outofplace_scratch_len,
111        }
112    }
113
114    fn inplace_scratch_len(&self) -> usize {
115        self.inplace_scratch_len
116    }
117    fn outofplace_scratch_len(&self) -> usize {
118        self.outofplace_scratch_len
119    }
120
121    fn perform_fft_out_of_place(
122        &self,
123        input: &mut [Complex<T>],
124        output: &mut [Complex<T>],
125        scratch: &mut [Complex<T>],
126    ) {
127        // copy the data into the output vector
128        if self.len() == self.base_len {
129            output.copy_from_slice(input);
130        } else {
131            bitreversed_transpose::<Complex<T>, 3>(self.base_len, input, output);
132        }
133
134        // Base-level FFTs
135        let base_scratch = if scratch.len() > 0 { scratch } else { input };
136        self.base_fft.process_with_scratch(output, base_scratch);
137
138        // cross-FFTs
139        const ROW_COUNT: usize = 3;
140        let mut cross_fft_len = self.base_len;
141        let mut layer_twiddles: &[Complex<T>] = &self.twiddles;
142
143        while cross_fft_len < output.len() {
144            let num_columns = cross_fft_len;
145            cross_fft_len *= ROW_COUNT;
146
147            for data in output.chunks_exact_mut(cross_fft_len) {
148                unsafe { butterfly_3(data, layer_twiddles, num_columns, &self.butterfly3) }
149            }
150
151            // skip past all the twiddle factors used in this layer
152            let twiddle_offset = num_columns * (ROW_COUNT - 1);
153            layer_twiddles = &layer_twiddles[twiddle_offset..];
154        }
155    }
156}
157boilerplate_fft_oop!(Radix3, |this: &Radix3<_>| this.len);
158
159#[cfg(test)]
160mod unit_tests {
161    use super::*;
162    use crate::test_utils::{check_fft_algorithm, construct_base};
163
164    #[test]
165    fn test_radix3_with_length() {
166        for pow in 0..8 {
167            let len = 3usize.pow(pow);
168
169            let forward_fft = Radix3::new(len, FftDirection::Forward);
170            check_fft_algorithm::<f32>(&forward_fft, len, FftDirection::Forward);
171
172            let inverse_fft = Radix3::new(len, FftDirection::Inverse);
173            check_fft_algorithm::<f32>(&inverse_fft, len, FftDirection::Inverse);
174        }
175    }
176
177    #[test]
178    fn test_radix3_with_base() {
179        for base in 1..=9 {
180            let base_forward = construct_base(base, FftDirection::Forward);
181            let base_inverse = construct_base(base, FftDirection::Inverse);
182
183            for k in 0..5 {
184                test_radix3(k, Arc::clone(&base_forward));
185                test_radix3(k, Arc::clone(&base_inverse));
186            }
187        }
188    }
189
190    fn test_radix3(k: u32, base_fft: Arc<dyn Fft<f32>>) {
191        let len = base_fft.len() * 3usize.pow(k as u32);
192        let direction = base_fft.fft_direction();
193        let fft = Radix3::new_with_base(k, base_fft);
194
195        check_fft_algorithm::<f32>(&fft, len, direction);
196    }
197}