1use crate::{Fft, FftDirection, FftNum};
2use std::arch::x86_64::{__m256, __m256d};
3use std::sync::Arc;
4
5pub trait AvxNum: FftNum {
6 type VectorType: AvxVector256<ScalarType = Self>;
7}
8
9impl AvxNum for f32 {
10 type VectorType = __m256;
11}
12impl AvxNum for f64 {
13 type VectorType = __m256d;
14}
15
16struct CommonSimdData<T, V> {
19 inner_fft: Arc<dyn Fft<T>>,
20 twiddles: Box<[V]>,
21
22 len: usize,
23
24 inplace_scratch_len: usize,
25 outofplace_scratch_len: usize,
26
27 direction: FftDirection,
28}
29
30macro_rules! boilerplate_avx_fft {
31 ($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr) => {
32 impl<A: AvxNum, T: FftNum> Fft<T> for $struct_name<A, T> {
33 fn process_outofplace_with_scratch(
34 &self,
35 input: &mut [Complex<T>],
36 output: &mut [Complex<T>],
37 scratch: &mut [Complex<T>],
38 ) {
39 let required_scratch = self.get_outofplace_scratch_len();
40 if scratch.len() < required_scratch
41 || input.len() < self.len()
42 || output.len() != input.len()
43 {
44 fft_error_outofplace(
46 self.len(),
47 input.len(),
48 output.len(),
49 self.get_outofplace_scratch_len(),
50 scratch.len(),
51 );
52 return; }
54
55 let scratch = &mut scratch[..required_scratch];
56 let result = array_utils::iter_chunks_zipped(
57 input,
58 output,
59 self.len(),
60 |in_chunk, out_chunk| {
61 self.perform_fft_out_of_place(in_chunk, out_chunk, scratch)
62 },
63 );
64
65 if result.is_err() {
66 fft_error_outofplace(
69 self.len(),
70 input.len(),
71 output.len(),
72 self.get_outofplace_scratch_len(),
73 scratch.len(),
74 )
75 }
76 }
77 fn process_with_scratch(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
78 let required_scratch = self.get_inplace_scratch_len();
79 if scratch.len() < required_scratch || buffer.len() < self.len() {
80 fft_error_inplace(
82 self.len(),
83 buffer.len(),
84 self.get_inplace_scratch_len(),
85 scratch.len(),
86 );
87 return; }
89
90 let scratch = &mut scratch[..required_scratch];
91 let result = array_utils::iter_chunks(buffer, self.len(), |chunk| {
92 self.perform_fft_inplace(chunk, scratch)
93 });
94
95 if result.is_err() {
96 fft_error_inplace(
99 self.len(),
100 buffer.len(),
101 self.get_inplace_scratch_len(),
102 scratch.len(),
103 )
104 }
105 }
106 #[inline(always)]
107 fn get_inplace_scratch_len(&self) -> usize {
108 $inplace_scratch_len_fn(self)
109 }
110 #[inline(always)]
111 fn get_outofplace_scratch_len(&self) -> usize {
112 $out_of_place_scratch_len_fn(self)
113 }
114 }
115 impl<A: AvxNum, T> Length for $struct_name<A, T> {
116 #[inline(always)]
117 fn len(&self) -> usize {
118 $len_fn(self)
119 }
120 }
121 impl<A: AvxNum, T> Direction for $struct_name<A, T> {
122 #[inline(always)]
123 fn fft_direction(&self) -> FftDirection {
124 self.direction
125 }
126 }
127 };
128}
129
130macro_rules! boilerplate_avx_fft_commondata {
131 ($struct_name:ident) => {
132 impl<A: AvxNum, T: FftNum> Fft<T> for $struct_name<A, T> {
133 fn process_outofplace_with_scratch(
134 &self,
135 input: &mut [Complex<T>],
136 output: &mut [Complex<T>],
137 scratch: &mut [Complex<T>],
138 ) {
139 if self.len() == 0 {
140 return;
141 }
142
143 let required_scratch = self.get_outofplace_scratch_len();
144 if scratch.len() < required_scratch
145 || input.len() < self.len()
146 || output.len() != input.len()
147 {
148 fft_error_outofplace(
150 self.len(),
151 input.len(),
152 output.len(),
153 self.get_outofplace_scratch_len(),
154 scratch.len(),
155 );
156 return; }
158
159 let scratch = &mut scratch[..required_scratch];
160 let result = array_utils::iter_chunks_zipped(
161 input,
162 output,
163 self.len(),
164 |in_chunk, out_chunk| {
165 self.perform_fft_out_of_place(in_chunk, out_chunk, scratch)
166 },
167 );
168
169 if result.is_err() {
170 fft_error_outofplace(
173 self.len(),
174 input.len(),
175 output.len(),
176 self.get_outofplace_scratch_len(),
177 scratch.len(),
178 );
179 }
180 }
181 fn process_with_scratch(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
182 if self.len() == 0 {
183 return;
184 }
185
186 let required_scratch = self.get_inplace_scratch_len();
187 if scratch.len() < required_scratch || buffer.len() < self.len() {
188 fft_error_inplace(
190 self.len(),
191 buffer.len(),
192 self.get_inplace_scratch_len(),
193 scratch.len(),
194 );
195 return; }
197
198 let scratch = &mut scratch[..required_scratch];
199 let result = array_utils::iter_chunks(buffer, self.len(), |chunk| {
200 self.perform_fft_inplace(chunk, scratch)
201 });
202
203 if result.is_err() {
204 fft_error_inplace(
207 self.len(),
208 buffer.len(),
209 self.get_inplace_scratch_len(),
210 scratch.len(),
211 );
212 }
213 }
214 #[inline(always)]
215 fn get_inplace_scratch_len(&self) -> usize {
216 self.common_data.inplace_scratch_len
217 }
218 #[inline(always)]
219 fn get_outofplace_scratch_len(&self) -> usize {
220 self.common_data.outofplace_scratch_len
221 }
222 }
223 impl<A: AvxNum, T> Length for $struct_name<A, T> {
224 #[inline(always)]
225 fn len(&self) -> usize {
226 self.common_data.len
227 }
228 }
229 impl<A: AvxNum, T> Direction for $struct_name<A, T> {
230 #[inline(always)]
231 fn fft_direction(&self) -> FftDirection {
232 self.common_data.direction
233 }
234 }
235 };
236}
237
238#[macro_use]
239mod avx_vector;
240
241mod avx32_butterflies;
242mod avx32_utils;
243
244mod avx64_butterflies;
245mod avx64_utils;
246
247mod avx_bluesteins;
248mod avx_mixed_radix;
249mod avx_raders;
250
251pub mod avx_planner;
252
253pub use self::avx32_butterflies::{
254 Butterfly11Avx, Butterfly128Avx, Butterfly12Avx, Butterfly16Avx, Butterfly24Avx,
255 Butterfly256Avx, Butterfly27Avx, Butterfly32Avx, Butterfly36Avx, Butterfly48Avx,
256 Butterfly512Avx, Butterfly54Avx, Butterfly5Avx, Butterfly64Avx, Butterfly72Avx, Butterfly7Avx,
257 Butterfly8Avx, Butterfly9Avx,
258};
259
260pub use self::avx64_butterflies::{
261 Butterfly11Avx64, Butterfly128Avx64, Butterfly12Avx64, Butterfly16Avx64, Butterfly18Avx64,
262 Butterfly24Avx64, Butterfly256Avx64, Butterfly27Avx64, Butterfly32Avx64, Butterfly36Avx64,
263 Butterfly512Avx64, Butterfly5Avx64, Butterfly64Avx64, Butterfly7Avx64, Butterfly8Avx64,
264 Butterfly9Avx64,
265};
266
267pub use self::avx_bluesteins::BluesteinsAvx;
268pub use self::avx_mixed_radix::{
269 MixedRadix11xnAvx, MixedRadix12xnAvx, MixedRadix16xnAvx, MixedRadix2xnAvx, MixedRadix3xnAvx,
270 MixedRadix4xnAvx, MixedRadix5xnAvx, MixedRadix6xnAvx, MixedRadix7xnAvx, MixedRadix8xnAvx,
271 MixedRadix9xnAvx,
272};
273pub use self::avx_raders::RadersAvx2;
274use self::avx_vector::AvxVector256;