1use num_complex::Complex;
2
3use std::any::TypeId;
4use std::sync::Arc;
5
6use crate::array_utils::{self, bitreversed_transpose, workaround_transmute_mut};
7use crate::common::{fft_error_inplace, fft_error_outofplace};
8use crate::{common::FftNum, FftDirection};
9use crate::{Direction, Fft, Length};
10
11use super::SseNum;
12
13use super::sse_vector::{Rotation90, SseArray, SseArrayMut, SseVector};
14
15pub struct SseRadix4<S: SseNum, T> {
19 twiddles: Box<[S::VectorType]>,
20 rotation: Rotation90<S::VectorType>,
21
22 base_fft: Arc<dyn Fft<T>>,
23 base_len: usize,
24
25 len: usize,
26 direction: FftDirection,
27}
28
29impl<S: SseNum, T: FftNum> SseRadix4<S, T> {
30 #[inline]
32 pub fn new(k: u32, base_fft: Arc<dyn Fft<T>>) -> Result<Self, ()> {
33 let id_a = TypeId::of::<S>();
37 let id_t = TypeId::of::<T>();
38 assert_eq!(id_a, id_t);
39
40 let has_sse = is_x86_feature_detected!("sse4.1");
41 if has_sse {
42 Ok(unsafe { Self::new_with_sse(k, base_fft) })
44 } else {
45 Err(())
46 }
47 }
48
49 #[target_feature(enable = "sse4.1")]
50 unsafe fn new_with_sse(k: u32, base_fft: Arc<dyn Fft<T>>) -> Self {
51 let direction = base_fft.fft_direction();
52 let base_len = base_fft.len();
53
54 assert!(base_len % (2 * S::VectorType::COMPLEX_PER_VECTOR) == 0 && base_len > 0);
56
57 let len = base_len * (1 << (k * 2));
58
59 const ROW_COUNT: usize = 4;
64 let mut cross_fft_len = base_len * ROW_COUNT;
65 let mut twiddle_factors = Vec::with_capacity(len * 2);
66 while cross_fft_len <= len {
67 let num_scalar_columns = cross_fft_len / ROW_COUNT;
68 let num_vector_columns = num_scalar_columns / S::VectorType::COMPLEX_PER_VECTOR;
69
70 for i in 0..num_vector_columns {
71 for k in 1..ROW_COUNT {
72 twiddle_factors.push(SseVector::make_mixedradix_twiddle_chunk(
73 i * S::VectorType::COMPLEX_PER_VECTOR,
74 k,
75 cross_fft_len,
76 direction,
77 ));
78 }
79 }
80 cross_fft_len *= ROW_COUNT;
81 }
82
83 Self {
84 twiddles: twiddle_factors.into_boxed_slice(),
85 rotation: SseVector::make_rotate90(direction),
86
87 base_fft,
88 base_len,
89
90 len,
91 direction,
92 }
93 }
94
95 #[target_feature(enable = "sse4.1")]
96 unsafe fn perform_fft_out_of_place(
97 &self,
98 input: &[Complex<T>],
99 output: &mut [Complex<T>],
100 _scratch: &mut [Complex<T>],
101 ) {
102 if self.len() == self.base_len {
104 output.copy_from_slice(input);
105 } else {
106 bitreversed_transpose::<Complex<T>, 4>(self.base_len, input, output);
107 }
108
109 self.base_fft.process_with_scratch(output, &mut []);
111
112 const ROW_COUNT: usize = 4;
114 let mut cross_fft_len = self.base_len * ROW_COUNT;
115 let mut layer_twiddles: &[S::VectorType] = &self.twiddles;
116
117 while cross_fft_len <= input.len() {
118 let num_rows = input.len() / cross_fft_len;
119 let num_scalar_columns = cross_fft_len / ROW_COUNT;
120 let num_vector_columns = num_scalar_columns / S::VectorType::COMPLEX_PER_VECTOR;
121
122 for i in 0..num_rows {
123 butterfly_4::<S, T>(
124 &mut output[i * cross_fft_len..],
125 layer_twiddles,
126 num_scalar_columns,
127 &self.rotation,
128 )
129 }
130
131 let twiddle_offset = num_vector_columns * (ROW_COUNT - 1);
133 layer_twiddles = &layer_twiddles[twiddle_offset..];
134
135 cross_fft_len *= ROW_COUNT;
136 }
137 }
138}
139boilerplate_fft_sse_oop!(SseRadix4, |this: &SseRadix4<_, _>| this.len);
140
141#[target_feature(enable = "sse4.1")]
142unsafe fn butterfly_4<S: SseNum, T: FftNum>(
143 data: &mut [Complex<T>],
144 twiddles: &[S::VectorType],
145 num_ffts: usize,
146 rotation: &Rotation90<S::VectorType>,
147) {
148 let unroll_offset = S::VectorType::COMPLEX_PER_VECTOR;
149
150 let mut idx = 0usize;
151 let mut buffer: &mut [Complex<S>] = workaround_transmute_mut(data);
152 for tw in twiddles
153 .chunks_exact(6)
154 .take(num_ffts / (S::VectorType::COMPLEX_PER_VECTOR * 2))
155 {
156 let mut scratcha = [
157 buffer.load_complex(idx + 0 * num_ffts),
158 buffer.load_complex(idx + 1 * num_ffts),
159 buffer.load_complex(idx + 2 * num_ffts),
160 buffer.load_complex(idx + 3 * num_ffts),
161 ];
162 let mut scratchb = [
163 buffer.load_complex(idx + 0 * num_ffts + unroll_offset),
164 buffer.load_complex(idx + 1 * num_ffts + unroll_offset),
165 buffer.load_complex(idx + 2 * num_ffts + unroll_offset),
166 buffer.load_complex(idx + 3 * num_ffts + unroll_offset),
167 ];
168
169 scratcha[1] = SseVector::mul_complex(scratcha[1], tw[0]);
170 scratcha[2] = SseVector::mul_complex(scratcha[2], tw[1]);
171 scratcha[3] = SseVector::mul_complex(scratcha[3], tw[2]);
172 scratchb[1] = SseVector::mul_complex(scratchb[1], tw[3]);
173 scratchb[2] = SseVector::mul_complex(scratchb[2], tw[4]);
174 scratchb[3] = SseVector::mul_complex(scratchb[3], tw[5]);
175
176 let scratcha = SseVector::column_butterfly4(scratcha, *rotation);
177 let scratchb = SseVector::column_butterfly4(scratchb, *rotation);
178
179 buffer.store_complex(scratcha[0], idx + 0 * num_ffts);
180 buffer.store_complex(scratchb[0], idx + 0 * num_ffts + unroll_offset);
181 buffer.store_complex(scratcha[1], idx + 1 * num_ffts);
182 buffer.store_complex(scratchb[1], idx + 1 * num_ffts + unroll_offset);
183 buffer.store_complex(scratcha[2], idx + 2 * num_ffts);
184 buffer.store_complex(scratchb[2], idx + 2 * num_ffts + unroll_offset);
185 buffer.store_complex(scratcha[3], idx + 3 * num_ffts);
186 buffer.store_complex(scratchb[3], idx + 3 * num_ffts + unroll_offset);
187
188 idx += S::VectorType::COMPLEX_PER_VECTOR * 2;
189 }
190}
191
192#[cfg(test)]
193mod unit_tests {
194 use super::*;
195 use crate::test_utils::{check_fft_algorithm, construct_base};
196
197 #[test]
198 fn test_sse_radix4_64() {
199 for base in [2, 4, 6, 8, 12, 16] {
200 let base_forward = construct_base(base, FftDirection::Forward);
201 let base_inverse = construct_base(base, FftDirection::Inverse);
202 for k in 0..4 {
203 test_sse_radix4_64_with_base(k, Arc::clone(&base_forward));
204 test_sse_radix4_64_with_base(k, Arc::clone(&base_inverse));
205 }
206 }
207 }
208
209 fn test_sse_radix4_64_with_base(k: u32, base_fft: Arc<dyn Fft<f64>>) {
210 let len = base_fft.len() * 4usize.pow(k);
211 let direction = base_fft.fft_direction();
212 let fft = SseRadix4::<f64, f64>::new(k, base_fft).unwrap();
213 check_fft_algorithm::<f64>(&fft, len, direction);
214 }
215
216 #[test]
217 fn test_sse_radix4_32() {
218 for base in [4, 8, 12, 16] {
219 let base_forward = construct_base(base, FftDirection::Forward);
220 let base_inverse = construct_base(base, FftDirection::Inverse);
221 for k in 0..4 {
222 test_sse_radix4_32_with_base(k, Arc::clone(&base_forward));
223 test_sse_radix4_32_with_base(k, Arc::clone(&base_inverse));
224 }
225 }
226 }
227
228 fn test_sse_radix4_32_with_base(k: u32, base_fft: Arc<dyn Fft<f32>>) {
229 let len = base_fft.len() * 4usize.pow(k);
230 let direction = base_fft.fft_direction();
231 let fft = SseRadix4::<f32, f32>::new(k, base_fft).unwrap();
232 check_fft_algorithm::<f32>(&fft, len, direction);
233 }
234}