rustfft/sse/
sse_utils.rs

1use core::arch::x86_64::*;
2
3//  __  __       _   _               _________  _     _ _
4// |  \/  | __ _| |_| |__           |___ /___ \| |__ (_) |_
5// | |\/| |/ _` | __| '_ \   _____    |_ \ __) | '_ \| | __|
6// | |  | | (_| | |_| | | | |_____|  ___) / __/| |_) | | |_
7// |_|  |_|\__,_|\__|_| |_|         |____/_____|_.__/|_|\__|
8//
9
10pub struct Rotate90F32 {
11    //sign_lo: __m128,
12    sign_hi: __m128,
13    sign_both: __m128,
14}
15
16impl Rotate90F32 {
17    pub fn new(positive: bool) -> Self {
18        // There doesn't seem to be any need for rotating just the first element, but let's keep the code just in case
19        //let sign_lo = unsafe {
20        //    if positive {
21        //        _mm_set_ps(0.0, 0.0, 0.0, -0.0)
22        //    }
23        //    else {
24        //        _mm_set_ps(0.0, 0.0, -0.0, 0.0)
25        //    }
26        //};
27        let sign_hi = unsafe {
28            if positive {
29                _mm_set_ps(0.0, -0.0, 0.0, 0.0)
30            } else {
31                _mm_set_ps(-0.0, 0.0, 0.0, 0.0)
32            }
33        };
34        let sign_both = unsafe {
35            if positive {
36                _mm_set_ps(0.0, -0.0, 0.0, -0.0)
37            } else {
38                _mm_set_ps(-0.0, 0.0, -0.0, 0.0)
39            }
40        };
41        Self {
42            //sign_lo,
43            sign_hi,
44            sign_both,
45        }
46    }
47
48    #[inline(always)]
49    pub unsafe fn rotate_hi(&self, values: __m128) -> __m128 {
50        let temp = _mm_shuffle_ps(values, values, 0xB4);
51        _mm_xor_ps(temp, self.sign_hi)
52    }
53
54    // There doesn't seem to be any need for rotating just the first element, but let's keep the code just in case
55    //#[inline(always)]
56    //pub unsafe fn rotate_lo(&self, values: __m128) -> __m128 {
57    //    let temp = _mm_shuffle_ps(values, values, 0xE1);
58    //    _mm_xor_ps(temp, self.sign_lo)
59    //}
60
61    #[inline(always)]
62    pub unsafe fn rotate_both(&self, values: __m128) -> __m128 {
63        let temp = _mm_shuffle_ps(values, values, 0xB1);
64        _mm_xor_ps(temp, self.sign_both)
65    }
66
67    #[inline(always)]
68    pub unsafe fn rotate_both_45(&self, values: __m128) -> __m128 {
69        let rotated = self.rotate_both(values);
70        let sum = _mm_add_ps(rotated, values);
71        _mm_mul_ps(sum, _mm_set1_ps(0.5f32.sqrt()))
72    }
73
74    #[inline(always)]
75    pub unsafe fn rotate_both_135(&self, values: __m128) -> __m128 {
76        let rotated = self.rotate_both(values);
77        let diff = _mm_sub_ps(rotated, values);
78        _mm_mul_ps(diff, _mm_set1_ps(0.5f32.sqrt()))
79    }
80
81    #[inline(always)]
82    pub unsafe fn rotate_both_225(&self, values: __m128) -> __m128 {
83        let rotated = self.rotate_both(values);
84        let diff = _mm_add_ps(rotated, values);
85        _mm_mul_ps(diff, _mm_set1_ps(-(0.5f32.sqrt())))
86    }
87}
88
89// Pack low (1st) complex
90// left: r1.re, r1.im, r2.re, r2.im
91// right: l1.re, l1.im, l2.re, l2.im
92// --> r1.re, r1.im, l1.re, l1.im
93#[inline(always)]
94pub unsafe fn extract_lo_lo_f32(left: __m128, right: __m128) -> __m128 {
95    //_mm_shuffle_ps(left, right, 0x44)
96    _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(left), _mm_castps_pd(right)))
97}
98
99// Pack high (2nd) complex
100// left: r1.re, r1.im, r2.re, r2.im
101// right: l1.re, l1.im, l2.re, l2.im
102// --> r2.re, r2.im, l2.re, l2.im
103#[inline(always)]
104pub unsafe fn extract_hi_hi_f32(left: __m128, right: __m128) -> __m128 {
105    _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(left), _mm_castps_pd(right)))
106}
107
108// Pack low (1st) and high (2nd) complex
109// left: r1.re, r1.im, r2.re, r2.im
110// right: l1.re, l1.im, l2.re, l2.im
111// --> r1.re, r1.im, l2.re, l2.im
112#[inline(always)]
113pub unsafe fn extract_lo_hi_f32(left: __m128, right: __m128) -> __m128 {
114    _mm_blend_ps(left, right, 0x0C)
115}
116
117// Pack  high (2nd) and low (1st) complex
118// left: r1.re, r1.im, r2.re, r2.im
119// right: l1.re, l1.im, l2.re, l2.im
120// --> r2.re, r2.im, l1.re, l1.im
121#[inline(always)]
122pub unsafe fn extract_hi_lo_f32(left: __m128, right: __m128) -> __m128 {
123    _mm_shuffle_ps(left, right, 0x4E)
124}
125
126// Reverse complex
127// values: a.re, a.im, b.re, b.im
128// --> b.re, b.im, a.re, a.im
129#[inline(always)]
130pub unsafe fn reverse_complex_elements_f32(values: __m128) -> __m128 {
131    _mm_shuffle_ps(values, values, 0x4E)
132}
133
134// Invert sign of high (2nd) complex
135// values: a.re, a.im, b.re, b.im
136// -->  a.re, a.im, -b.re, -b.im
137#[inline(always)]
138pub unsafe fn negate_hi_f32(values: __m128) -> __m128 {
139    _mm_xor_ps(values, _mm_set_ps(-0.0, -0.0, 0.0, 0.0))
140}
141
142// Duplicate low (1st) complex
143// values: a.re, a.im, b.re, b.im
144// --> a.re, a.im, a.re, a.im
145#[inline(always)]
146pub unsafe fn duplicate_lo_f32(values: __m128) -> __m128 {
147    _mm_shuffle_ps(values, values, 0x44)
148}
149
150// Duplicate high (2nd) complex
151// values: a.re, a.im, b.re, b.im
152// --> b.re, b.im, b.re, b.im
153#[inline(always)]
154pub unsafe fn duplicate_hi_f32(values: __m128) -> __m128 {
155    _mm_shuffle_ps(values, values, 0xEE)
156}
157
158// transpose a 2x2 complex matrix given as [x0, x1], [x2, x3]
159// result is [x0, x2], [x1, x3]
160#[inline(always)]
161pub unsafe fn transpose_complex_2x2_f32(left: __m128, right: __m128) -> [__m128; 2] {
162    let temp02 = extract_lo_lo_f32(left, right);
163    let temp13 = extract_hi_hi_f32(left, right);
164    [temp02, temp13]
165}
166
167//  __  __       _   _                __   _  _   _     _ _
168// |  \/  | __ _| |_| |__            / /_ | || | | |__ (_) |_
169// | |\/| |/ _` | __| '_ \   _____  | '_ \| || |_| '_ \| | __|
170// | |  | | (_| | |_| | | | |_____| | (_) |__   _| |_) | | |_
171// |_|  |_|\__,_|\__|_| |_|          \___/   |_| |_.__/|_|\__|
172//
173
174pub(crate) struct Rotate90F64 {
175    sign: __m128d,
176}
177
178impl Rotate90F64 {
179    pub fn new(positive: bool) -> Self {
180        let sign = unsafe {
181            if positive {
182                _mm_set_pd(0.0, -0.0)
183            } else {
184                _mm_set_pd(-0.0, 0.0)
185            }
186        };
187        Self { sign }
188    }
189
190    #[inline(always)]
191    pub unsafe fn rotate(&self, values: __m128d) -> __m128d {
192        let temp = _mm_shuffle_pd(values, values, 0x01);
193        _mm_xor_pd(temp, self.sign)
194    }
195
196    #[inline(always)]
197    pub unsafe fn rotate_45(&self, values: __m128d) -> __m128d {
198        let rotated = self.rotate(values);
199        let sum = _mm_add_pd(rotated, values);
200        _mm_mul_pd(sum, _mm_set1_pd(0.5f64.sqrt()))
201    }
202
203    #[inline(always)]
204    pub unsafe fn rotate_135(&self, values: __m128d) -> __m128d {
205        let rotated = self.rotate(values);
206        let diff = _mm_sub_pd(rotated, values);
207        _mm_mul_pd(diff, _mm_set1_pd(0.5f64.sqrt()))
208    }
209
210    #[inline(always)]
211    pub unsafe fn rotate_225(&self, values: __m128d) -> __m128d {
212        let rotated = self.rotate(values);
213        let diff = _mm_add_pd(rotated, values);
214        _mm_mul_pd(diff, _mm_set1_pd(-(0.5f64.sqrt())))
215    }
216}
217
218#[cfg(test)]
219mod unit_tests {
220    use crate::sse::sse_vector::SseVector;
221
222    use super::*;
223    use num_complex::Complex;
224
225    #[test]
226    fn test_mul_complex_f64() {
227        unsafe {
228            let right = _mm_set_pd(1.0, 2.0);
229            let left = _mm_set_pd(5.0, 7.0);
230            let res = SseVector::mul_complex(left, right);
231            let expected = _mm_set_pd(2.0 * 5.0 + 1.0 * 7.0, 2.0 * 7.0 - 1.0 * 5.0);
232            assert_eq!(
233                std::mem::transmute::<__m128d, Complex<f64>>(res),
234                std::mem::transmute::<__m128d, Complex<f64>>(expected)
235            );
236        }
237    }
238
239    #[test]
240    fn test_mul_complex_f32() {
241        unsafe {
242            let val1 = Complex::<f32>::new(1.0, 2.5);
243            let val2 = Complex::<f32>::new(3.2, 4.2);
244            let val3 = Complex::<f32>::new(5.6, 6.2);
245            let val4 = Complex::<f32>::new(7.4, 8.3);
246
247            let nbr2 = _mm_set_ps(val4.im, val4.re, val3.im, val3.re);
248            let nbr1 = _mm_set_ps(val2.im, val2.re, val1.im, val1.re);
249            let res = SseVector::mul_complex(nbr1, nbr2);
250            let res = std::mem::transmute::<__m128, [Complex<f32>; 2]>(res);
251            let expected = [val1 * val3, val2 * val4];
252            assert_eq!(res, expected);
253        }
254    }
255
256    #[test]
257    fn test_pack() {
258        unsafe {
259            let nbr2 = _mm_set_ps(8.0, 7.0, 6.0, 5.0);
260            let nbr1 = _mm_set_ps(4.0, 3.0, 2.0, 1.0);
261            let first = extract_lo_lo_f32(nbr1, nbr2);
262            let second = extract_hi_hi_f32(nbr1, nbr2);
263            let first = std::mem::transmute::<__m128, [Complex<f32>; 2]>(first);
264            let second = std::mem::transmute::<__m128, [Complex<f32>; 2]>(second);
265            let first_expected = [Complex::new(1.0, 2.0), Complex::new(5.0, 6.0)];
266            let second_expected = [Complex::new(3.0, 4.0), Complex::new(7.0, 8.0)];
267            assert_eq!(first, first_expected);
268            assert_eq!(second, second_expected);
269        }
270    }
271}