1use std::cmp::max;
2use std::sync::Arc;
3
4use num_complex::Complex;
5use num_integer::Integer;
6use strength_reduce::StrengthReducedUsize;
7use transpose;
8
9use crate::array_utils;
10use crate::common::{fft_error_inplace, fft_error_outofplace};
11use crate::{common::FftNum, FftDirection};
12use crate::{Direction, Fft, Length};
13
14pub struct GoodThomasAlgorithm<T> {
42 width: usize,
43 width_size_fft: Arc<dyn Fft<T>>,
44
45 height: usize,
46 height_size_fft: Arc<dyn Fft<T>>,
47
48 reduced_width: StrengthReducedUsize,
49 reduced_width_plus_one: StrengthReducedUsize,
50
51 inplace_scratch_len: usize,
52 outofplace_scratch_len: usize,
53
54 len: usize,
55 direction: FftDirection,
56}
57
58impl<T: FftNum> GoodThomasAlgorithm<T> {
59 pub fn new(mut width_fft: Arc<dyn Fft<T>>, mut height_fft: Arc<dyn Fft<T>>) -> Self {
63 assert_eq!(
64 width_fft.fft_direction(), height_fft.fft_direction(),
65 "width_fft and height_fft must have the same direction. got width direction={}, height direction={}",
66 width_fft.fft_direction(), height_fft.fft_direction());
67
68 let mut width = width_fft.len();
69 let mut height = height_fft.len();
70 let direction = width_fft.fft_direction();
71
72 let gcd = num_integer::gcd(width as i64, height as i64);
74 assert!(gcd == 1,
75 "Invalid width and height for Good-Thomas Algorithm (width={}, height={}): Inputs must be coprime",
76 width,
77 height);
78
79 if width > height {
81 std::mem::swap(&mut width, &mut height);
82 std::mem::swap(&mut width_fft, &mut height_fft);
83 }
84
85 let len = width * height;
86
87 let width_inplace_scratch = width_fft.get_inplace_scratch_len();
89 let height_inplace_scratch = height_fft.get_inplace_scratch_len();
90 let height_outofplace_scratch = height_fft.get_outofplace_scratch_len();
91
92 let max_inner_inplace_scratch = max(height_inplace_scratch, width_inplace_scratch);
100 let outofplace_scratch_len = if max_inner_inplace_scratch > len {
101 max_inner_inplace_scratch
102 } else {
103 0
104 };
105
106 let inplace_scratch_len = len
111 + max(
112 if width_inplace_scratch > len {
113 width_inplace_scratch
114 } else {
115 0
116 },
117 height_outofplace_scratch,
118 );
119
120 Self {
121 width,
122 width_size_fft: width_fft,
123
124 height,
125 height_size_fft: height_fft,
126
127 reduced_width: StrengthReducedUsize::new(width),
128 reduced_width_plus_one: StrengthReducedUsize::new(width + 1),
129
130 inplace_scratch_len,
131 outofplace_scratch_len,
132
133 len,
134 direction,
135 }
136 }
137
138 fn reindex_input(&self, source: &[Complex<T>], destination: &mut [Complex<T>]) {
139 let mut destination_index = 0;
154 for mut source_row in source.chunks_exact(self.width) {
155 let increments_until_cycle =
156 1 + (self.len() - destination_index) / self.reduced_width_plus_one;
157
158 if increments_until_cycle < self.width {
160 let (pre_cycle_row, post_cycle_row) = source_row.split_at(increments_until_cycle);
161
162 for input_element in pre_cycle_row {
163 destination[destination_index] = *input_element;
164 destination_index += self.reduced_width_plus_one.get();
165 }
166
167 source_row = post_cycle_row;
169 destination_index -= self.len();
170 }
171
172 for input_element in source_row {
174 destination[destination_index] = *input_element;
175 destination_index += self.reduced_width_plus_one.get();
176 }
177
178 destination_index -= self.width;
181 }
182 }
183
184 fn reindex_output(&self, source: &[Complex<T>], destination: &mut [Complex<T>]) {
185 for (y, source_chunk) in source.chunks_exact(self.height).enumerate() {
197 let (quotient, remainder) =
198 StrengthReducedUsize::div_rem(y * self.height, self.reduced_width);
199
200 let mut destination_index = remainder;
202 let start_x = self.height - quotient;
203
204 for x in start_x..self.height {
206 destination[destination_index] = source_chunk[x];
207 destination_index += self.width;
208 }
209
210 for x in 0..start_x {
212 destination[destination_index] = source_chunk[x];
213 destination_index += self.width;
214 }
215 }
216 }
217
218 fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
219 let (scratch, inner_scratch) = scratch.split_at_mut(self.len());
220
221 self.reindex_input(buffer, scratch);
223
224 let width_scratch = if inner_scratch.len() > buffer.len() {
226 &mut inner_scratch[..]
227 } else {
228 &mut buffer[..]
229 };
230 self.width_size_fft
231 .process_with_scratch(scratch, width_scratch);
232
233 transpose::transpose(scratch, buffer, self.width, self.height);
235
236 self.height_size_fft
238 .process_outofplace_with_scratch(buffer, scratch, inner_scratch);
239
240 self.reindex_output(scratch, buffer);
242 }
243
244 fn perform_fft_out_of_place(
245 &self,
246 input: &mut [Complex<T>],
247 output: &mut [Complex<T>],
248 scratch: &mut [Complex<T>],
249 ) {
250 self.reindex_input(input, output);
252
253 let width_scratch = if scratch.len() > input.len() {
255 &mut scratch[..]
256 } else {
257 &mut input[..]
258 };
259 self.width_size_fft
260 .process_with_scratch(output, width_scratch);
261
262 transpose::transpose(output, input, self.width, self.height);
264
265 let height_scratch = if scratch.len() > output.len() {
267 &mut scratch[..]
268 } else {
269 &mut output[..]
270 };
271 self.height_size_fft
272 .process_with_scratch(input, height_scratch);
273
274 self.reindex_output(input, output);
276 }
277}
278boilerplate_fft!(
279 GoodThomasAlgorithm,
280 |this: &GoodThomasAlgorithm<_>| this.len,
281 |this: &GoodThomasAlgorithm<_>| this.inplace_scratch_len,
282 |this: &GoodThomasAlgorithm<_>| this.outofplace_scratch_len
283);
284
285pub struct GoodThomasAlgorithmSmall<T> {
313 width: usize,
314 width_size_fft: Arc<dyn Fft<T>>,
315
316 height: usize,
317 height_size_fft: Arc<dyn Fft<T>>,
318
319 input_output_map: Box<[usize]>,
320
321 direction: FftDirection,
322}
323
324impl<T: FftNum> GoodThomasAlgorithmSmall<T> {
325 pub fn new(width_fft: Arc<dyn Fft<T>>, height_fft: Arc<dyn Fft<T>>) -> Self {
329 assert_eq!(
330 width_fft.fft_direction(), height_fft.fft_direction(),
331 "n1_fft and height_fft must have the same direction. got width direction={}, height direction={}",
332 width_fft.fft_direction(), height_fft.fft_direction());
333
334 let width = width_fft.len();
335 let height = height_fft.len();
336 let len = width * height;
337
338 assert_eq!(width_fft.get_outofplace_scratch_len(), 0, "GoodThomasAlgorithmSmall should only be used with algorithms that require 0 out-of-place scratch. Width FFT (len={}) requires {}, should require 0", width, width_fft.get_outofplace_scratch_len());
339 assert_eq!(height_fft.get_outofplace_scratch_len(), 0, "GoodThomasAlgorithmSmall should only be used with algorithms that require 0 out-of-place scratch. Height FFT (len={}) requires {}, should require 0", height, height_fft.get_outofplace_scratch_len());
340
341 assert!(width_fft.get_inplace_scratch_len() <= width, "GoodThomasAlgorithmSmall should only be used with algorithms that require little inplace scratch. Width FFT (len={}) requires {}, should require {} or less", width, width_fft.get_inplace_scratch_len(), width);
342 assert!(height_fft.get_inplace_scratch_len() <= height, "GoodThomasAlgorithmSmall should only be used with algorithms that require little inplace scratch. Height FFT (len={}) requires {}, should require {} or less", height, height_fft.get_inplace_scratch_len(), height);
343
344 let gcd_data = i64::extended_gcd(&(width as i64), &(height as i64));
346 assert!(gcd_data.gcd == 1,
347 "Invalid input width and height to Good-Thomas Algorithm: ({},{}): Inputs must be coprime",
348 width,
349 height);
350
351 let width_inverse = if gcd_data.x >= 0 {
353 gcd_data.x
354 } else {
355 gcd_data.x + height as i64
356 } as usize;
357 let height_inverse = if gcd_data.y >= 0 {
358 gcd_data.y
359 } else {
360 gcd_data.y + width as i64
361 } as usize;
362
363 let input_iter = (0..len)
366 .map(|i| (i % width, i / width))
367 .map(|(x, y)| (x * height + y * width) % len);
368 let output_iter = (0..len).map(|i| (i % height, i / height)).map(|(y, x)| {
369 (x * height * height_inverse as usize + y * width * width_inverse as usize) % len
370 });
371
372 let input_output_map: Vec<usize> = input_iter.chain(output_iter).collect();
373
374 Self {
375 direction: width_fft.fft_direction(),
376
377 width,
378 width_size_fft: width_fft,
379
380 height,
381 height_size_fft: height_fft,
382
383 input_output_map: input_output_map.into_boxed_slice(),
384 }
385 }
386
387 fn perform_fft_out_of_place(
388 &self,
389 input: &mut [Complex<T>],
390 output: &mut [Complex<T>],
391 _scratch: &mut [Complex<T>],
392 ) {
393 assert_eq!(self.len(), input.len());
395 assert_eq!(self.len(), output.len());
396
397 let (input_map, output_map) = self.input_output_map.split_at(self.len());
398
399 for (output_element, &input_index) in output.iter_mut().zip(input_map.iter()) {
401 *output_element = input[input_index];
402 }
403
404 self.width_size_fft.process_with_scratch(output, input);
406
407 unsafe { array_utils::transpose_small(self.width, self.height, output, input) };
409
410 self.height_size_fft.process_with_scratch(input, output);
412
413 for (input_element, &output_index) in input.iter().zip(output_map.iter()) {
415 output[output_index] = *input_element;
416 }
417 }
418
419 fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
420 assert_eq!(self.len(), buffer.len());
422 assert_eq!(self.len(), scratch.len());
423
424 let (input_map, output_map) = self.input_output_map.split_at(self.len());
425
426 for (output_element, &input_index) in scratch.iter_mut().zip(input_map.iter()) {
428 *output_element = buffer[input_index];
429 }
430
431 self.width_size_fft.process_with_scratch(scratch, buffer);
433
434 unsafe { array_utils::transpose_small(self.width, self.height, scratch, buffer) };
436
437 self.height_size_fft
439 .process_outofplace_with_scratch(buffer, scratch, &mut []);
440
441 for (input_element, &output_index) in scratch.iter().zip(output_map.iter()) {
443 buffer[output_index] = *input_element;
444 }
445 }
446}
447boilerplate_fft!(
448 GoodThomasAlgorithmSmall,
449 |this: &GoodThomasAlgorithmSmall<_>| this.width * this.height,
450 |this: &GoodThomasAlgorithmSmall<_>| this.len(),
451 |_| 0
452);
453
454#[cfg(test)]
455mod unit_tests {
456 use super::*;
457 use crate::test_utils::check_fft_algorithm;
458 use crate::{algorithm::Dft, test_utils::BigScratchAlgorithm};
459 use num_integer::gcd;
460 use num_traits::Zero;
461 use std::sync::Arc;
462
463 #[test]
464 fn test_good_thomas() {
465 for width in 1..12 {
466 for height in 1..12 {
467 if gcd(width, height) == 1 {
468 test_good_thomas_with_lengths(width, height, FftDirection::Forward);
469 test_good_thomas_with_lengths(width, height, FftDirection::Inverse);
470 }
471 }
472 }
473 }
474
475 #[test]
476 fn test_good_thomas_small() {
477 let butterfly_sizes = [2, 3, 4, 5, 6, 7, 8, 16];
478 for width in &butterfly_sizes {
479 for height in &butterfly_sizes {
480 if gcd(*width, *height) == 1 {
481 test_good_thomas_small_with_lengths(*width, *height, FftDirection::Forward);
482 test_good_thomas_small_with_lengths(*width, *height, FftDirection::Inverse);
483 }
484 }
485 }
486 }
487
488 fn test_good_thomas_with_lengths(width: usize, height: usize, direction: FftDirection) {
489 let width_fft = Arc::new(Dft::new(width, direction)) as Arc<dyn Fft<f32>>;
490 let height_fft = Arc::new(Dft::new(height, direction)) as Arc<dyn Fft<f32>>;
491
492 let fft = GoodThomasAlgorithm::new(width_fft, height_fft);
493
494 check_fft_algorithm(&fft, width * height, direction);
495 }
496
497 fn test_good_thomas_small_with_lengths(width: usize, height: usize, direction: FftDirection) {
498 let width_fft = Arc::new(Dft::new(width, direction)) as Arc<dyn Fft<f32>>;
499 let height_fft = Arc::new(Dft::new(height, direction)) as Arc<dyn Fft<f32>>;
500
501 let fft = GoodThomasAlgorithmSmall::new(width_fft, height_fft);
502
503 check_fft_algorithm(&fft, width * height, direction);
504 }
505
506 #[test]
507 fn test_output_mapping() {
508 let width = 15;
509 for height in 3..width {
510 if gcd(width, height) == 1 {
511 let width_fft =
512 Arc::new(Dft::new(width, FftDirection::Forward)) as Arc<dyn Fft<f32>>;
513 let height_fft =
514 Arc::new(Dft::new(height, FftDirection::Forward)) as Arc<dyn Fft<f32>>;
515
516 let fft = GoodThomasAlgorithm::new(width_fft, height_fft);
517
518 let mut buffer = vec![Complex { re: 0.0, im: 0.0 }; fft.len()];
519
520 fft.process(&mut buffer);
521 }
522 }
523 }
524
525 #[test]
527 fn test_good_thomas_inner_scratch() {
528 let scratch_lengths = [1, 5, 24];
529
530 let mut inner_ffts = Vec::new();
531
532 for &len in &scratch_lengths {
533 for &inplace_scratch in &scratch_lengths {
534 for &outofplace_scratch in &scratch_lengths {
535 inner_ffts.push(Arc::new(BigScratchAlgorithm {
536 len,
537 inplace_scratch,
538 outofplace_scratch,
539 direction: FftDirection::Forward,
540 }) as Arc<dyn Fft<f32>>);
541 }
542 }
543 }
544
545 for width_fft in inner_ffts.iter() {
546 for height_fft in inner_ffts.iter() {
547 if width_fft.len() == height_fft.len() {
548 continue;
549 }
550
551 let fft = GoodThomasAlgorithm::new(Arc::clone(width_fft), Arc::clone(height_fft));
552
553 let mut inplace_buffer = vec![Complex::zero(); fft.len()];
554 let mut inplace_scratch = vec![Complex::zero(); fft.get_inplace_scratch_len()];
555
556 fft.process_with_scratch(&mut inplace_buffer, &mut inplace_scratch);
557
558 let mut outofplace_input = vec![Complex::zero(); fft.len()];
559 let mut outofplace_output = vec![Complex::zero(); fft.len()];
560 let mut outofplace_scratch =
561 vec![Complex::zero(); fft.get_outofplace_scratch_len()];
562
563 fft.process_outofplace_with_scratch(
564 &mut outofplace_input,
565 &mut outofplace_output,
566 &mut outofplace_scratch,
567 );
568 }
569 }
570 }
571}