1use num_complex::Complex;
2use num_traits::Zero;
3use std::arch::x86_64::*;
4use std::fmt::Debug;
5use std::ops::{Deref, DerefMut};
6
7use crate::array_utils::DoubleBuf;
8use crate::{twiddles, FftDirection};
9
10use super::SseNum;
11
12macro_rules! read_complex_to_array {
28 ($input:ident, { $($idx:literal),* }) => {
29 [
30 $(
31 $input.load_complex($idx),
32 )*
33 ]
34 }
35}
36
37macro_rules! read_partial1_complex_to_array {
53 ($input:ident, { $($idx:literal),* }) => {
54 [
55 $(
56 $input.load1_complex($idx),
57 )*
58 ]
59 }
60}
61
62macro_rules! write_complex_to_array {
78 ($input:ident, $output:ident, { $($idx:literal),* }) => {
79 $(
80 $output.store_complex($input[$idx], $idx);
81 )*
82 }
83}
84
85macro_rules! write_partial_lo_complex_to_array {
101 ($input:ident, $output:ident, { $($idx:literal),* }) => {
102 $(
103 $output.store_partial_lo_complex($input[$idx], $idx);
104 )*
105 }
106}
107
108macro_rules! write_complex_to_array_strided {
124 ($input:ident, $output:ident, $stride:literal, { $($idx:literal),* }) => {
125 $(
126 $output.store_complex($input[$idx], $idx*$stride);
127 )*
128 }
129}
130
131#[derive(Copy, Clone)]
132pub struct Rotation90<V: SseVector>(V);
133
134pub trait SseVector: Copy + Debug + Send + Sync {
136 const SCALAR_PER_VECTOR: usize;
137 const COMPLEX_PER_VECTOR: usize;
138
139 type ScalarType: SseNum<VectorType = Self>;
140
141 unsafe fn load_complex(ptr: *const Complex<Self::ScalarType>) -> Self;
143 unsafe fn load_partial_lo_complex(ptr: *const Complex<Self::ScalarType>) -> Self;
144 unsafe fn load1_complex(ptr: *const Complex<Self::ScalarType>) -> Self;
145
146 unsafe fn store_complex(ptr: *mut Complex<Self::ScalarType>, data: Self);
148 unsafe fn store_partial_lo_complex(ptr: *mut Complex<Self::ScalarType>, data: Self);
149 unsafe fn store_partial_hi_complex(ptr: *mut Complex<Self::ScalarType>, data: Self);
150
151 unsafe fn neg(a: Self) -> Self;
153 unsafe fn add(a: Self, b: Self) -> Self;
154 unsafe fn mul(a: Self, b: Self) -> Self;
155 unsafe fn fmadd(acc: Self, a: Self, b: Self) -> Self;
156 unsafe fn nmadd(acc: Self, a: Self, b: Self) -> Self;
157
158 unsafe fn broadcast_scalar(value: Self::ScalarType) -> Self;
159
160 unsafe fn make_mixedradix_twiddle_chunk(
163 x: usize,
164 y: usize,
165 len: usize,
166 direction: FftDirection,
167 ) -> Self;
168
169 unsafe fn mul_complex(left: Self, right: Self) -> Self;
171
172 unsafe fn make_rotate90(direction: FftDirection) -> Rotation90<Self>;
174
175 unsafe fn apply_rotate90(direction: Rotation90<Self>, values: Self) -> Self;
177
178 unsafe fn column_butterfly2(rows: [Self; 2]) -> [Self; 2];
180 unsafe fn column_butterfly4(rows: [Self; 4], rotation: Rotation90<Self>) -> [Self; 4];
181}
182
183impl SseVector for __m128 {
184 const SCALAR_PER_VECTOR: usize = 4;
185 const COMPLEX_PER_VECTOR: usize = 2;
186
187 type ScalarType = f32;
188
189 #[inline(always)]
190 unsafe fn load_complex(ptr: *const Complex<Self::ScalarType>) -> Self {
191 _mm_loadu_ps(ptr as *const f32)
192 }
193
194 #[inline(always)]
195 unsafe fn load_partial_lo_complex(ptr: *const Complex<Self::ScalarType>) -> Self {
196 _mm_castpd_ps(_mm_load_sd(ptr as *const f64))
197 }
198
199 #[inline(always)]
200 unsafe fn load1_complex(ptr: *const Complex<Self::ScalarType>) -> Self {
201 _mm_castpd_ps(_mm_load1_pd(ptr as *const f64))
202 }
203
204 #[inline(always)]
205 unsafe fn store_complex(ptr: *mut Complex<Self::ScalarType>, data: Self) {
206 _mm_storeu_ps(ptr as *mut f32, data);
207 }
208
209 #[inline(always)]
210 unsafe fn store_partial_lo_complex(ptr: *mut Complex<Self::ScalarType>, data: Self) {
211 _mm_storel_pd(ptr as *mut f64, _mm_castps_pd(data));
212 }
213
214 #[inline(always)]
215 unsafe fn store_partial_hi_complex(ptr: *mut Complex<Self::ScalarType>, data: Self) {
216 _mm_storeh_pd(ptr as *mut f64, _mm_castps_pd(data));
217 }
218
219 #[inline(always)]
220 unsafe fn neg(a: Self) -> Self {
221 _mm_xor_ps(a, _mm_set1_ps(-0.0))
222 }
223 #[inline(always)]
224 unsafe fn add(a: Self, b: Self) -> Self {
225 _mm_add_ps(a, b)
226 }
227 #[inline(always)]
228 unsafe fn mul(a: Self, b: Self) -> Self {
229 _mm_mul_ps(a, b)
230 }
231 #[inline(always)]
232 unsafe fn fmadd(acc: Self, a: Self, b: Self) -> Self {
233 _mm_add_ps(acc, _mm_mul_ps(a, b))
234 }
235 #[inline(always)]
236 unsafe fn nmadd(acc: Self, a: Self, b: Self) -> Self {
237 _mm_sub_ps(acc, _mm_mul_ps(a, b))
238 }
239
240 #[inline(always)]
241 unsafe fn broadcast_scalar(value: Self::ScalarType) -> Self {
242 _mm_set1_ps(value)
243 }
244
245 #[inline(always)]
246 unsafe fn make_mixedradix_twiddle_chunk(
247 x: usize,
248 y: usize,
249 len: usize,
250 direction: FftDirection,
251 ) -> Self {
252 let mut twiddle_chunk = [Complex::<f32>::zero(); Self::COMPLEX_PER_VECTOR];
253 for i in 0..Self::COMPLEX_PER_VECTOR {
254 twiddle_chunk[i] = twiddles::compute_twiddle(y * (x + i), len, direction);
255 }
256
257 twiddle_chunk.as_slice().load_complex(0)
258 }
259
260 #[inline(always)]
261 unsafe fn mul_complex(left: Self, right: Self) -> Self {
262 let mut temp1 = _mm_shuffle_ps(right, right, 0xA0);
264 let mut temp2 = _mm_shuffle_ps(right, right, 0xF5);
265 temp1 = _mm_mul_ps(temp1, left);
266 temp2 = _mm_mul_ps(temp2, left);
267 temp2 = _mm_shuffle_ps(temp2, temp2, 0xB1);
268 _mm_addsub_ps(temp1, temp2)
269 }
270
271 #[inline(always)]
272 unsafe fn make_rotate90(direction: FftDirection) -> Rotation90<Self> {
273 Rotation90(match direction {
274 FftDirection::Forward => _mm_set_ps(-0.0, 0.0, -0.0, 0.0),
275 FftDirection::Inverse => _mm_set_ps(0.0, -0.0, 0.0, -0.0),
276 })
277 }
278
279 #[inline(always)]
280 unsafe fn apply_rotate90(direction: Rotation90<Self>, values: Self) -> Self {
281 let temp = _mm_shuffle_ps(values, values, 0xB1);
282 _mm_xor_ps(temp, direction.0)
283 }
284
285 #[inline(always)]
286 unsafe fn column_butterfly2(rows: [Self; 2]) -> [Self; 2] {
287 [_mm_add_ps(rows[0], rows[1]), _mm_sub_ps(rows[0], rows[1])]
288 }
289
290 #[inline(always)]
291 unsafe fn column_butterfly4(rows: [Self; 4], rotation: Rotation90<Self>) -> [Self; 4] {
292 let [mid0, mid2] = Self::column_butterfly2([rows[0], rows[2]]);
296 let [mid1, mid3] = Self::column_butterfly2([rows[1], rows[3]]);
297
298 let mid3_rotated = Self::apply_rotate90(rotation, mid3);
300
301 let [output0, output1] = Self::column_butterfly2([mid0, mid1]);
303 let [output2, output3] = Self::column_butterfly2([mid2, mid3_rotated]);
304
305 [output0, output2, output1, output3]
307 }
308}
309
310impl SseVector for __m128d {
311 const SCALAR_PER_VECTOR: usize = 2;
312 const COMPLEX_PER_VECTOR: usize = 1;
313
314 type ScalarType = f64;
315
316 #[inline(always)]
317 unsafe fn load_complex(ptr: *const Complex<Self::ScalarType>) -> Self {
318 _mm_loadu_pd(ptr as *const f64)
319 }
320
321 #[inline(always)]
322 unsafe fn load_partial_lo_complex(_ptr: *const Complex<Self::ScalarType>) -> Self {
323 unimplemented!("Impossible to do a load store of complex f64's");
324 }
325
326 #[inline(always)]
327 unsafe fn load1_complex(_ptr: *const Complex<Self::ScalarType>) -> Self {
328 unimplemented!("Impossible to do a load store of complex f64's");
329 }
330
331 #[inline(always)]
332 unsafe fn store_complex(ptr: *mut Complex<Self::ScalarType>, data: Self) {
333 _mm_storeu_pd(ptr as *mut f64, data);
334 }
335
336 #[inline(always)]
337 unsafe fn store_partial_lo_complex(_ptr: *mut Complex<Self::ScalarType>, _data: Self) {
338 unimplemented!("Impossible to do a partial store of complex f64's");
339 }
340
341 #[inline(always)]
342 unsafe fn store_partial_hi_complex(_ptr: *mut Complex<Self::ScalarType>, _data: Self) {
343 unimplemented!("Impossible to do a partial store of complex f64's");
344 }
345
346 #[inline(always)]
347 unsafe fn neg(a: Self) -> Self {
348 _mm_xor_pd(a, _mm_set1_pd(-0.0))
349 }
350 #[inline(always)]
351 unsafe fn add(a: Self, b: Self) -> Self {
352 _mm_add_pd(a, b)
353 }
354 #[inline(always)]
355 unsafe fn mul(a: Self, b: Self) -> Self {
356 _mm_mul_pd(a, b)
357 }
358 #[inline(always)]
359 unsafe fn fmadd(acc: Self, a: Self, b: Self) -> Self {
360 _mm_add_pd(acc, _mm_mul_pd(a, b))
361 }
362 #[inline(always)]
363 unsafe fn nmadd(acc: Self, a: Self, b: Self) -> Self {
364 _mm_sub_pd(acc, _mm_mul_pd(a, b))
365 }
366
367 #[inline(always)]
368 unsafe fn broadcast_scalar(value: Self::ScalarType) -> Self {
369 _mm_set1_pd(value)
370 }
371
372 #[inline(always)]
373 unsafe fn make_mixedradix_twiddle_chunk(
374 x: usize,
375 y: usize,
376 len: usize,
377 direction: FftDirection,
378 ) -> Self {
379 let mut twiddle_chunk = [Complex::<f64>::zero(); Self::COMPLEX_PER_VECTOR];
380 for i in 0..Self::COMPLEX_PER_VECTOR {
381 twiddle_chunk[i] = twiddles::compute_twiddle(y * (x + i), len, direction);
382 }
383
384 twiddle_chunk.as_slice().load_complex(0)
385 }
386
387 #[inline(always)]
388 unsafe fn mul_complex(left: Self, right: Self) -> Self {
389 let mut temp1 = _mm_unpacklo_pd(right, right);
391 let mut temp2 = _mm_unpackhi_pd(right, right);
392 temp1 = _mm_mul_pd(temp1, left);
393 temp2 = _mm_mul_pd(temp2, left);
394 temp2 = _mm_shuffle_pd(temp2, temp2, 0x01);
395 _mm_addsub_pd(temp1, temp2)
396 }
397
398 #[inline(always)]
399 unsafe fn make_rotate90(direction: FftDirection) -> Rotation90<Self> {
400 Rotation90(match direction {
401 FftDirection::Forward => _mm_set_pd(-0.0, 0.0),
402 FftDirection::Inverse => _mm_set_pd(0.0, -0.0),
403 })
404 }
405
406 #[inline(always)]
407 unsafe fn apply_rotate90(direction: Rotation90<Self>, values: Self) -> Self {
408 let temp = _mm_shuffle_pd(values, values, 0x01);
409 _mm_xor_pd(temp, direction.0)
410 }
411
412 #[inline(always)]
413 unsafe fn column_butterfly2(rows: [Self; 2]) -> [Self; 2] {
414 [_mm_add_pd(rows[0], rows[1]), _mm_sub_pd(rows[0], rows[1])]
415 }
416
417 #[inline(always)]
418 unsafe fn column_butterfly4(rows: [Self; 4], rotation: Rotation90<Self>) -> [Self; 4] {
419 let [mid0, mid2] = Self::column_butterfly2([rows[0], rows[2]]);
423 let [mid1, mid3] = Self::column_butterfly2([rows[1], rows[3]]);
424
425 let mid3_rotated = Self::apply_rotate90(rotation, mid3);
427
428 let [output0, output1] = Self::column_butterfly2([mid0, mid1]);
430 let [output2, output3] = Self::column_butterfly2([mid2, mid3_rotated]);
431
432 [output0, output2, output1, output3]
434 }
435}
436
437pub trait SseArray<S: SseNum>: Deref {
441 unsafe fn load_complex(&self, index: usize) -> S::VectorType;
443 unsafe fn load_partial_lo_complex(&self, index: usize) -> S::VectorType;
445 unsafe fn load1_complex(&self, index: usize) -> S::VectorType;
447}
448
449impl<S: SseNum> SseArray<S> for &[Complex<S>] {
450 #[inline(always)]
451 unsafe fn load_complex(&self, index: usize) -> S::VectorType {
452 debug_assert!(self.len() >= index + S::VectorType::COMPLEX_PER_VECTOR);
453 S::VectorType::load_complex(self.as_ptr().add(index))
454 }
455
456 #[inline(always)]
457 unsafe fn load_partial_lo_complex(&self, index: usize) -> S::VectorType {
458 debug_assert!(self.len() >= index + 1);
459 S::VectorType::load_partial_lo_complex(self.as_ptr().add(index))
460 }
461
462 #[inline(always)]
463 unsafe fn load1_complex(&self, index: usize) -> S::VectorType {
464 debug_assert!(self.len() >= index + 1);
465 S::VectorType::load1_complex(self.as_ptr().add(index))
466 }
467}
468impl<S: SseNum> SseArray<S> for &mut [Complex<S>] {
469 #[inline(always)]
470 unsafe fn load_complex(&self, index: usize) -> S::VectorType {
471 debug_assert!(self.len() >= index + S::VectorType::COMPLEX_PER_VECTOR);
472 S::VectorType::load_complex(self.as_ptr().add(index))
473 }
474
475 #[inline(always)]
476 unsafe fn load_partial_lo_complex(&self, index: usize) -> S::VectorType {
477 debug_assert!(self.len() >= index + 1);
478 S::VectorType::load_partial_lo_complex(self.as_ptr().add(index))
479 }
480
481 #[inline(always)]
482 unsafe fn load1_complex(&self, index: usize) -> S::VectorType {
483 debug_assert!(self.len() >= index + 1);
484 S::VectorType::load1_complex(self.as_ptr().add(index))
485 }
486}
487
488impl<'a, S: SseNum> SseArray<S> for DoubleBuf<'a, S>
489where
490 &'a [Complex<S>]: SseArray<S>,
491{
492 #[inline(always)]
493 unsafe fn load_complex(&self, index: usize) -> S::VectorType {
494 self.input.load_complex(index)
495 }
496 #[inline(always)]
497 unsafe fn load_partial_lo_complex(&self, index: usize) -> S::VectorType {
498 self.input.load_partial_lo_complex(index)
499 }
500 #[inline(always)]
501 unsafe fn load1_complex(&self, index: usize) -> S::VectorType {
502 self.input.load1_complex(index)
503 }
504}
505
506pub trait SseArrayMut<S: SseNum>: SseArray<S> + DerefMut {
510 unsafe fn store_complex(&mut self, vector: S::VectorType, index: usize);
512 unsafe fn store_partial_lo_complex(&mut self, vector: S::VectorType, index: usize);
514 unsafe fn store_partial_hi_complex(&mut self, vector: S::VectorType, index: usize);
516}
517
518impl<S: SseNum> SseArrayMut<S> for &mut [Complex<S>] {
519 #[inline(always)]
520 unsafe fn store_complex(&mut self, vector: S::VectorType, index: usize) {
521 debug_assert!(self.len() >= index + S::VectorType::COMPLEX_PER_VECTOR);
522 S::VectorType::store_complex(self.as_mut_ptr().add(index), vector)
523 }
524
525 #[inline(always)]
526 unsafe fn store_partial_hi_complex(&mut self, vector: S::VectorType, index: usize) {
527 debug_assert!(self.len() >= index + 1);
528 S::VectorType::store_partial_hi_complex(self.as_mut_ptr().add(index), vector)
529 }
530 #[inline(always)]
531 unsafe fn store_partial_lo_complex(&mut self, vector: S::VectorType, index: usize) {
532 debug_assert!(self.len() >= index + 1);
533 S::VectorType::store_partial_lo_complex(self.as_mut_ptr().add(index), vector)
534 }
535}
536
537impl<'a, T: SseNum> SseArrayMut<T> for DoubleBuf<'a, T>
538where
539 Self: SseArray<T>,
540 &'a mut [Complex<T>]: SseArrayMut<T>,
541{
542 #[inline(always)]
543 unsafe fn store_complex(&mut self, vector: T::VectorType, index: usize) {
544 self.output.store_complex(vector, index);
545 }
546 #[inline(always)]
547 unsafe fn store_partial_lo_complex(&mut self, vector: T::VectorType, index: usize) {
548 self.output.store_partial_lo_complex(vector, index);
549 }
550 #[inline(always)]
551 unsafe fn store_partial_hi_complex(&mut self, vector: T::VectorType, index: usize) {
552 self.output.store_partial_hi_complex(vector, index);
553 }
554}