rustfft/algorithm/
radix4.rs1use std::sync::Arc;
2
3use num_complex::Complex;
4
5use crate::algorithm::butterflies::{
6 Butterfly1, Butterfly16, Butterfly2, Butterfly32, Butterfly4, Butterfly8,
7};
8use crate::algorithm::radixn::butterfly_4;
9use crate::array_utils::{self, bitreversed_transpose};
10use crate::common::{fft_error_inplace, fft_error_outofplace};
11use crate::{common::FftNum, twiddles, FftDirection};
12use crate::{Direction, Fft, Length};
13
14pub struct Radix4<T> {
29 twiddles: Box<[Complex<T>]>,
30
31 base_fft: Arc<dyn Fft<T>>,
32 base_len: usize,
33
34 len: usize,
35 direction: FftDirection,
36 inplace_scratch_len: usize,
37 outofplace_scratch_len: usize,
38}
39
40impl<T: FftNum> Radix4<T> {
41 pub fn new(len: usize, direction: FftDirection) -> Self {
43 assert!(
44 len.is_power_of_two(),
45 "Radix4 algorithm requires a power-of-two input size. Got {}",
46 len
47 );
48
49 let exponent = len.trailing_zeros();
51 let (base_exponent, base_fft) = match exponent {
52 0 => (0, Arc::new(Butterfly1::new(direction)) as Arc<dyn Fft<T>>),
53 1 => (1, Arc::new(Butterfly2::new(direction)) as Arc<dyn Fft<T>>),
54 2 => (2, Arc::new(Butterfly4::new(direction)) as Arc<dyn Fft<T>>),
55 3 => (3, Arc::new(Butterfly8::new(direction)) as Arc<dyn Fft<T>>),
56 _ => {
57 if exponent % 2 == 1 {
58 (5, Arc::new(Butterfly32::new(direction)) as Arc<dyn Fft<T>>)
59 } else {
60 (4, Arc::new(Butterfly16::new(direction)) as Arc<dyn Fft<T>>)
61 }
62 }
63 };
64
65 Self::new_with_base((exponent - base_exponent) / 2, base_fft)
66 }
67
68 pub fn new_with_base(k: u32, base_fft: Arc<dyn Fft<T>>) -> Self {
70 let base_len = base_fft.len();
71 let len = base_len * (1 << (k * 2));
72
73 let direction = base_fft.fft_direction();
74
75 const ROW_COUNT: usize = 4;
80 let mut cross_fft_len = base_len;
81 let mut twiddle_factors = Vec::with_capacity(len * 2);
82 while cross_fft_len < len {
83 let num_columns = cross_fft_len;
84 cross_fft_len *= ROW_COUNT;
85
86 for i in 0..num_columns {
87 for k in 1..ROW_COUNT {
88 let twiddle = twiddles::compute_twiddle(i * k, cross_fft_len, direction);
89 twiddle_factors.push(twiddle);
90 }
91 }
92 }
93
94 let base_inplace_scratch = base_fft.get_inplace_scratch_len();
95 let inplace_scratch_len = if base_inplace_scratch > cross_fft_len {
96 cross_fft_len + base_inplace_scratch
97 } else {
98 cross_fft_len
99 };
100 let outofplace_scratch_len = if base_inplace_scratch > len {
101 base_inplace_scratch
102 } else {
103 0
104 };
105
106 Self {
107 twiddles: twiddle_factors.into_boxed_slice(),
108
109 base_fft,
110 base_len,
111
112 len,
113 direction,
114
115 inplace_scratch_len,
116 outofplace_scratch_len,
117 }
118 }
119
120 fn inplace_scratch_len(&self) -> usize {
121 self.inplace_scratch_len
122 }
123 fn outofplace_scratch_len(&self) -> usize {
124 self.outofplace_scratch_len
125 }
126
127 fn perform_fft_out_of_place(
128 &self,
129 input: &mut [Complex<T>],
130 output: &mut [Complex<T>],
131 scratch: &mut [Complex<T>],
132 ) {
133 if self.len() == self.base_len {
135 output.copy_from_slice(input);
136 } else {
137 bitreversed_transpose::<Complex<T>, 4>(self.base_len, input, output);
138 }
139
140 let base_scratch = if scratch.len() > 0 { scratch } else { input };
142 self.base_fft.process_with_scratch(output, base_scratch);
143
144 const ROW_COUNT: usize = 4;
146 let mut cross_fft_len = self.base_len;
147 let mut layer_twiddles: &[Complex<T>] = &self.twiddles;
148
149 let butterfly4 = Butterfly4::new(self.direction);
150
151 while cross_fft_len < output.len() {
152 let num_columns = cross_fft_len;
153 cross_fft_len *= ROW_COUNT;
154
155 for data in output.chunks_exact_mut(cross_fft_len) {
156 unsafe { butterfly_4(data, layer_twiddles, num_columns, &butterfly4) }
157 }
158
159 let twiddle_offset = num_columns * (ROW_COUNT - 1);
161 layer_twiddles = &layer_twiddles[twiddle_offset..];
162 }
163 }
164}
165boilerplate_fft_oop!(Radix4, |this: &Radix4<_>| this.len);
166
167#[cfg(test)]
168mod unit_tests {
169 use super::*;
170 use crate::test_utils::{check_fft_algorithm, construct_base};
171
172 #[test]
173 fn test_radix4_with_length() {
174 for pow in 0..8 {
175 let len = 1 << pow;
176
177 let forward_fft = Radix4::new(len, FftDirection::Forward);
178 check_fft_algorithm::<f32>(&forward_fft, len, FftDirection::Forward);
179
180 let inverse_fft = Radix4::new(len, FftDirection::Inverse);
181 check_fft_algorithm::<f32>(&inverse_fft, len, FftDirection::Inverse);
182 }
183 }
184
185 #[test]
186 fn test_radix4_with_base() {
187 for base in 1..=9 {
188 let base_forward = construct_base(base, FftDirection::Forward);
189 let base_inverse = construct_base(base, FftDirection::Inverse);
190
191 for k in 0..4 {
192 test_radix4(k, Arc::clone(&base_forward));
193 test_radix4(k, Arc::clone(&base_inverse));
194 }
195 }
196 }
197
198 fn test_radix4(k: u32, base_fft: Arc<dyn Fft<f64>>) {
199 let len = base_fft.len() * 4usize.pow(k as u32);
200 let direction = base_fft.fft_direction();
201 let fft = Radix4::new_with_base(k, base_fft);
202
203 check_fft_algorithm::<f64>(&fft, len, direction);
204 }
205}