rustfft/algorithm/
dft.rs

1use num_complex::Complex;
2use num_traits::Zero;
3
4use crate::array_utils;
5use crate::common::{fft_error_inplace, fft_error_outofplace};
6use crate::{twiddles, FftDirection};
7use crate::{Direction, Fft, FftNum, Length};
8
9/// Naive O(n^2 ) Discrete Fourier Transform implementation
10///
11/// This implementation is primarily used to test other FFT algorithms.
12///
13/// ~~~
14/// // Computes a naive DFT of size 123
15/// use rustfft::algorithm::Dft;
16/// use rustfft::{Fft, FftDirection};
17/// use rustfft::num_complex::Complex;
18///
19/// let mut buffer = vec![Complex{ re: 0.0f32, im: 0.0f32 }; 123];
20///
21/// let dft = Dft::new(123, FftDirection::Forward);
22/// dft.process(&mut buffer);
23/// ~~~
24pub struct Dft<T> {
25    twiddles: Vec<Complex<T>>,
26    direction: FftDirection,
27}
28
29impl<T: FftNum> Dft<T> {
30    /// Preallocates necessary arrays and precomputes necessary data to efficiently compute Dft
31    pub fn new(len: usize, direction: FftDirection) -> Self {
32        let twiddles = (0..len)
33            .map(|i| twiddles::compute_twiddle(i, len, direction))
34            .collect();
35        Self {
36            twiddles,
37            direction,
38        }
39    }
40
41    fn inplace_scratch_len(&self) -> usize {
42        self.len()
43    }
44    fn outofplace_scratch_len(&self) -> usize {
45        0
46    }
47
48    fn perform_fft_out_of_place(
49        &self,
50        signal: &[Complex<T>],
51        spectrum: &mut [Complex<T>],
52        _scratch: &mut [Complex<T>],
53    ) {
54        for k in 0..spectrum.len() {
55            let output_cell = spectrum.get_mut(k).unwrap();
56
57            *output_cell = Zero::zero();
58            let mut twiddle_index = 0;
59
60            for input_cell in signal {
61                let twiddle = self.twiddles[twiddle_index];
62                *output_cell = *output_cell + twiddle * input_cell;
63
64                twiddle_index += k;
65                if twiddle_index >= self.twiddles.len() {
66                    twiddle_index -= self.twiddles.len();
67                }
68            }
69        }
70    }
71}
72boilerplate_fft_oop!(Dft, |this: &Dft<_>| this.twiddles.len());
73
74#[cfg(test)]
75mod unit_tests {
76    use super::*;
77    use crate::test_utils::{compare_vectors, random_signal};
78    use num_complex::Complex;
79    use num_traits::Zero;
80    use std::f32;
81
82    fn dft(signal: &[Complex<f32>], spectrum: &mut [Complex<f32>]) {
83        for (k, spec_bin) in spectrum.iter_mut().enumerate() {
84            let mut sum = Zero::zero();
85            for (i, &x) in signal.iter().enumerate() {
86                let angle = -1f32 * (i * k) as f32 * 2f32 * f32::consts::PI / signal.len() as f32;
87                let twiddle = Complex::from_polar(1f32, angle);
88
89                sum = sum + twiddle * x;
90            }
91            *spec_bin = sum;
92        }
93    }
94
95    #[test]
96    fn test_matches_dft() {
97        let n = 4;
98
99        for len in 1..20 {
100            let dft_instance = Dft::new(len, FftDirection::Forward);
101            assert_eq!(
102                dft_instance.len(),
103                len,
104                "Dft instance reported incorrect length"
105            );
106
107            let input = random_signal(len * n);
108            let mut expected_output = input.clone();
109
110            // Compute the control data using our simplified Dft definition
111            for (input_chunk, output_chunk) in
112                input.chunks(len).zip(expected_output.chunks_mut(len))
113            {
114                dft(input_chunk, output_chunk);
115            }
116
117            // test process()
118            {
119                let mut inplace_buffer = input.clone();
120
121                dft_instance.process(&mut inplace_buffer);
122
123                assert!(
124                    compare_vectors(&expected_output, &inplace_buffer),
125                    "process() failed, length = {}",
126                    len
127                );
128            }
129
130            // test process_with_scratch()
131            {
132                let mut inplace_with_scratch_buffer = input.clone();
133                let mut inplace_scratch =
134                    vec![Zero::zero(); dft_instance.get_inplace_scratch_len()];
135
136                dft_instance
137                    .process_with_scratch(&mut inplace_with_scratch_buffer, &mut inplace_scratch);
138
139                assert!(
140                    compare_vectors(&expected_output, &inplace_with_scratch_buffer),
141                    "process_inplace() failed, length = {}",
142                    len
143                );
144
145                // one more thing: make sure that the Dft algorithm even works with dirty scratch space
146                for item in inplace_scratch.iter_mut() {
147                    *item = Complex::new(100.0, 100.0);
148                }
149                inplace_with_scratch_buffer.copy_from_slice(&input);
150
151                dft_instance
152                    .process_with_scratch(&mut inplace_with_scratch_buffer, &mut inplace_scratch);
153
154                assert!(
155                    compare_vectors(&expected_output, &inplace_with_scratch_buffer),
156                    "process_with_scratch() failed the 'dirty scratch' test for len = {}",
157                    len
158                );
159            }
160
161            // test process_outofplace_with_scratch
162            {
163                let mut outofplace_input = input.clone();
164                let mut outofplace_output = expected_output.clone();
165
166                dft_instance.process_outofplace_with_scratch(
167                    &mut outofplace_input,
168                    &mut outofplace_output,
169                    &mut [],
170                );
171
172                assert!(
173                    compare_vectors(&expected_output, &outofplace_output),
174                    "process_outofplace_with_scratch() failed, length = {}",
175                    len
176                );
177            }
178        }
179
180        //verify that it doesn't crash or infinite loop if we have a length of 0
181        let zero_dft = Dft::new(0, FftDirection::Forward);
182        let mut zero_input: Vec<Complex<f32>> = Vec::new();
183        let mut zero_output: Vec<Complex<f32>> = Vec::new();
184        let mut zero_scratch: Vec<Complex<f32>> = Vec::new();
185
186        zero_dft.process(&mut zero_input);
187        zero_dft.process_with_scratch(&mut zero_input, &mut zero_scratch);
188        zero_dft.process_outofplace_with_scratch(
189            &mut zero_input,
190            &mut zero_output,
191            &mut zero_scratch,
192        );
193    }
194
195    /// Returns true if our `dft` function calculates the given output from the
196    /// given input, and if rustfft's Dft struct does the same
197    fn test_dft_correct(input: &[Complex<f32>], expected_output: &[Complex<f32>]) {
198        assert_eq!(input.len(), expected_output.len());
199        let len = input.len();
200
201        let mut reference_output = vec![Zero::zero(); len];
202        dft(&input, &mut reference_output);
203        assert!(
204            compare_vectors(expected_output, &reference_output),
205            "Reference implementation failed for len={}",
206            len
207        );
208
209        let dft_instance = Dft::new(len, FftDirection::Forward);
210
211        // test process()
212        {
213            let mut inplace_buffer = input.to_vec();
214
215            dft_instance.process(&mut inplace_buffer);
216
217            assert!(
218                compare_vectors(&expected_output, &inplace_buffer),
219                "process() failed, length = {}",
220                len
221            );
222        }
223
224        // test process_with_scratch()
225        {
226            let mut inplace_with_scratch_buffer = input.to_vec();
227            let mut inplace_scratch = vec![Zero::zero(); dft_instance.get_inplace_scratch_len()];
228
229            dft_instance
230                .process_with_scratch(&mut inplace_with_scratch_buffer, &mut inplace_scratch);
231
232            assert!(
233                compare_vectors(&expected_output, &inplace_with_scratch_buffer),
234                "process_inplace() failed, length = {}",
235                len
236            );
237
238            // one more thing: make sure that the Dft algorithm even works with dirty scratch space
239            for item in inplace_scratch.iter_mut() {
240                *item = Complex::new(100.0, 100.0);
241            }
242            inplace_with_scratch_buffer.copy_from_slice(&input);
243
244            dft_instance
245                .process_with_scratch(&mut inplace_with_scratch_buffer, &mut inplace_scratch);
246
247            assert!(
248                compare_vectors(&expected_output, &inplace_with_scratch_buffer),
249                "process_with_scratch() failed the 'dirty scratch' test for len = {}",
250                len
251            );
252        }
253
254        // test process_outofplace_with_scratch
255        {
256            let mut outofplace_input = input.to_vec();
257            let mut outofplace_output = expected_output.to_vec();
258
259            dft_instance.process_outofplace_with_scratch(
260                &mut outofplace_input,
261                &mut outofplace_output,
262                &mut [],
263            );
264
265            assert!(
266                compare_vectors(&expected_output, &outofplace_output),
267                "process_outofplace_with_scratch() failed, length = {}",
268                len
269            );
270        }
271    }
272
273    #[test]
274    fn test_dft_known_len_2() {
275        let signal = [
276            Complex { re: 1f32, im: 0f32 },
277            Complex {
278                re: -1f32,
279                im: 0f32,
280            },
281        ];
282        let spectrum = [
283            Complex { re: 0f32, im: 0f32 },
284            Complex { re: 2f32, im: 0f32 },
285        ];
286        test_dft_correct(&signal[..], &spectrum[..]);
287    }
288
289    #[test]
290    fn test_dft_known_len_3() {
291        let signal = [
292            Complex { re: 1f32, im: 1f32 },
293            Complex {
294                re: 2f32,
295                im: -3f32,
296            },
297            Complex {
298                re: -1f32,
299                im: 4f32,
300            },
301        ];
302        let spectrum = [
303            Complex { re: 2f32, im: 2f32 },
304            Complex {
305                re: -5.562177f32,
306                im: -2.098076f32,
307            },
308            Complex {
309                re: 6.562178f32,
310                im: 3.09807f32,
311            },
312        ];
313        test_dft_correct(&signal[..], &spectrum[..]);
314    }
315
316    #[test]
317    fn test_dft_known_len_4() {
318        let signal = [
319            Complex { re: 0f32, im: 1f32 },
320            Complex {
321                re: 2.5f32,
322                im: -3f32,
323            },
324            Complex {
325                re: -1f32,
326                im: -1f32,
327            },
328            Complex { re: 4f32, im: 0f32 },
329        ];
330        let spectrum = [
331            Complex {
332                re: 5.5f32,
333                im: -3f32,
334            },
335            Complex {
336                re: -2f32,
337                im: 3.5f32,
338            },
339            Complex {
340                re: -7.5f32,
341                im: 3f32,
342            },
343            Complex {
344                re: 4f32,
345                im: 0.5f32,
346            },
347        ];
348        test_dft_correct(&signal[..], &spectrum[..]);
349    }
350
351    #[test]
352    fn test_dft_known_len_6() {
353        let signal = [
354            Complex { re: 1f32, im: 1f32 },
355            Complex { re: 2f32, im: 2f32 },
356            Complex { re: 3f32, im: 3f32 },
357            Complex { re: 4f32, im: 4f32 },
358            Complex { re: 5f32, im: 5f32 },
359            Complex { re: 6f32, im: 6f32 },
360        ];
361        let spectrum = [
362            Complex {
363                re: 21f32,
364                im: 21f32,
365            },
366            Complex {
367                re: -8.16f32,
368                im: 2.16f32,
369            },
370            Complex {
371                re: -4.76f32,
372                im: -1.24f32,
373            },
374            Complex {
375                re: -3f32,
376                im: -3f32,
377            },
378            Complex {
379                re: -1.24f32,
380                im: -4.76f32,
381            },
382            Complex {
383                re: 2.16f32,
384                im: -8.16f32,
385            },
386        ];
387        test_dft_correct(&signal[..], &spectrum[..]);
388    }
389}