1use core::arch::x86_64::*;
2
3pub struct Rotate90F32 {
11 sign_hi: __m128,
13 sign_both: __m128,
14}
15
16impl Rotate90F32 {
17 pub fn new(positive: bool) -> Self {
18 let sign_hi = unsafe {
28 if positive {
29 _mm_set_ps(0.0, -0.0, 0.0, 0.0)
30 } else {
31 _mm_set_ps(-0.0, 0.0, 0.0, 0.0)
32 }
33 };
34 let sign_both = unsafe {
35 if positive {
36 _mm_set_ps(0.0, -0.0, 0.0, -0.0)
37 } else {
38 _mm_set_ps(-0.0, 0.0, -0.0, 0.0)
39 }
40 };
41 Self {
42 sign_hi,
44 sign_both,
45 }
46 }
47
48 #[inline(always)]
49 pub unsafe fn rotate_hi(&self, values: __m128) -> __m128 {
50 let temp = _mm_shuffle_ps(values, values, 0xB4);
51 _mm_xor_ps(temp, self.sign_hi)
52 }
53
54 #[inline(always)]
62 pub unsafe fn rotate_both(&self, values: __m128) -> __m128 {
63 let temp = _mm_shuffle_ps(values, values, 0xB1);
64 _mm_xor_ps(temp, self.sign_both)
65 }
66
67 #[inline(always)]
68 pub unsafe fn rotate_both_45(&self, values: __m128) -> __m128 {
69 let rotated = self.rotate_both(values);
70 let sum = _mm_add_ps(rotated, values);
71 _mm_mul_ps(sum, _mm_set1_ps(0.5f32.sqrt()))
72 }
73
74 #[inline(always)]
75 pub unsafe fn rotate_both_135(&self, values: __m128) -> __m128 {
76 let rotated = self.rotate_both(values);
77 let diff = _mm_sub_ps(rotated, values);
78 _mm_mul_ps(diff, _mm_set1_ps(0.5f32.sqrt()))
79 }
80
81 #[inline(always)]
82 pub unsafe fn rotate_both_225(&self, values: __m128) -> __m128 {
83 let rotated = self.rotate_both(values);
84 let diff = _mm_add_ps(rotated, values);
85 _mm_mul_ps(diff, _mm_set1_ps(-(0.5f32.sqrt())))
86 }
87}
88
89#[inline(always)]
94pub unsafe fn extract_lo_lo_f32(left: __m128, right: __m128) -> __m128 {
95 _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(left), _mm_castps_pd(right)))
97}
98
99#[inline(always)]
104pub unsafe fn extract_hi_hi_f32(left: __m128, right: __m128) -> __m128 {
105 _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(left), _mm_castps_pd(right)))
106}
107
108#[inline(always)]
113pub unsafe fn extract_lo_hi_f32(left: __m128, right: __m128) -> __m128 {
114 _mm_blend_ps(left, right, 0x0C)
115}
116
117#[inline(always)]
122pub unsafe fn extract_hi_lo_f32(left: __m128, right: __m128) -> __m128 {
123 _mm_shuffle_ps(left, right, 0x4E)
124}
125
126#[inline(always)]
130pub unsafe fn reverse_complex_elements_f32(values: __m128) -> __m128 {
131 _mm_shuffle_ps(values, values, 0x4E)
132}
133
134#[inline(always)]
138pub unsafe fn negate_hi_f32(values: __m128) -> __m128 {
139 _mm_xor_ps(values, _mm_set_ps(-0.0, -0.0, 0.0, 0.0))
140}
141
142#[inline(always)]
146pub unsafe fn duplicate_lo_f32(values: __m128) -> __m128 {
147 _mm_shuffle_ps(values, values, 0x44)
148}
149
150#[inline(always)]
154pub unsafe fn duplicate_hi_f32(values: __m128) -> __m128 {
155 _mm_shuffle_ps(values, values, 0xEE)
156}
157
158#[inline(always)]
161pub unsafe fn transpose_complex_2x2_f32(left: __m128, right: __m128) -> [__m128; 2] {
162 let temp02 = extract_lo_lo_f32(left, right);
163 let temp13 = extract_hi_hi_f32(left, right);
164 [temp02, temp13]
165}
166
167pub(crate) struct Rotate90F64 {
175 sign: __m128d,
176}
177
178impl Rotate90F64 {
179 pub fn new(positive: bool) -> Self {
180 let sign = unsafe {
181 if positive {
182 _mm_set_pd(0.0, -0.0)
183 } else {
184 _mm_set_pd(-0.0, 0.0)
185 }
186 };
187 Self { sign }
188 }
189
190 #[inline(always)]
191 pub unsafe fn rotate(&self, values: __m128d) -> __m128d {
192 let temp = _mm_shuffle_pd(values, values, 0x01);
193 _mm_xor_pd(temp, self.sign)
194 }
195
196 #[inline(always)]
197 pub unsafe fn rotate_45(&self, values: __m128d) -> __m128d {
198 let rotated = self.rotate(values);
199 let sum = _mm_add_pd(rotated, values);
200 _mm_mul_pd(sum, _mm_set1_pd(0.5f64.sqrt()))
201 }
202
203 #[inline(always)]
204 pub unsafe fn rotate_135(&self, values: __m128d) -> __m128d {
205 let rotated = self.rotate(values);
206 let diff = _mm_sub_pd(rotated, values);
207 _mm_mul_pd(diff, _mm_set1_pd(0.5f64.sqrt()))
208 }
209
210 #[inline(always)]
211 pub unsafe fn rotate_225(&self, values: __m128d) -> __m128d {
212 let rotated = self.rotate(values);
213 let diff = _mm_add_pd(rotated, values);
214 _mm_mul_pd(diff, _mm_set1_pd(-(0.5f64.sqrt())))
215 }
216}
217
218#[cfg(test)]
219mod unit_tests {
220 use crate::sse::sse_vector::SseVector;
221
222 use super::*;
223 use num_complex::Complex;
224
225 #[test]
226 fn test_mul_complex_f64() {
227 unsafe {
228 let right = _mm_set_pd(1.0, 2.0);
229 let left = _mm_set_pd(5.0, 7.0);
230 let res = SseVector::mul_complex(left, right);
231 let expected = _mm_set_pd(2.0 * 5.0 + 1.0 * 7.0, 2.0 * 7.0 - 1.0 * 5.0);
232 assert_eq!(
233 std::mem::transmute::<__m128d, Complex<f64>>(res),
234 std::mem::transmute::<__m128d, Complex<f64>>(expected)
235 );
236 }
237 }
238
239 #[test]
240 fn test_mul_complex_f32() {
241 unsafe {
242 let val1 = Complex::<f32>::new(1.0, 2.5);
243 let val2 = Complex::<f32>::new(3.2, 4.2);
244 let val3 = Complex::<f32>::new(5.6, 6.2);
245 let val4 = Complex::<f32>::new(7.4, 8.3);
246
247 let nbr2 = _mm_set_ps(val4.im, val4.re, val3.im, val3.re);
248 let nbr1 = _mm_set_ps(val2.im, val2.re, val1.im, val1.re);
249 let res = SseVector::mul_complex(nbr1, nbr2);
250 let res = std::mem::transmute::<__m128, [Complex<f32>; 2]>(res);
251 let expected = [val1 * val3, val2 * val4];
252 assert_eq!(res, expected);
253 }
254 }
255
256 #[test]
257 fn test_pack() {
258 unsafe {
259 let nbr2 = _mm_set_ps(8.0, 7.0, 6.0, 5.0);
260 let nbr1 = _mm_set_ps(4.0, 3.0, 2.0, 1.0);
261 let first = extract_lo_lo_f32(nbr1, nbr2);
262 let second = extract_hi_hi_f32(nbr1, nbr2);
263 let first = std::mem::transmute::<__m128, [Complex<f32>; 2]>(first);
264 let second = std::mem::transmute::<__m128, [Complex<f32>; 2]>(second);
265 let first_expected = [Complex::new(1.0, 2.0), Complex::new(5.0, 6.0)];
266 let second_expected = [Complex::new(3.0, 4.0), Complex::new(7.0, 8.0)];
267 assert_eq!(first, first_expected);
268 assert_eq!(second, second_expected);
269 }
270 }
271}