rustfft/sse/
sse_vector.rs

1use num_complex::Complex;
2use num_traits::Zero;
3use std::arch::x86_64::*;
4use std::fmt::Debug;
5use std::ops::{Deref, DerefMut};
6
7use crate::array_utils::DoubleBuf;
8use crate::{twiddles, FftDirection};
9
10use super::SseNum;
11
12// Read these indexes from an SseArray and build an array of simd vectors.
13// Takes a name of a vector to read from, and a list of indexes to read.
14// This statement:
15// ```
16// let values = read_complex_to_array!(input, {0, 1, 2, 3});
17// ```
18// is equivalent to:
19// ```
20// let values = [
21//     input.load_complex(0),
22//     input.load_complex(1),
23//     input.load_complex(2),
24//     input.load_complex(3),
25// ];
26// ```
27macro_rules! read_complex_to_array {
28    ($input:ident, { $($idx:literal),* }) => {
29        [
30        $(
31            $input.load_complex($idx),
32        )*
33        ]
34    }
35}
36
37// Read these indexes from an SseArray and build an array or partially filled simd vectors.
38// Takes a name of a vector to read from, and a list of indexes to read.
39// This statement:
40// ```
41// let values = read_partial1_complex_to_array!(input, {0, 1, 2, 3});
42// ```
43// is equivalent to:
44// ```
45// let values = [
46//     input.load1_complex(0),
47//     input.load1_complex(1),
48//     input.load1_complex(2),
49//     input.load1_complex(3),
50// ];
51// ```
52macro_rules! read_partial1_complex_to_array {
53    ($input:ident, { $($idx:literal),* }) => {
54        [
55        $(
56            $input.load1_complex($idx),
57        )*
58        ]
59    }
60}
61
62// Write these indexes of an array of simd vectors to the same indexes of an SseArray.
63// Takes a name of a vector to read from, one to write to, and a list of indexes.
64// This statement:
65// ```
66// let values = write_complex_to_array!(input, output, {0, 1, 2, 3});
67// ```
68// is equivalent to:
69// ```
70// let values = [
71//     output.store_complex(input[0], 0),
72//     output.store_complex(input[1], 1),
73//     output.store_complex(input[2], 2),
74//     output.store_complex(input[3], 3),
75// ];
76// ```
77macro_rules! write_complex_to_array {
78    ($input:ident, $output:ident, { $($idx:literal),* }) => {
79        $(
80            $output.store_complex($input[$idx], $idx);
81        )*
82    }
83}
84
85// Write the low half of these indexes of an array of simd vectors to the same indexes of an SseArray.
86// Takes a name of a vector to read from, one to write to, and a list of indexes.
87// This statement:
88// ```
89// let values = write_partial_lo_complex_to_array!(input, output, {0, 1, 2, 3});
90// ```
91// is equivalent to:
92// ```
93// let values = [
94//     output.store_partial_lo_complex(input[0], 0),
95//     output.store_partial_lo_complex(input[1], 1),
96//     output.store_partial_lo_complex(input[2], 2),
97//     output.store_partial_lo_complex(input[3], 3),
98// ];
99// ```
100macro_rules! write_partial_lo_complex_to_array {
101    ($input:ident, $output:ident, { $($idx:literal),* }) => {
102        $(
103            $output.store_partial_lo_complex($input[$idx], $idx);
104        )*
105    }
106}
107
108// Write these indexes of an array of simd vectors to the same indexes, multiplied by a stride, of an SseArray.
109// Takes a name of a vector to read from, one to write to, an integer stride, and a list of indexes.
110// This statement:
111// ```
112// let values = write_complex_to_array_separate!(input, output, {0, 1, 2, 3});
113// ```
114// is equivalent to:
115// ```
116// let values = [
117//     output.store_complex(input[0], 0),
118//     output.store_complex(input[1], 2),
119//     output.store_complex(input[2], 4),
120//     output.store_complex(input[3], 6),
121// ];
122// ```
123macro_rules! write_complex_to_array_strided {
124    ($input:ident, $output:ident, $stride:literal, { $($idx:literal),* }) => {
125        $(
126            $output.store_complex($input[$idx], $idx*$stride);
127        )*
128    }
129}
130
131#[derive(Copy, Clone)]
132pub struct Rotation90<V: SseVector>(V);
133
134// A trait to hold the BVectorType and COMPLEX_PER_VECTOR associated data
135pub trait SseVector: Copy + Debug + Send + Sync {
136    const SCALAR_PER_VECTOR: usize;
137    const COMPLEX_PER_VECTOR: usize;
138
139    type ScalarType: SseNum<VectorType = Self>;
140
141    // loads of complex numbers
142    unsafe fn load_complex(ptr: *const Complex<Self::ScalarType>) -> Self;
143    unsafe fn load_partial_lo_complex(ptr: *const Complex<Self::ScalarType>) -> Self;
144    unsafe fn load1_complex(ptr: *const Complex<Self::ScalarType>) -> Self;
145
146    // stores of complex numbers
147    unsafe fn store_complex(ptr: *mut Complex<Self::ScalarType>, data: Self);
148    unsafe fn store_partial_lo_complex(ptr: *mut Complex<Self::ScalarType>, data: Self);
149    unsafe fn store_partial_hi_complex(ptr: *mut Complex<Self::ScalarType>, data: Self);
150
151    // math ops
152    unsafe fn neg(a: Self) -> Self;
153    unsafe fn add(a: Self, b: Self) -> Self;
154    unsafe fn mul(a: Self, b: Self) -> Self;
155    unsafe fn fmadd(acc: Self, a: Self, b: Self) -> Self;
156    unsafe fn nmadd(acc: Self, a: Self, b: Self) -> Self;
157
158    unsafe fn broadcast_scalar(value: Self::ScalarType) -> Self;
159
160    /// Generates a chunk of twiddle factors starting at (X,Y) and incrementing X `COMPLEX_PER_VECTOR` times.
161    /// The result will be [twiddle(x*y, len), twiddle((x+1)*y, len), twiddle((x+2)*y, len), ...] for as many complex numbers fit in a vector
162    unsafe fn make_mixedradix_twiddle_chunk(
163        x: usize,
164        y: usize,
165        len: usize,
166        direction: FftDirection,
167    ) -> Self;
168
169    /// Pairwise multiply the complex numbers in `left` with the complex numbers in `right`.
170    unsafe fn mul_complex(left: Self, right: Self) -> Self;
171
172    /// Constructs a Rotate90 object that will apply eithr a 90 or 270 degree rotationto the complex elements
173    unsafe fn make_rotate90(direction: FftDirection) -> Rotation90<Self>;
174
175    /// Uses a pre-constructed rotate90 object to apply the given rotation
176    unsafe fn apply_rotate90(direction: Rotation90<Self>, values: Self) -> Self;
177
178    /// Each of these Interprets the input as rows of a Self::COMPLEX_PER_VECTOR-by-N 2D array, and computes parallel butterflies down the columns of the 2D array
179    unsafe fn column_butterfly2(rows: [Self; 2]) -> [Self; 2];
180    unsafe fn column_butterfly4(rows: [Self; 4], rotation: Rotation90<Self>) -> [Self; 4];
181}
182
183impl SseVector for __m128 {
184    const SCALAR_PER_VECTOR: usize = 4;
185    const COMPLEX_PER_VECTOR: usize = 2;
186
187    type ScalarType = f32;
188
189    #[inline(always)]
190    unsafe fn load_complex(ptr: *const Complex<Self::ScalarType>) -> Self {
191        _mm_loadu_ps(ptr as *const f32)
192    }
193
194    #[inline(always)]
195    unsafe fn load_partial_lo_complex(ptr: *const Complex<Self::ScalarType>) -> Self {
196        _mm_castpd_ps(_mm_load_sd(ptr as *const f64))
197    }
198
199    #[inline(always)]
200    unsafe fn load1_complex(ptr: *const Complex<Self::ScalarType>) -> Self {
201        _mm_castpd_ps(_mm_load1_pd(ptr as *const f64))
202    }
203
204    #[inline(always)]
205    unsafe fn store_complex(ptr: *mut Complex<Self::ScalarType>, data: Self) {
206        _mm_storeu_ps(ptr as *mut f32, data);
207    }
208
209    #[inline(always)]
210    unsafe fn store_partial_lo_complex(ptr: *mut Complex<Self::ScalarType>, data: Self) {
211        _mm_storel_pd(ptr as *mut f64, _mm_castps_pd(data));
212    }
213
214    #[inline(always)]
215    unsafe fn store_partial_hi_complex(ptr: *mut Complex<Self::ScalarType>, data: Self) {
216        _mm_storeh_pd(ptr as *mut f64, _mm_castps_pd(data));
217    }
218
219    #[inline(always)]
220    unsafe fn neg(a: Self) -> Self {
221        _mm_xor_ps(a, _mm_set1_ps(-0.0))
222    }
223    #[inline(always)]
224    unsafe fn add(a: Self, b: Self) -> Self {
225        _mm_add_ps(a, b)
226    }
227    #[inline(always)]
228    unsafe fn mul(a: Self, b: Self) -> Self {
229        _mm_mul_ps(a, b)
230    }
231    #[inline(always)]
232    unsafe fn fmadd(acc: Self, a: Self, b: Self) -> Self {
233        _mm_add_ps(acc, _mm_mul_ps(a, b))
234    }
235    #[inline(always)]
236    unsafe fn nmadd(acc: Self, a: Self, b: Self) -> Self {
237        _mm_sub_ps(acc, _mm_mul_ps(a, b))
238    }
239
240    #[inline(always)]
241    unsafe fn broadcast_scalar(value: Self::ScalarType) -> Self {
242        _mm_set1_ps(value)
243    }
244
245    #[inline(always)]
246    unsafe fn make_mixedradix_twiddle_chunk(
247        x: usize,
248        y: usize,
249        len: usize,
250        direction: FftDirection,
251    ) -> Self {
252        let mut twiddle_chunk = [Complex::<f32>::zero(); Self::COMPLEX_PER_VECTOR];
253        for i in 0..Self::COMPLEX_PER_VECTOR {
254            twiddle_chunk[i] = twiddles::compute_twiddle(y * (x + i), len, direction);
255        }
256
257        twiddle_chunk.as_slice().load_complex(0)
258    }
259
260    #[inline(always)]
261    unsafe fn mul_complex(left: Self, right: Self) -> Self {
262        //SSE3, taken from Intel performance manual
263        let mut temp1 = _mm_shuffle_ps(right, right, 0xA0);
264        let mut temp2 = _mm_shuffle_ps(right, right, 0xF5);
265        temp1 = _mm_mul_ps(temp1, left);
266        temp2 = _mm_mul_ps(temp2, left);
267        temp2 = _mm_shuffle_ps(temp2, temp2, 0xB1);
268        _mm_addsub_ps(temp1, temp2)
269    }
270
271    #[inline(always)]
272    unsafe fn make_rotate90(direction: FftDirection) -> Rotation90<Self> {
273        Rotation90(match direction {
274            FftDirection::Forward => _mm_set_ps(-0.0, 0.0, -0.0, 0.0),
275            FftDirection::Inverse => _mm_set_ps(0.0, -0.0, 0.0, -0.0),
276        })
277    }
278
279    #[inline(always)]
280    unsafe fn apply_rotate90(direction: Rotation90<Self>, values: Self) -> Self {
281        let temp = _mm_shuffle_ps(values, values, 0xB1);
282        _mm_xor_ps(temp, direction.0)
283    }
284
285    #[inline(always)]
286    unsafe fn column_butterfly2(rows: [Self; 2]) -> [Self; 2] {
287        [_mm_add_ps(rows[0], rows[1]), _mm_sub_ps(rows[0], rows[1])]
288    }
289
290    #[inline(always)]
291    unsafe fn column_butterfly4(rows: [Self; 4], rotation: Rotation90<Self>) -> [Self; 4] {
292        // Algorithm: 2x2 mixed radix
293
294        // Perform the first set of size-2 FFTs.
295        let [mid0, mid2] = Self::column_butterfly2([rows[0], rows[2]]);
296        let [mid1, mid3] = Self::column_butterfly2([rows[1], rows[3]]);
297
298        // Apply twiddle factors (in this case just a rotation)
299        let mid3_rotated = Self::apply_rotate90(rotation, mid3);
300
301        // Transpose the data and do size-2 FFTs down the columns
302        let [output0, output1] = Self::column_butterfly2([mid0, mid1]);
303        let [output2, output3] = Self::column_butterfly2([mid2, mid3_rotated]);
304
305        // Swap outputs 1 and 2 in the output to do a square transpose
306        [output0, output2, output1, output3]
307    }
308}
309
310impl SseVector for __m128d {
311    const SCALAR_PER_VECTOR: usize = 2;
312    const COMPLEX_PER_VECTOR: usize = 1;
313
314    type ScalarType = f64;
315
316    #[inline(always)]
317    unsafe fn load_complex(ptr: *const Complex<Self::ScalarType>) -> Self {
318        _mm_loadu_pd(ptr as *const f64)
319    }
320
321    #[inline(always)]
322    unsafe fn load_partial_lo_complex(_ptr: *const Complex<Self::ScalarType>) -> Self {
323        unimplemented!("Impossible to do a load store of complex f64's");
324    }
325
326    #[inline(always)]
327    unsafe fn load1_complex(_ptr: *const Complex<Self::ScalarType>) -> Self {
328        unimplemented!("Impossible to do a load store of complex f64's");
329    }
330
331    #[inline(always)]
332    unsafe fn store_complex(ptr: *mut Complex<Self::ScalarType>, data: Self) {
333        _mm_storeu_pd(ptr as *mut f64, data);
334    }
335
336    #[inline(always)]
337    unsafe fn store_partial_lo_complex(_ptr: *mut Complex<Self::ScalarType>, _data: Self) {
338        unimplemented!("Impossible to do a partial store of complex f64's");
339    }
340
341    #[inline(always)]
342    unsafe fn store_partial_hi_complex(_ptr: *mut Complex<Self::ScalarType>, _data: Self) {
343        unimplemented!("Impossible to do a partial store of complex f64's");
344    }
345
346    #[inline(always)]
347    unsafe fn neg(a: Self) -> Self {
348        _mm_xor_pd(a, _mm_set1_pd(-0.0))
349    }
350    #[inline(always)]
351    unsafe fn add(a: Self, b: Self) -> Self {
352        _mm_add_pd(a, b)
353    }
354    #[inline(always)]
355    unsafe fn mul(a: Self, b: Self) -> Self {
356        _mm_mul_pd(a, b)
357    }
358    #[inline(always)]
359    unsafe fn fmadd(acc: Self, a: Self, b: Self) -> Self {
360        _mm_add_pd(acc, _mm_mul_pd(a, b))
361    }
362    #[inline(always)]
363    unsafe fn nmadd(acc: Self, a: Self, b: Self) -> Self {
364        _mm_sub_pd(acc, _mm_mul_pd(a, b))
365    }
366
367    #[inline(always)]
368    unsafe fn broadcast_scalar(value: Self::ScalarType) -> Self {
369        _mm_set1_pd(value)
370    }
371
372    #[inline(always)]
373    unsafe fn make_mixedradix_twiddle_chunk(
374        x: usize,
375        y: usize,
376        len: usize,
377        direction: FftDirection,
378    ) -> Self {
379        let mut twiddle_chunk = [Complex::<f64>::zero(); Self::COMPLEX_PER_VECTOR];
380        for i in 0..Self::COMPLEX_PER_VECTOR {
381            twiddle_chunk[i] = twiddles::compute_twiddle(y * (x + i), len, direction);
382        }
383
384        twiddle_chunk.as_slice().load_complex(0)
385    }
386
387    #[inline(always)]
388    unsafe fn mul_complex(left: Self, right: Self) -> Self {
389        // SSE3, taken from Intel performance manual
390        let mut temp1 = _mm_unpacklo_pd(right, right);
391        let mut temp2 = _mm_unpackhi_pd(right, right);
392        temp1 = _mm_mul_pd(temp1, left);
393        temp2 = _mm_mul_pd(temp2, left);
394        temp2 = _mm_shuffle_pd(temp2, temp2, 0x01);
395        _mm_addsub_pd(temp1, temp2)
396    }
397
398    #[inline(always)]
399    unsafe fn make_rotate90(direction: FftDirection) -> Rotation90<Self> {
400        Rotation90(match direction {
401            FftDirection::Forward => _mm_set_pd(-0.0, 0.0),
402            FftDirection::Inverse => _mm_set_pd(0.0, -0.0),
403        })
404    }
405
406    #[inline(always)]
407    unsafe fn apply_rotate90(direction: Rotation90<Self>, values: Self) -> Self {
408        let temp = _mm_shuffle_pd(values, values, 0x01);
409        _mm_xor_pd(temp, direction.0)
410    }
411
412    #[inline(always)]
413    unsafe fn column_butterfly2(rows: [Self; 2]) -> [Self; 2] {
414        [_mm_add_pd(rows[0], rows[1]), _mm_sub_pd(rows[0], rows[1])]
415    }
416
417    #[inline(always)]
418    unsafe fn column_butterfly4(rows: [Self; 4], rotation: Rotation90<Self>) -> [Self; 4] {
419        // Algorithm: 2x2 mixed radix
420
421        // Perform the first set of size-2 FFTs.
422        let [mid0, mid2] = Self::column_butterfly2([rows[0], rows[2]]);
423        let [mid1, mid3] = Self::column_butterfly2([rows[1], rows[3]]);
424
425        // Apply twiddle factors (in this case just a rotation)
426        let mid3_rotated = Self::apply_rotate90(rotation, mid3);
427
428        // Transpose the data and do size-2 FFTs down the columns
429        let [output0, output1] = Self::column_butterfly2([mid0, mid1]);
430        let [output2, output3] = Self::column_butterfly2([mid2, mid3_rotated]);
431
432        // Swap outputs 1 and 2 in the output to do a square transpose
433        [output0, output2, output1, output3]
434    }
435}
436
437// A trait to handle reading from an array of complex floats into SSE vectors.
438// SSE works with 128-bit vectors, meaning a vector can hold two complex f32,
439// or a single complex f64.
440pub trait SseArray<S: SseNum>: Deref {
441    // Load complex numbers from the array to fill a SSE vector.
442    unsafe fn load_complex(&self, index: usize) -> S::VectorType;
443    // Load a single complex number from the array into a SSE vector, setting the unused elements to zero.
444    unsafe fn load_partial_lo_complex(&self, index: usize) -> S::VectorType;
445    // Load a single complex number from the array, and copy it to all elements of a SSE vector.
446    unsafe fn load1_complex(&self, index: usize) -> S::VectorType;
447}
448
449impl<S: SseNum> SseArray<S> for &[Complex<S>] {
450    #[inline(always)]
451    unsafe fn load_complex(&self, index: usize) -> S::VectorType {
452        debug_assert!(self.len() >= index + S::VectorType::COMPLEX_PER_VECTOR);
453        S::VectorType::load_complex(self.as_ptr().add(index))
454    }
455
456    #[inline(always)]
457    unsafe fn load_partial_lo_complex(&self, index: usize) -> S::VectorType {
458        debug_assert!(self.len() >= index + 1);
459        S::VectorType::load_partial_lo_complex(self.as_ptr().add(index))
460    }
461
462    #[inline(always)]
463    unsafe fn load1_complex(&self, index: usize) -> S::VectorType {
464        debug_assert!(self.len() >= index + 1);
465        S::VectorType::load1_complex(self.as_ptr().add(index))
466    }
467}
468impl<S: SseNum> SseArray<S> for &mut [Complex<S>] {
469    #[inline(always)]
470    unsafe fn load_complex(&self, index: usize) -> S::VectorType {
471        debug_assert!(self.len() >= index + S::VectorType::COMPLEX_PER_VECTOR);
472        S::VectorType::load_complex(self.as_ptr().add(index))
473    }
474
475    #[inline(always)]
476    unsafe fn load_partial_lo_complex(&self, index: usize) -> S::VectorType {
477        debug_assert!(self.len() >= index + 1);
478        S::VectorType::load_partial_lo_complex(self.as_ptr().add(index))
479    }
480
481    #[inline(always)]
482    unsafe fn load1_complex(&self, index: usize) -> S::VectorType {
483        debug_assert!(self.len() >= index + 1);
484        S::VectorType::load1_complex(self.as_ptr().add(index))
485    }
486}
487
488impl<'a, S: SseNum> SseArray<S> for DoubleBuf<'a, S>
489where
490    &'a [Complex<S>]: SseArray<S>,
491{
492    #[inline(always)]
493    unsafe fn load_complex(&self, index: usize) -> S::VectorType {
494        self.input.load_complex(index)
495    }
496    #[inline(always)]
497    unsafe fn load_partial_lo_complex(&self, index: usize) -> S::VectorType {
498        self.input.load_partial_lo_complex(index)
499    }
500    #[inline(always)]
501    unsafe fn load1_complex(&self, index: usize) -> S::VectorType {
502        self.input.load1_complex(index)
503    }
504}
505
506// A trait to handle writing to an array of complex floats from SSE vectors.
507// SSE works with 128-bit vectors, meaning a vector can hold two complex f32,
508// or a single complex f64.
509pub trait SseArrayMut<S: SseNum>: SseArray<S> + DerefMut {
510    // Store all complex numbers from a SSE vector to the array.
511    unsafe fn store_complex(&mut self, vector: S::VectorType, index: usize);
512    // Store the low complex number from a SSE vector to the array.
513    unsafe fn store_partial_lo_complex(&mut self, vector: S::VectorType, index: usize);
514    // Store the high complex number from a SSE vector to the array.
515    unsafe fn store_partial_hi_complex(&mut self, vector: S::VectorType, index: usize);
516}
517
518impl<S: SseNum> SseArrayMut<S> for &mut [Complex<S>] {
519    #[inline(always)]
520    unsafe fn store_complex(&mut self, vector: S::VectorType, index: usize) {
521        debug_assert!(self.len() >= index + S::VectorType::COMPLEX_PER_VECTOR);
522        S::VectorType::store_complex(self.as_mut_ptr().add(index), vector)
523    }
524
525    #[inline(always)]
526    unsafe fn store_partial_hi_complex(&mut self, vector: S::VectorType, index: usize) {
527        debug_assert!(self.len() >= index + 1);
528        S::VectorType::store_partial_hi_complex(self.as_mut_ptr().add(index), vector)
529    }
530    #[inline(always)]
531    unsafe fn store_partial_lo_complex(&mut self, vector: S::VectorType, index: usize) {
532        debug_assert!(self.len() >= index + 1);
533        S::VectorType::store_partial_lo_complex(self.as_mut_ptr().add(index), vector)
534    }
535}
536
537impl<'a, T: SseNum> SseArrayMut<T> for DoubleBuf<'a, T>
538where
539    Self: SseArray<T>,
540    &'a mut [Complex<T>]: SseArrayMut<T>,
541{
542    #[inline(always)]
543    unsafe fn store_complex(&mut self, vector: T::VectorType, index: usize) {
544        self.output.store_complex(vector, index);
545    }
546    #[inline(always)]
547    unsafe fn store_partial_lo_complex(&mut self, vector: T::VectorType, index: usize) {
548        self.output.store_partial_lo_complex(vector, index);
549    }
550    #[inline(always)]
551    unsafe fn store_partial_hi_complex(&mut self, vector: T::VectorType, index: usize) {
552        self.output.store_partial_hi_complex(vector, index);
553    }
554}