rustfft/sse/sse_common.rs
1use std::any::TypeId;
2
3// Helper function to assert we have the right float type
4pub fn assert_f32<T: 'static>() {
5 let id_f32 = TypeId::of::<f32>();
6 let id_t = TypeId::of::<T>();
7 assert!(id_t == id_f32, "Wrong float type, must be f32");
8}
9
10// Helper function to assert we have the right float type
11pub fn assert_f64<T: 'static>() {
12 let id_f64 = TypeId::of::<f64>();
13 let id_t = TypeId::of::<T>();
14 assert!(id_t == id_f64, "Wrong float type, must be f64");
15}
16
17// Shuffle elements to interleave two contiguous sets of f32, from an array of simd vectors to a new array of simd vectors
18macro_rules! interleave_complex_f32 {
19 ($input:ident, $offset:literal, { $($idx:literal),* }) => {
20 [
21 $(
22 extract_lo_lo_f32($input[$idx], $input[$idx+$offset]),
23 extract_hi_hi_f32($input[$idx], $input[$idx+$offset]),
24 )*
25 ]
26 }
27}
28
29// Shuffle elements to interleave two contiguous sets of f32, from an array of simd vectors to a new array of simd vectors
30// This statement:
31// ```
32// let values = separate_interleaved_complex_f32!(input, {0, 2, 4});
33// ```
34// is equivalent to:
35// ```
36// let values = [
37// extract_lo_lo_f32(input[0], input[1]),
38// extract_lo_lo_f32(input[2], input[3]),
39// extract_lo_lo_f32(input[4], input[5]),
40// extract_hi_hi_f32(input[0], input[1]),
41// extract_hi_hi_f32(input[2], input[3]),
42// extract_hi_hi_f32(input[4], input[5]),
43// ];
44macro_rules! separate_interleaved_complex_f32 {
45 ($input:ident, { $($idx:literal),* }) => {
46 [
47 $(
48 extract_lo_lo_f32($input[$idx], $input[$idx+1]),
49 )*
50 $(
51 extract_hi_hi_f32($input[$idx], $input[$idx+1]),
52 )*
53 ]
54 }
55}
56
57macro_rules! boilerplate_fft_sse_oop {
58 ($struct_name:ident, $len_fn:expr) => {
59 impl<S: SseNum, T: FftNum> Fft<T> for $struct_name<S, T> {
60 fn process_outofplace_with_scratch(
61 &self,
62 input: &mut [Complex<T>],
63 output: &mut [Complex<T>],
64 _scratch: &mut [Complex<T>],
65 ) {
66 if self.len() == 0 {
67 return;
68 }
69
70 if input.len() < self.len() || output.len() != input.len() {
71 // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
72 fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0);
73 return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here
74 }
75
76 let result = unsafe {
77 array_utils::iter_chunks_zipped(
78 input,
79 output,
80 self.len(),
81 |in_chunk, out_chunk| {
82 self.perform_fft_out_of_place(in_chunk, out_chunk, &mut [])
83 },
84 )
85 };
86
87 if result.is_err() {
88 // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size,
89 // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
90 fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0);
91 }
92 }
93 fn process_with_scratch(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
94 if self.len() == 0 {
95 return;
96 }
97
98 let required_scratch = self.get_inplace_scratch_len();
99 if scratch.len() < required_scratch || buffer.len() < self.len() {
100 // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
101 fft_error_inplace(
102 self.len(),
103 buffer.len(),
104 self.get_inplace_scratch_len(),
105 scratch.len(),
106 );
107 return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here
108 }
109
110 let scratch = &mut scratch[..required_scratch];
111 let result = unsafe {
112 array_utils::iter_chunks(buffer, self.len(), |chunk| {
113 self.perform_fft_out_of_place(chunk, scratch, &mut []);
114 chunk.copy_from_slice(scratch);
115 })
116 };
117 if result.is_err() {
118 // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size,
119 // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
120 fft_error_inplace(
121 self.len(),
122 buffer.len(),
123 self.get_inplace_scratch_len(),
124 scratch.len(),
125 );
126 }
127 }
128 #[inline(always)]
129 fn get_inplace_scratch_len(&self) -> usize {
130 self.len()
131 }
132 #[inline(always)]
133 fn get_outofplace_scratch_len(&self) -> usize {
134 0
135 }
136 }
137 impl<S: SseNum, T> Length for $struct_name<S, T> {
138 #[inline(always)]
139 fn len(&self) -> usize {
140 $len_fn(self)
141 }
142 }
143 impl<S: SseNum, T> Direction for $struct_name<S, T> {
144 #[inline(always)]
145 fn fft_direction(&self) -> FftDirection {
146 self.direction
147 }
148 }
149 };
150}
151
152/* Not used now, but maybe later for the mixed radixes etc
153macro_rules! boilerplate_sse_fft {
154 ($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr) => {
155 impl<T: FftNum> Fft<T> for $struct_name<T> {
156 fn process_outofplace_with_scratch(
157 &self,
158 input: &mut [Complex<T>],
159 output: &mut [Complex<T>],
160 scratch: &mut [Complex<T>],
161 ) {
162 if self.len() == 0 {
163 return;
164 }
165
166 let required_scratch = self.get_outofplace_scratch_len();
167 if scratch.len() < required_scratch
168 || input.len() < self.len()
169 || output.len() != input.len()
170 {
171 // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
172 fft_error_outofplace(
173 self.len(),
174 input.len(),
175 output.len(),
176 self.get_outofplace_scratch_len(),
177 scratch.len(),
178 );
179 return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here
180 }
181
182 let scratch = &mut scratch[..required_scratch];
183 let result = array_utils::iter_chunks_zipped(
184 input,
185 output,
186 self.len(),
187 |in_chunk, out_chunk| {
188 self.perform_fft_out_of_place(in_chunk, out_chunk, scratch)
189 },
190 );
191
192 if result.is_err() {
193 // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size,
194 // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
195 fft_error_outofplace(
196 self.len(),
197 input.len(),
198 output.len(),
199 self.get_outofplace_scratch_len(),
200 scratch.len(),
201 );
202 }
203 }
204 fn process_with_scratch(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
205 if self.len() == 0 {
206 return;
207 }
208
209 let required_scratch = self.get_inplace_scratch_len();
210 if scratch.len() < required_scratch || buffer.len() < self.len() {
211 // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
212 fft_error_inplace(
213 self.len(),
214 buffer.len(),
215 self.get_inplace_scratch_len(),
216 scratch.len(),
217 );
218 return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here
219 }
220
221 let scratch = &mut scratch[..required_scratch];
222 let result = array_utils::iter_chunks(buffer, self.len(), |chunk| {
223 self.perform_fft_inplace(chunk, scratch)
224 });
225
226 if result.is_err() {
227 // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size,
228 // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us
229 fft_error_inplace(
230 self.len(),
231 buffer.len(),
232 self.get_inplace_scratch_len(),
233 scratch.len(),
234 );
235 }
236 }
237 #[inline(always)]
238 fn get_inplace_scratch_len(&self) -> usize {
239 $inplace_scratch_len_fn(self)
240 }
241 #[inline(always)]
242 fn get_outofplace_scratch_len(&self) -> usize {
243 $out_of_place_scratch_len_fn(self)
244 }
245 }
246 impl<T: FftNum> Length for $struct_name<T> {
247 #[inline(always)]
248 fn len(&self) -> usize {
249 $len_fn(self)
250 }
251 }
252 impl<T: FftNum> Direction for $struct_name<T> {
253 #[inline(always)]
254 fn fft_direction(&self) -> FftDirection {
255 self.direction
256 }
257 }
258 };
259}
260*/