rustfft/algorithm/
raders_algorithm.rs1use std::sync::Arc;
2
3use num_complex::Complex;
4use num_integer::Integer;
5use num_traits::Zero;
6use primal_check::miller_rabin;
7use strength_reduce::StrengthReducedUsize;
8
9use crate::array_utils;
10use crate::common::{fft_error_inplace, fft_error_outofplace};
11use crate::math_utils;
12use crate::{common::FftNum, twiddles, FftDirection};
13use crate::{Direction, Fft, Length};
14
15pub struct RadersAlgorithm<T> {
44 inner_fft: Arc<dyn Fft<T>>,
45 inner_fft_data: Box<[Complex<T>]>,
46
47 primitive_root: usize,
48 primitive_root_inverse: usize,
49
50 len: StrengthReducedUsize,
51 inplace_scratch_len: usize,
52 outofplace_scratch_len: usize,
53 direction: FftDirection,
54}
55
56impl<T: FftNum> RadersAlgorithm<T> {
57 pub fn new(inner_fft: Arc<dyn Fft<T>>) -> Self {
66 let inner_fft_len = inner_fft.len();
67 let len = inner_fft_len + 1;
68 assert!(miller_rabin(len as u64), "For raders algorithm, inner_fft.len() + 1 must be prime. Expected prime number, got {} + 1 = {}", inner_fft_len, len);
69
70 let direction = inner_fft.fft_direction();
71 let reduced_len = StrengthReducedUsize::new(len);
72
73 let primitive_root = math_utils::primitive_root(len as u64).unwrap() as usize;
75
76 let gcd_data = i64::extended_gcd(&(primitive_root as i64), &(len as i64));
80 let primitive_root_inverse = if gcd_data.x >= 0 {
81 gcd_data.x
82 } else {
83 gcd_data.x + len as i64
84 } as usize;
85
86 let inner_fft_scale = T::one() / T::from_usize(inner_fft_len).unwrap();
88 let mut inner_fft_input = vec![Complex::zero(); inner_fft_len];
89 let mut twiddle_input = 1;
90 for input_cell in &mut inner_fft_input {
91 let twiddle = twiddles::compute_twiddle(twiddle_input, len, direction);
92 *input_cell = twiddle * inner_fft_scale;
93
94 twiddle_input = (twiddle_input * primitive_root_inverse) % reduced_len;
95 }
96
97 let required_inner_scratch = inner_fft.get_inplace_scratch_len();
98 let extra_inner_scratch = if required_inner_scratch <= inner_fft_len {
99 0
100 } else {
101 required_inner_scratch
102 };
103
104 let mut inner_fft_scratch = vec![Zero::zero(); required_inner_scratch];
106 inner_fft.process_with_scratch(&mut inner_fft_input, &mut inner_fft_scratch);
107
108 Self {
109 inner_fft,
110 inner_fft_data: inner_fft_input.into_boxed_slice(),
111
112 primitive_root,
113 primitive_root_inverse,
114
115 len: reduced_len,
116 inplace_scratch_len: inner_fft_len + extra_inner_scratch,
117 outofplace_scratch_len: extra_inner_scratch,
118 direction,
119 }
120 }
121
122 fn perform_fft_out_of_place(
123 &self,
124 input: &mut [Complex<T>],
125 output: &mut [Complex<T>],
126 scratch: &mut [Complex<T>],
127 ) {
128 let (output_first, output) = output.split_first_mut().unwrap();
130 let (input_first, input) = input.split_first_mut().unwrap();
131
132 let mut input_index = 1;
134 for output_element in output.iter_mut() {
135 input_index = (input_index * self.primitive_root) % self.len;
136
137 let input_element = input[input_index - 1];
138 *output_element = input_element;
139 }
140
141 let inner_scratch = if scratch.len() > 0 {
143 &mut scratch[..]
144 } else {
145 &mut input[..]
146 };
147 self.inner_fft.process_with_scratch(output, inner_scratch);
148
149 *output_first = *input_first + output[0];
151
152 for ((output_cell, input_cell), &multiple) in output
156 .iter()
157 .zip(input.iter_mut())
158 .zip(self.inner_fft_data.iter())
159 {
160 *input_cell = (*output_cell * multiple).conj();
161 }
162
163 input[0] = input[0] + input_first.conj();
166
167 let inner_scratch = if scratch.len() > 0 {
169 scratch
170 } else {
171 &mut output[..]
172 };
173 self.inner_fft.process_with_scratch(input, inner_scratch);
174
175 let mut output_index = 1;
177 for input_element in input {
178 output_index = (output_index * self.primitive_root_inverse) % self.len;
179 output[output_index - 1] = input_element.conj();
180 }
181 }
182 fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
183 let (buffer_first, buffer) = buffer.split_first_mut().unwrap();
185 let buffer_first_val = *buffer_first;
186
187 let (scratch, extra_scratch) = scratch.split_at_mut(self.len() - 1);
188
189 let mut input_index = 1;
191 for scratch_element in scratch.iter_mut() {
192 input_index = (input_index * self.primitive_root) % self.len;
193
194 let buffer_element = buffer[input_index - 1];
195 *scratch_element = buffer_element;
196 }
197
198 let inner_scratch = if extra_scratch.len() > 0 {
200 extra_scratch
201 } else {
202 &mut buffer[..]
203 };
204 self.inner_fft.process_with_scratch(scratch, inner_scratch);
205
206 *buffer_first = *buffer_first + scratch[0];
208
209 for (scratch_cell, &twiddle) in scratch.iter_mut().zip(self.inner_fft_data.iter()) {
213 *scratch_cell = (*scratch_cell * twiddle).conj();
214 }
215
216 scratch[0] = scratch[0] + buffer_first_val.conj();
219
220 self.inner_fft.process_with_scratch(scratch, inner_scratch);
222
223 let mut output_index = 1;
225 for scratch_element in scratch {
226 output_index = (output_index * self.primitive_root_inverse) % self.len;
227 buffer[output_index - 1] = scratch_element.conj();
228 }
229 }
230}
231boilerplate_fft!(
232 RadersAlgorithm,
233 |this: &RadersAlgorithm<_>| this.len.get(),
234 |this: &RadersAlgorithm<_>| this.inplace_scratch_len,
235 |this: &RadersAlgorithm<_>| this.outofplace_scratch_len
236);
237
238#[cfg(test)]
239mod unit_tests {
240 use super::*;
241 use crate::algorithm::Dft;
242 use crate::test_utils::check_fft_algorithm;
243 use std::sync::Arc;
244
245 #[test]
246 fn test_raders() {
247 for len in 3..100 {
248 if miller_rabin(len as u64) {
249 test_raders_with_length(len, FftDirection::Forward);
250 test_raders_with_length(len, FftDirection::Inverse);
251 }
252 }
253 }
254
255 fn test_raders_with_length(len: usize, direction: FftDirection) {
256 let inner_fft = Arc::new(Dft::new(len - 1, direction));
257 let fft = RadersAlgorithm::new(inner_fft);
258
259 check_fft_algorithm::<f32>(&fft, len, direction);
260 }
261}