rustfft/algorithm/
radix3.rs1use std::sync::Arc;
2
3use num_complex::Complex;
4
5use crate::algorithm::butterflies::{Butterfly1, Butterfly27, Butterfly3, Butterfly9};
6use crate::algorithm::radixn::butterfly_3;
7use crate::array_utils::{self, bitreversed_transpose, compute_logarithm};
8use crate::common::{fft_error_inplace, fft_error_outofplace};
9use crate::{common::FftNum, twiddles, FftDirection};
10use crate::{Direction, Fft, Length};
11
12pub struct Radix3<T> {
27 twiddles: Box<[Complex<T>]>,
28 butterfly3: Butterfly3<T>,
29
30 base_fft: Arc<dyn Fft<T>>,
31 base_len: usize,
32
33 len: usize,
34 direction: FftDirection,
35 inplace_scratch_len: usize,
36 outofplace_scratch_len: usize,
37}
38
39impl<T: FftNum> Radix3<T> {
40 pub fn new(len: usize, direction: FftDirection) -> Self {
42 let exponent = compute_logarithm::<3>(len).unwrap_or_else(|| {
44 panic!(
45 "Radix3 algorithm requires a power-of-three input size. Got {}",
46 len
47 )
48 });
49
50 let (base_exponent, base_fft) = match exponent {
52 0 => (0, Arc::new(Butterfly1::new(direction)) as Arc<dyn Fft<T>>),
53 1 => (1, Arc::new(Butterfly3::new(direction)) as Arc<dyn Fft<T>>),
54 2 => (2, Arc::new(Butterfly9::new(direction)) as Arc<dyn Fft<T>>),
55 _ => (3, Arc::new(Butterfly27::new(direction)) as Arc<dyn Fft<T>>),
56 };
57
58 Self::new_with_base(exponent - base_exponent, base_fft)
59 }
60
61 pub fn new_with_base(k: u32, base_fft: Arc<dyn Fft<T>>) -> Self {
63 let base_len = base_fft.len();
64 let len = base_len * 3usize.pow(k);
65
66 let direction = base_fft.fft_direction();
67
68 const ROW_COUNT: usize = 3;
73 let mut cross_fft_len = base_len;
74 let mut twiddle_factors = Vec::with_capacity(len * 2);
75 while cross_fft_len < len {
76 let num_columns = cross_fft_len;
77 cross_fft_len *= ROW_COUNT;
78
79 for i in 0..num_columns {
80 for k in 1..ROW_COUNT {
81 let twiddle = twiddles::compute_twiddle(i * k, cross_fft_len, direction);
82 twiddle_factors.push(twiddle);
83 }
84 }
85 }
86
87 let base_inplace_scratch = base_fft.get_inplace_scratch_len();
88 let inplace_scratch_len = if base_inplace_scratch > cross_fft_len {
89 cross_fft_len + base_inplace_scratch
90 } else {
91 cross_fft_len
92 };
93 let outofplace_scratch_len = if base_inplace_scratch > len {
94 base_inplace_scratch
95 } else {
96 0
97 };
98
99 Self {
100 twiddles: twiddle_factors.into_boxed_slice(),
101 butterfly3: Butterfly3::new(direction),
102
103 base_fft,
104 base_len,
105
106 len,
107 direction,
108
109 inplace_scratch_len,
110 outofplace_scratch_len,
111 }
112 }
113
114 fn inplace_scratch_len(&self) -> usize {
115 self.inplace_scratch_len
116 }
117 fn outofplace_scratch_len(&self) -> usize {
118 self.outofplace_scratch_len
119 }
120
121 fn perform_fft_out_of_place(
122 &self,
123 input: &mut [Complex<T>],
124 output: &mut [Complex<T>],
125 scratch: &mut [Complex<T>],
126 ) {
127 if self.len() == self.base_len {
129 output.copy_from_slice(input);
130 } else {
131 bitreversed_transpose::<Complex<T>, 3>(self.base_len, input, output);
132 }
133
134 let base_scratch = if scratch.len() > 0 { scratch } else { input };
136 self.base_fft.process_with_scratch(output, base_scratch);
137
138 const ROW_COUNT: usize = 3;
140 let mut cross_fft_len = self.base_len;
141 let mut layer_twiddles: &[Complex<T>] = &self.twiddles;
142
143 while cross_fft_len < output.len() {
144 let num_columns = cross_fft_len;
145 cross_fft_len *= ROW_COUNT;
146
147 for data in output.chunks_exact_mut(cross_fft_len) {
148 unsafe { butterfly_3(data, layer_twiddles, num_columns, &self.butterfly3) }
149 }
150
151 let twiddle_offset = num_columns * (ROW_COUNT - 1);
153 layer_twiddles = &layer_twiddles[twiddle_offset..];
154 }
155 }
156}
157boilerplate_fft_oop!(Radix3, |this: &Radix3<_>| this.len);
158
159#[cfg(test)]
160mod unit_tests {
161 use super::*;
162 use crate::test_utils::{check_fft_algorithm, construct_base};
163
164 #[test]
165 fn test_radix3_with_length() {
166 for pow in 0..8 {
167 let len = 3usize.pow(pow);
168
169 let forward_fft = Radix3::new(len, FftDirection::Forward);
170 check_fft_algorithm::<f32>(&forward_fft, len, FftDirection::Forward);
171
172 let inverse_fft = Radix3::new(len, FftDirection::Inverse);
173 check_fft_algorithm::<f32>(&inverse_fft, len, FftDirection::Inverse);
174 }
175 }
176
177 #[test]
178 fn test_radix3_with_base() {
179 for base in 1..=9 {
180 let base_forward = construct_base(base, FftDirection::Forward);
181 let base_inverse = construct_base(base, FftDirection::Inverse);
182
183 for k in 0..5 {
184 test_radix3(k, Arc::clone(&base_forward));
185 test_radix3(k, Arc::clone(&base_inverse));
186 }
187 }
188 }
189
190 fn test_radix3(k: u32, base_fft: Arc<dyn Fft<f32>>) {
191 let len = base_fft.len() * 3usize.pow(k as u32);
192 let direction = base_fft.fft_direction();
193 let fft = Radix3::new_with_base(k, base_fft);
194
195 check_fft_algorithm::<f32>(&fft, len, direction);
196 }
197}