1use std::cmp::max;
2use std::sync::Arc;
3
4use num_complex::Complex;
5use num_traits::Zero;
6use transpose;
7
8use crate::array_utils;
9use crate::common::{fft_error_inplace, fft_error_outofplace};
10use crate::{common::FftNum, twiddles, FftDirection};
11use crate::{Direction, Fft, Length};
12
13pub struct MixedRadix<T> {
37 twiddles: Box<[Complex<T>]>,
38
39 width_size_fft: Arc<dyn Fft<T>>,
40 width: usize,
41
42 height_size_fft: Arc<dyn Fft<T>>,
43 height: usize,
44
45 inplace_scratch_len: usize,
46 outofplace_scratch_len: usize,
47
48 direction: FftDirection,
49}
50
51impl<T: FftNum> MixedRadix<T> {
52 pub fn new(width_fft: Arc<dyn Fft<T>>, height_fft: Arc<dyn Fft<T>>) -> Self {
54 assert_eq!(
55 width_fft.fft_direction(), height_fft.fft_direction(),
56 "width_fft and height_fft must have the same direction. got width direction={}, height direction={}",
57 width_fft.fft_direction(), height_fft.fft_direction());
58
59 let direction = width_fft.fft_direction();
60
61 let width = width_fft.len();
62 let height = height_fft.len();
63
64 let len = width * height;
65
66 let mut twiddles = vec![Complex::zero(); len];
67 for (x, twiddle_chunk) in twiddles.chunks_exact_mut(height).enumerate() {
68 for (y, twiddle_element) in twiddle_chunk.iter_mut().enumerate() {
69 *twiddle_element = twiddles::compute_twiddle(x * y, len, direction);
70 }
71 }
72
73 let height_inplace_scratch = height_fft.get_inplace_scratch_len();
75 let width_inplace_scratch = width_fft.get_inplace_scratch_len();
76 let width_outofplace_scratch = width_fft.get_outofplace_scratch_len();
77
78 let max_inner_inplace_scratch = max(height_inplace_scratch, width_inplace_scratch);
86 let outofplace_scratch_len = if max_inner_inplace_scratch > len {
87 max_inner_inplace_scratch
88 } else {
89 0
90 };
91
92 let inplace_scratch_len = len
97 + max(
98 if height_inplace_scratch > len {
99 height_inplace_scratch
100 } else {
101 0
102 },
103 width_outofplace_scratch,
104 );
105
106 Self {
107 twiddles: twiddles.into_boxed_slice(),
108
109 width_size_fft: width_fft,
110 width: width,
111
112 height_size_fft: height_fft,
113 height: height,
114
115 inplace_scratch_len,
116 outofplace_scratch_len,
117
118 direction,
119 }
120 }
121
122 fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
123 let (scratch, inner_scratch) = scratch.split_at_mut(self.len());
125
126 transpose::transpose(buffer, scratch, self.width, self.height);
128
129 let height_scratch = if inner_scratch.len() > buffer.len() {
131 &mut inner_scratch[..]
132 } else {
133 &mut buffer[..]
134 };
135 self.height_size_fft
136 .process_with_scratch(scratch, height_scratch);
137
138 for (element, twiddle) in scratch.iter_mut().zip(self.twiddles.iter()) {
140 *element = *element * twiddle;
141 }
142
143 transpose::transpose(scratch, buffer, self.height, self.width);
145
146 self.width_size_fft
148 .process_outofplace_with_scratch(buffer, scratch, inner_scratch);
149
150 transpose::transpose(scratch, buffer, self.width, self.height);
152 }
153
154 fn perform_fft_out_of_place(
155 &self,
156 input: &mut [Complex<T>],
157 output: &mut [Complex<T>],
158 scratch: &mut [Complex<T>],
159 ) {
160 transpose::transpose(input, output, self.width, self.height);
164
165 let height_scratch = if scratch.len() > input.len() {
167 &mut scratch[..]
168 } else {
169 &mut input[..]
170 };
171 self.height_size_fft
172 .process_with_scratch(output, height_scratch);
173
174 for (element, twiddle) in output.iter_mut().zip(self.twiddles.iter()) {
176 *element = *element * twiddle;
177 }
178
179 transpose::transpose(output, input, self.height, self.width);
181
182 let width_scratch = if scratch.len() > output.len() {
184 &mut scratch[..]
185 } else {
186 &mut output[..]
187 };
188 self.width_size_fft
189 .process_with_scratch(input, width_scratch);
190
191 transpose::transpose(input, output, self.width, self.height);
193 }
194}
195boilerplate_fft!(
196 MixedRadix,
197 |this: &MixedRadix<_>| this.twiddles.len(),
198 |this: &MixedRadix<_>| this.inplace_scratch_len,
199 |this: &MixedRadix<_>| this.outofplace_scratch_len
200);
201
202pub struct MixedRadixSmall<T> {
229 twiddles: Box<[Complex<T>]>,
230
231 width_size_fft: Arc<dyn Fft<T>>,
232 width: usize,
233
234 height_size_fft: Arc<dyn Fft<T>>,
235 height: usize,
236
237 direction: FftDirection,
238}
239
240impl<T: FftNum> MixedRadixSmall<T> {
241 pub fn new(width_fft: Arc<dyn Fft<T>>, height_fft: Arc<dyn Fft<T>>) -> Self {
243 assert_eq!(
244 width_fft.fft_direction(), height_fft.fft_direction(),
245 "width_fft and height_fft must have the same direction. got width direction={}, height direction={}",
246 width_fft.fft_direction(), height_fft.fft_direction());
247
248 let width = width_fft.len();
250 let height = height_fft.len();
251 let len = width * height;
252
253 assert_eq!(width_fft.get_outofplace_scratch_len(), 0, "MixedRadixSmall should only be used with algorithms that require 0 out-of-place scratch. Width FFT (len={}) requires {}, should require 0", width, width_fft.get_outofplace_scratch_len());
254 assert_eq!(height_fft.get_outofplace_scratch_len(), 0, "MixedRadixSmall should only be used with algorithms that require 0 out-of-place scratch. Height FFT (len={}) requires {}, should require 0", height, height_fft.get_outofplace_scratch_len());
255
256 assert!(width_fft.get_inplace_scratch_len() <= width, "MixedRadixSmall should only be used with algorithms that require little inplace scratch. Width FFT (len={}) requires {}, should require {} or less", width, width_fft.get_inplace_scratch_len(), width);
257 assert!(height_fft.get_inplace_scratch_len() <= height, "MixedRadixSmall should only be used with algorithms that require little inplace scratch. Height FFT (len={}) requires {}, should require {} or less", height, height_fft.get_inplace_scratch_len(), height);
258
259 let direction = width_fft.fft_direction();
260
261 let mut twiddles = vec![Complex::zero(); len];
262 for (x, twiddle_chunk) in twiddles.chunks_exact_mut(height).enumerate() {
263 for (y, twiddle_element) in twiddle_chunk.iter_mut().enumerate() {
264 *twiddle_element = twiddles::compute_twiddle(x * y, len, direction);
265 }
266 }
267
268 Self {
269 twiddles: twiddles.into_boxed_slice(),
270
271 width_size_fft: width_fft,
272 width: width,
273
274 height_size_fft: height_fft,
275 height: height,
276
277 direction,
278 }
279 }
280
281 fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
282 unsafe { array_utils::transpose_small(self.width, self.height, buffer, scratch) };
285
286 self.height_size_fft.process_with_scratch(scratch, buffer);
288
289 for (element, twiddle) in scratch.iter_mut().zip(self.twiddles.iter()) {
291 *element = *element * twiddle;
292 }
293
294 unsafe { array_utils::transpose_small(self.height, self.width, scratch, buffer) };
296
297 self.width_size_fft
299 .process_outofplace_with_scratch(buffer, scratch, &mut []);
300
301 unsafe { array_utils::transpose_small(self.width, self.height, scratch, buffer) };
303 }
304
305 fn perform_fft_out_of_place(
306 &self,
307 input: &mut [Complex<T>],
308 output: &mut [Complex<T>],
309 _scratch: &mut [Complex<T>],
310 ) {
311 unsafe { array_utils::transpose_small(self.width, self.height, input, output) };
314
315 self.height_size_fft.process_with_scratch(output, input);
317
318 for (element, twiddle) in output.iter_mut().zip(self.twiddles.iter()) {
320 *element = *element * twiddle;
321 }
322
323 unsafe { array_utils::transpose_small(self.height, self.width, output, input) };
325
326 self.width_size_fft.process_with_scratch(input, output);
328
329 unsafe { array_utils::transpose_small(self.width, self.height, input, output) };
331 }
332}
333boilerplate_fft!(
334 MixedRadixSmall,
335 |this: &MixedRadixSmall<_>| this.twiddles.len(),
336 |this: &MixedRadixSmall<_>| this.len(),
337 |_| 0
338);
339
340#[cfg(test)]
341mod unit_tests {
342 use super::*;
343 use crate::test_utils::check_fft_algorithm;
344 use crate::{algorithm::Dft, test_utils::BigScratchAlgorithm};
345 use num_traits::Zero;
346 use std::sync::Arc;
347
348 #[test]
349 fn test_mixed_radix() {
350 for width in 1..7 {
351 for height in 1..7 {
352 test_mixed_radix_with_lengths(width, height, FftDirection::Forward);
353 test_mixed_radix_with_lengths(width, height, FftDirection::Inverse);
354 }
355 }
356 }
357
358 #[test]
359 fn test_mixed_radix_small() {
360 for width in 2..7 {
361 for height in 2..7 {
362 test_mixed_radix_small_with_lengths(width, height, FftDirection::Forward);
363 test_mixed_radix_small_with_lengths(width, height, FftDirection::Inverse);
364 }
365 }
366 }
367
368 fn test_mixed_radix_with_lengths(width: usize, height: usize, direction: FftDirection) {
369 let width_fft = Arc::new(Dft::new(width, direction)) as Arc<dyn Fft<f32>>;
370 let height_fft = Arc::new(Dft::new(height, direction)) as Arc<dyn Fft<f32>>;
371
372 let fft = MixedRadix::new(width_fft, height_fft);
373
374 check_fft_algorithm(&fft, width * height, direction);
375 }
376
377 fn test_mixed_radix_small_with_lengths(width: usize, height: usize, direction: FftDirection) {
378 let width_fft = Arc::new(Dft::new(width, direction)) as Arc<dyn Fft<f32>>;
379 let height_fft = Arc::new(Dft::new(height, direction)) as Arc<dyn Fft<f32>>;
380
381 let fft = MixedRadixSmall::new(width_fft, height_fft);
382
383 check_fft_algorithm(&fft, width * height, direction);
384 }
385
386 #[test]
388 fn test_mixed_radix_inner_scratch() {
389 let scratch_lengths = [1, 5, 25];
390
391 let mut inner_ffts = Vec::new();
392
393 for &len in &scratch_lengths {
394 for &inplace_scratch in &scratch_lengths {
395 for &outofplace_scratch in &scratch_lengths {
396 inner_ffts.push(Arc::new(BigScratchAlgorithm {
397 len,
398 inplace_scratch,
399 outofplace_scratch,
400 direction: FftDirection::Forward,
401 }) as Arc<dyn Fft<f32>>);
402 }
403 }
404 }
405
406 for width_fft in inner_ffts.iter() {
407 for height_fft in inner_ffts.iter() {
408 let fft = MixedRadix::new(Arc::clone(width_fft), Arc::clone(height_fft));
409
410 let mut inplace_buffer = vec![Complex::zero(); fft.len()];
411 let mut inplace_scratch = vec![Complex::zero(); fft.get_inplace_scratch_len()];
412
413 fft.process_with_scratch(&mut inplace_buffer, &mut inplace_scratch);
414
415 let mut outofplace_input = vec![Complex::zero(); fft.len()];
416 let mut outofplace_output = vec![Complex::zero(); fft.len()];
417 let mut outofplace_scratch =
418 vec![Complex::zero(); fft.get_outofplace_scratch_len()];
419 fft.process_outofplace_with_scratch(
420 &mut outofplace_input,
421 &mut outofplace_output,
422 &mut outofplace_scratch,
423 );
424 }
425 }
426 }
427}