rustfft/algorithm/
bluesteins_algorithm.rs1use std::sync::Arc;
2
3use num_complex::Complex;
4use num_traits::Zero;
5
6use crate::array_utils;
7use crate::common::{fft_error_inplace, fft_error_outofplace};
8use crate::{common::FftNum, twiddles, FftDirection};
9use crate::{Direction, Fft, Length};
10
11pub struct BluesteinsAlgorithm<T> {
42 inner_fft: Arc<dyn Fft<T>>,
43
44 inner_fft_multiplier: Box<[Complex<T>]>,
45 twiddles: Box<[Complex<T>]>,
46
47 len: usize,
48 direction: FftDirection,
49}
50
51impl<T: FftNum> BluesteinsAlgorithm<T> {
52 pub fn new(len: usize, inner_fft: Arc<dyn Fft<T>>) -> Self {
61 let inner_fft_len = inner_fft.len();
62 assert!(len * 2 - 1 <= inner_fft_len, "Bluestein's algorithm requires inner_fft.len() >= self.len() * 2 - 1. Expected >= {}, got {}", len * 2 - 1, inner_fft_len);
63
64 let inner_fft_scale = T::one() / T::from_usize(inner_fft_len).unwrap();
66 let direction = inner_fft.fft_direction();
67
68 let mut inner_fft_input = vec![Complex::zero(); inner_fft_len];
70 twiddles::fill_bluesteins_twiddles(
71 &mut inner_fft_input[..len],
72 direction.opposite_direction(),
73 );
74
75 inner_fft_input[0] = inner_fft_input[0] * inner_fft_scale;
77 for i in 1..len {
78 let twiddle = inner_fft_input[i] * inner_fft_scale;
79 inner_fft_input[i] = twiddle;
80 inner_fft_input[inner_fft_len - i] = twiddle;
81 }
82
83 let mut inner_fft_scratch = vec![Complex::zero(); inner_fft.get_inplace_scratch_len()];
85 inner_fft.process_with_scratch(&mut inner_fft_input, &mut inner_fft_scratch);
86
87 let mut twiddles = vec![Complex::zero(); len];
89 twiddles::fill_bluesteins_twiddles(&mut twiddles, direction);
90
91 Self {
92 inner_fft: inner_fft,
93
94 inner_fft_multiplier: inner_fft_input.into_boxed_slice(),
95 twiddles: twiddles.into_boxed_slice(),
96
97 len,
98 direction,
99 }
100 }
101
102 fn perform_fft_inplace(&self, input: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
103 let (inner_input, inner_scratch) = scratch.split_at_mut(self.inner_fft_multiplier.len());
104
105 for ((buffer_entry, inner_entry), twiddle) in input
107 .iter()
108 .zip(inner_input.iter_mut())
109 .zip(self.twiddles.iter())
110 {
111 *inner_entry = *buffer_entry * *twiddle;
112 }
113 for inner in (&mut inner_input[input.len()..]).iter_mut() {
114 *inner = Complex::zero();
115 }
116
117 self.inner_fft
119 .process_with_scratch(inner_input, inner_scratch);
120
121 for (inner, multiplier) in inner_input.iter_mut().zip(self.inner_fft_multiplier.iter()) {
123 *inner = (*inner * *multiplier).conj();
124 }
125
126 self.inner_fft
128 .process_with_scratch(inner_input, inner_scratch);
129
130 for ((buffer_entry, inner_entry), twiddle) in input
132 .iter_mut()
133 .zip(inner_input.iter())
134 .zip(self.twiddles.iter())
135 {
136 *buffer_entry = inner_entry.conj() * twiddle;
137 }
138 }
139
140 fn perform_fft_out_of_place(
141 &self,
142 input: &mut [Complex<T>],
143 output: &mut [Complex<T>],
144 scratch: &mut [Complex<T>],
145 ) {
146 let (inner_input, inner_scratch) = scratch.split_at_mut(self.inner_fft_multiplier.len());
147
148 for ((buffer_entry, inner_entry), twiddle) in input
150 .iter()
151 .zip(inner_input.iter_mut())
152 .zip(self.twiddles.iter())
153 {
154 *inner_entry = *buffer_entry * *twiddle;
155 }
156 for inner in inner_input.iter_mut().skip(input.len()) {
157 *inner = Complex::zero();
158 }
159
160 self.inner_fft
162 .process_with_scratch(inner_input, inner_scratch);
163
164 for (inner, multiplier) in inner_input.iter_mut().zip(self.inner_fft_multiplier.iter()) {
166 *inner = (*inner * *multiplier).conj();
167 }
168
169 self.inner_fft
171 .process_with_scratch(inner_input, inner_scratch);
172
173 for ((buffer_entry, inner_entry), twiddle) in output
175 .iter_mut()
176 .zip(inner_input.iter())
177 .zip(self.twiddles.iter())
178 {
179 *buffer_entry = inner_entry.conj() * twiddle;
180 }
181 }
182}
183boilerplate_fft!(
184 BluesteinsAlgorithm,
185 |this: &BluesteinsAlgorithm<_>| this.len, |this: &BluesteinsAlgorithm<_>| this.inner_fft_multiplier.len()
187 + this.inner_fft.get_inplace_scratch_len(), |this: &BluesteinsAlgorithm<_>| this.inner_fft_multiplier.len()
189 + this.inner_fft.get_inplace_scratch_len() );
191
192#[cfg(test)]
193mod unit_tests {
194 use super::*;
195 use crate::algorithm::Dft;
196 use crate::test_utils::check_fft_algorithm;
197 use std::sync::Arc;
198
199 #[test]
200 fn test_bluesteins_scalar() {
201 for &len in &[3, 5, 7, 11, 13] {
202 test_bluesteins_with_length(len, FftDirection::Forward);
203 test_bluesteins_with_length(len, FftDirection::Inverse);
204 }
205 }
206
207 fn test_bluesteins_with_length(len: usize, direction: FftDirection) {
208 let inner_fft = Arc::new(Dft::new(
209 (len * 2 - 1).checked_next_power_of_two().unwrap(),
210 direction,
211 ));
212 let fft = BluesteinsAlgorithm::new(len, inner_fft);
213
214 check_fft_algorithm::<f32>(&fft, len, direction);
215 }
216}