1use num_complex::Complex;
2use num_traits::Zero;
3
4use crate::array_utils;
5use crate::common::{fft_error_inplace, fft_error_outofplace};
6use crate::{twiddles, FftDirection};
7use crate::{Direction, Fft, FftNum, Length};
8
9pub struct Dft<T> {
25 twiddles: Vec<Complex<T>>,
26 direction: FftDirection,
27}
28
29impl<T: FftNum> Dft<T> {
30 pub fn new(len: usize, direction: FftDirection) -> Self {
32 let twiddles = (0..len)
33 .map(|i| twiddles::compute_twiddle(i, len, direction))
34 .collect();
35 Self {
36 twiddles,
37 direction,
38 }
39 }
40
41 fn inplace_scratch_len(&self) -> usize {
42 self.len()
43 }
44 fn outofplace_scratch_len(&self) -> usize {
45 0
46 }
47
48 fn perform_fft_out_of_place(
49 &self,
50 signal: &[Complex<T>],
51 spectrum: &mut [Complex<T>],
52 _scratch: &mut [Complex<T>],
53 ) {
54 for k in 0..spectrum.len() {
55 let output_cell = spectrum.get_mut(k).unwrap();
56
57 *output_cell = Zero::zero();
58 let mut twiddle_index = 0;
59
60 for input_cell in signal {
61 let twiddle = self.twiddles[twiddle_index];
62 *output_cell = *output_cell + twiddle * input_cell;
63
64 twiddle_index += k;
65 if twiddle_index >= self.twiddles.len() {
66 twiddle_index -= self.twiddles.len();
67 }
68 }
69 }
70 }
71}
72boilerplate_fft_oop!(Dft, |this: &Dft<_>| this.twiddles.len());
73
74#[cfg(test)]
75mod unit_tests {
76 use super::*;
77 use crate::test_utils::{compare_vectors, random_signal};
78 use num_complex::Complex;
79 use num_traits::Zero;
80 use std::f32;
81
82 fn dft(signal: &[Complex<f32>], spectrum: &mut [Complex<f32>]) {
83 for (k, spec_bin) in spectrum.iter_mut().enumerate() {
84 let mut sum = Zero::zero();
85 for (i, &x) in signal.iter().enumerate() {
86 let angle = -1f32 * (i * k) as f32 * 2f32 * f32::consts::PI / signal.len() as f32;
87 let twiddle = Complex::from_polar(1f32, angle);
88
89 sum = sum + twiddle * x;
90 }
91 *spec_bin = sum;
92 }
93 }
94
95 #[test]
96 fn test_matches_dft() {
97 let n = 4;
98
99 for len in 1..20 {
100 let dft_instance = Dft::new(len, FftDirection::Forward);
101 assert_eq!(
102 dft_instance.len(),
103 len,
104 "Dft instance reported incorrect length"
105 );
106
107 let input = random_signal(len * n);
108 let mut expected_output = input.clone();
109
110 for (input_chunk, output_chunk) in
112 input.chunks(len).zip(expected_output.chunks_mut(len))
113 {
114 dft(input_chunk, output_chunk);
115 }
116
117 {
119 let mut inplace_buffer = input.clone();
120
121 dft_instance.process(&mut inplace_buffer);
122
123 assert!(
124 compare_vectors(&expected_output, &inplace_buffer),
125 "process() failed, length = {}",
126 len
127 );
128 }
129
130 {
132 let mut inplace_with_scratch_buffer = input.clone();
133 let mut inplace_scratch =
134 vec![Zero::zero(); dft_instance.get_inplace_scratch_len()];
135
136 dft_instance
137 .process_with_scratch(&mut inplace_with_scratch_buffer, &mut inplace_scratch);
138
139 assert!(
140 compare_vectors(&expected_output, &inplace_with_scratch_buffer),
141 "process_inplace() failed, length = {}",
142 len
143 );
144
145 for item in inplace_scratch.iter_mut() {
147 *item = Complex::new(100.0, 100.0);
148 }
149 inplace_with_scratch_buffer.copy_from_slice(&input);
150
151 dft_instance
152 .process_with_scratch(&mut inplace_with_scratch_buffer, &mut inplace_scratch);
153
154 assert!(
155 compare_vectors(&expected_output, &inplace_with_scratch_buffer),
156 "process_with_scratch() failed the 'dirty scratch' test for len = {}",
157 len
158 );
159 }
160
161 {
163 let mut outofplace_input = input.clone();
164 let mut outofplace_output = expected_output.clone();
165
166 dft_instance.process_outofplace_with_scratch(
167 &mut outofplace_input,
168 &mut outofplace_output,
169 &mut [],
170 );
171
172 assert!(
173 compare_vectors(&expected_output, &outofplace_output),
174 "process_outofplace_with_scratch() failed, length = {}",
175 len
176 );
177 }
178 }
179
180 let zero_dft = Dft::new(0, FftDirection::Forward);
182 let mut zero_input: Vec<Complex<f32>> = Vec::new();
183 let mut zero_output: Vec<Complex<f32>> = Vec::new();
184 let mut zero_scratch: Vec<Complex<f32>> = Vec::new();
185
186 zero_dft.process(&mut zero_input);
187 zero_dft.process_with_scratch(&mut zero_input, &mut zero_scratch);
188 zero_dft.process_outofplace_with_scratch(
189 &mut zero_input,
190 &mut zero_output,
191 &mut zero_scratch,
192 );
193 }
194
195 fn test_dft_correct(input: &[Complex<f32>], expected_output: &[Complex<f32>]) {
198 assert_eq!(input.len(), expected_output.len());
199 let len = input.len();
200
201 let mut reference_output = vec![Zero::zero(); len];
202 dft(&input, &mut reference_output);
203 assert!(
204 compare_vectors(expected_output, &reference_output),
205 "Reference implementation failed for len={}",
206 len
207 );
208
209 let dft_instance = Dft::new(len, FftDirection::Forward);
210
211 {
213 let mut inplace_buffer = input.to_vec();
214
215 dft_instance.process(&mut inplace_buffer);
216
217 assert!(
218 compare_vectors(&expected_output, &inplace_buffer),
219 "process() failed, length = {}",
220 len
221 );
222 }
223
224 {
226 let mut inplace_with_scratch_buffer = input.to_vec();
227 let mut inplace_scratch = vec![Zero::zero(); dft_instance.get_inplace_scratch_len()];
228
229 dft_instance
230 .process_with_scratch(&mut inplace_with_scratch_buffer, &mut inplace_scratch);
231
232 assert!(
233 compare_vectors(&expected_output, &inplace_with_scratch_buffer),
234 "process_inplace() failed, length = {}",
235 len
236 );
237
238 for item in inplace_scratch.iter_mut() {
240 *item = Complex::new(100.0, 100.0);
241 }
242 inplace_with_scratch_buffer.copy_from_slice(&input);
243
244 dft_instance
245 .process_with_scratch(&mut inplace_with_scratch_buffer, &mut inplace_scratch);
246
247 assert!(
248 compare_vectors(&expected_output, &inplace_with_scratch_buffer),
249 "process_with_scratch() failed the 'dirty scratch' test for len = {}",
250 len
251 );
252 }
253
254 {
256 let mut outofplace_input = input.to_vec();
257 let mut outofplace_output = expected_output.to_vec();
258
259 dft_instance.process_outofplace_with_scratch(
260 &mut outofplace_input,
261 &mut outofplace_output,
262 &mut [],
263 );
264
265 assert!(
266 compare_vectors(&expected_output, &outofplace_output),
267 "process_outofplace_with_scratch() failed, length = {}",
268 len
269 );
270 }
271 }
272
273 #[test]
274 fn test_dft_known_len_2() {
275 let signal = [
276 Complex { re: 1f32, im: 0f32 },
277 Complex {
278 re: -1f32,
279 im: 0f32,
280 },
281 ];
282 let spectrum = [
283 Complex { re: 0f32, im: 0f32 },
284 Complex { re: 2f32, im: 0f32 },
285 ];
286 test_dft_correct(&signal[..], &spectrum[..]);
287 }
288
289 #[test]
290 fn test_dft_known_len_3() {
291 let signal = [
292 Complex { re: 1f32, im: 1f32 },
293 Complex {
294 re: 2f32,
295 im: -3f32,
296 },
297 Complex {
298 re: -1f32,
299 im: 4f32,
300 },
301 ];
302 let spectrum = [
303 Complex { re: 2f32, im: 2f32 },
304 Complex {
305 re: -5.562177f32,
306 im: -2.098076f32,
307 },
308 Complex {
309 re: 6.562178f32,
310 im: 3.09807f32,
311 },
312 ];
313 test_dft_correct(&signal[..], &spectrum[..]);
314 }
315
316 #[test]
317 fn test_dft_known_len_4() {
318 let signal = [
319 Complex { re: 0f32, im: 1f32 },
320 Complex {
321 re: 2.5f32,
322 im: -3f32,
323 },
324 Complex {
325 re: -1f32,
326 im: -1f32,
327 },
328 Complex { re: 4f32, im: 0f32 },
329 ];
330 let spectrum = [
331 Complex {
332 re: 5.5f32,
333 im: -3f32,
334 },
335 Complex {
336 re: -2f32,
337 im: 3.5f32,
338 },
339 Complex {
340 re: -7.5f32,
341 im: 3f32,
342 },
343 Complex {
344 re: 4f32,
345 im: 0.5f32,
346 },
347 ];
348 test_dft_correct(&signal[..], &spectrum[..]);
349 }
350
351 #[test]
352 fn test_dft_known_len_6() {
353 let signal = [
354 Complex { re: 1f32, im: 1f32 },
355 Complex { re: 2f32, im: 2f32 },
356 Complex { re: 3f32, im: 3f32 },
357 Complex { re: 4f32, im: 4f32 },
358 Complex { re: 5f32, im: 5f32 },
359 Complex { re: 6f32, im: 6f32 },
360 ];
361 let spectrum = [
362 Complex {
363 re: 21f32,
364 im: 21f32,
365 },
366 Complex {
367 re: -8.16f32,
368 im: 2.16f32,
369 },
370 Complex {
371 re: -4.76f32,
372 im: -1.24f32,
373 },
374 Complex {
375 re: -3f32,
376 im: -3f32,
377 },
378 Complex {
379 re: -1.24f32,
380 im: -4.76f32,
381 },
382 Complex {
383 re: 2.16f32,
384 im: -8.16f32,
385 },
386 ];
387 test_dft_correct(&signal[..], &spectrum[..]);
388 }
389}