rustfft/avx/
avx32_utils.rs

1use std::arch::x86_64::*;
2
3use super::avx_vector::{AvxVector, AvxVector256};
4
5// Treat the input like the rows of a 4x4 array, and transpose said rows to the columns
6#[inline(always)]
7pub unsafe fn transpose_4x4_f32(rows: [__m256; 4]) -> [__m256; 4] {
8    let permute0 = _mm256_permute2f128_ps(rows[0], rows[2], 0x20);
9    let permute1 = _mm256_permute2f128_ps(rows[1], rows[3], 0x20);
10    let permute2 = _mm256_permute2f128_ps(rows[0], rows[2], 0x31);
11    let permute3 = _mm256_permute2f128_ps(rows[1], rows[3], 0x31);
12
13    let [unpacked0, unpacked1] = AvxVector::unpack_complex([permute0, permute1]);
14    let [unpacked2, unpacked3] = AvxVector::unpack_complex([permute2, permute3]);
15
16    [unpacked0, unpacked1, unpacked2, unpacked3]
17}
18
19// Treat the input like the rows of a 4x8 array, and transpose it to a 8x4 array, where each array of 4 is one set of 4 columns
20// The assumption here is that it's very likely that the caller wants to do some more AVX operations on the columns of the transposed array, so the output is arranged to make that more convenient
21// The second array only has two columns of valid data. TODO: make them __m128 instead
22#[inline(always)]
23pub unsafe fn transpose_4x6_to_6x4_f32(rows: [__m256; 6]) -> ([__m256; 4], [__m256; 4]) {
24    let chunk0 = [rows[0], rows[1], rows[2], rows[3]];
25    let chunk1 = [rows[4], rows[5], _mm256_setzero_ps(), _mm256_setzero_ps()];
26
27    let output0 = transpose_4x4_f32(chunk0);
28    let output1 = transpose_4x4_f32(chunk1);
29
30    (output0, output1)
31}
32
33// Treat the input like the rows of a 8x4 array, and transpose it to a 4x8 array
34#[inline(always)]
35pub unsafe fn transpose_8x4_to_4x8_f32(rows0: [__m256; 4], rows1: [__m256; 4]) -> [__m256; 8] {
36    let transposed0 = transpose_4x4_f32(rows0);
37    let transposed1 = transpose_4x4_f32(rows1);
38
39    [
40        transposed0[0],
41        transposed0[1],
42        transposed0[2],
43        transposed0[3],
44        transposed1[0],
45        transposed1[1],
46        transposed1[2],
47        transposed1[3],
48    ]
49}
50
51// Treat the input like the rows of a 9x3 array, and transpose it to a 3x9 array.
52// our parameters are technically 10 columns, not 9 -- we're going to discard the second element of row0
53#[inline(always)]
54pub unsafe fn transpose_9x3_to_3x9_emptycolumn1_f32(
55    rows0: [__m128; 3],
56    rows1: [__m256; 3],
57    rows2: [__m256; 3],
58) -> [__m256; 9] {
59    // the first row of the output will be the first column of the input
60    let unpacked0 = AvxVector::unpacklo_complex([rows0[0], rows0[1]]);
61    let unpacked1 = AvxVector::unpacklo_complex([rows0[2], _mm_setzero_ps()]);
62    let output0 = AvxVector256::merge(unpacked0, unpacked1);
63
64    let transposed0 = transpose_4x4_f32([rows1[0], rows1[1], rows1[2], _mm256_setzero_ps()]);
65    let transposed1 = transpose_4x4_f32([rows2[0], rows2[1], rows2[2], _mm256_setzero_ps()]);
66
67    [
68        output0,
69        transposed0[0],
70        transposed0[1],
71        transposed0[2],
72        transposed0[3],
73        transposed1[0],
74        transposed1[1],
75        transposed1[2],
76        transposed1[3],
77    ]
78}
79
80// Treat the input like the rows of a 9x4 array, and transpose it to a 4x9 array.
81// our parameters are technically 10 columns, not 9 -- we're going to discard the second element of row0
82#[inline(always)]
83pub unsafe fn transpose_9x4_to_4x9_emptycolumn1_f32(
84    rows0: [__m128; 4],
85    rows1: [__m256; 4],
86    rows2: [__m256; 4],
87) -> [__m256; 9] {
88    // the first row of the output will be the first column of the input
89    let unpacked0 = AvxVector::unpacklo_complex([rows0[0], rows0[1]]);
90    let unpacked1 = AvxVector::unpacklo_complex([rows0[2], rows0[3]]);
91    let output0 = AvxVector256::merge(unpacked0, unpacked1);
92
93    let transposed0 = transpose_4x4_f32([rows1[0], rows1[1], rows1[2], rows1[3]]);
94    let transposed1 = transpose_4x4_f32([rows2[0], rows2[1], rows2[2], rows2[3]]);
95
96    [
97        output0,
98        transposed0[0],
99        transposed0[1],
100        transposed0[2],
101        transposed0[3],
102        transposed1[0],
103        transposed1[1],
104        transposed1[2],
105        transposed1[3],
106    ]
107}
108
109// Treat the input like the rows of a 9x4 array, and transpose it to a 4x9 array.
110// our parameters are technically 10 columns, not 9 -- we're going to discard the second element of row0
111#[inline(always)]
112pub unsafe fn transpose_9x6_to_6x9_emptycolumn1_f32(
113    rows0: [__m128; 6],
114    rows1: [__m256; 6],
115    rows2: [__m256; 6],
116) -> ([__m256; 9], [__m128; 9]) {
117    // the first row of the output will be the first column of the input
118    let unpacked0 = AvxVector::unpacklo_complex([rows0[0], rows0[1]]);
119    let unpacked1 = AvxVector::unpacklo_complex([rows0[2], rows0[3]]);
120    let unpacked2 = AvxVector::unpacklo_complex([rows0[4], rows0[5]]);
121    let output0 = AvxVector256::merge(unpacked0, unpacked1);
122
123    let transposed_hi0 = transpose_4x4_f32([rows1[0], rows1[1], rows1[2], rows1[3]]);
124    let transposed_hi1 = transpose_4x4_f32([rows2[0], rows2[1], rows2[2], rows2[3]]);
125
126    let [unpacked_bottom0, unpacked_bottom1] = AvxVector::unpack_complex([rows1[4], rows1[5]]);
127    let [unpacked_bottom2, unpacked_bottom3] = AvxVector::unpack_complex([rows2[4], rows2[5]]);
128
129    let transposed_lo = [
130        unpacked2,
131        unpacked_bottom0.lo(),
132        unpacked_bottom1.lo(),
133        unpacked_bottom0.hi(),
134        unpacked_bottom1.hi(),
135        unpacked_bottom2.lo(),
136        unpacked_bottom3.lo(),
137        unpacked_bottom2.hi(),
138        unpacked_bottom3.hi(),
139    ];
140
141    (
142        [
143            output0,
144            transposed_hi0[0],
145            transposed_hi0[1],
146            transposed_hi0[2],
147            transposed_hi0[3],
148            transposed_hi1[0],
149            transposed_hi1[1],
150            transposed_hi1[2],
151            transposed_hi1[3],
152        ],
153        transposed_lo,
154    )
155}
156
157// Treat the input like the rows of a 12x4 array, and transpose it to a 4x12 array
158// The assumption here is that the caller wants to do some more AVX operations on the columns of the transposed array, so the output is arranged to make that more convenient
159#[inline(always)]
160pub unsafe fn transpose_12x4_to_4x12_f32(
161    rows0: [__m256; 4],
162    rows1: [__m256; 4],
163    rows2: [__m256; 4],
164) -> [__m256; 12] {
165    let transposed0 = transpose_4x4_f32(rows0);
166    let transposed1 = transpose_4x4_f32(rows1);
167    let transposed2 = transpose_4x4_f32(rows2);
168
169    [
170        transposed0[0],
171        transposed0[1],
172        transposed0[2],
173        transposed0[3],
174        transposed1[0],
175        transposed1[1],
176        transposed1[2],
177        transposed1[3],
178        transposed2[0],
179        transposed2[1],
180        transposed2[2],
181        transposed2[3],
182    ]
183}
184
185// Treat the input like the rows of a 12x6 array, and transpose it to a 6x12 array
186// The assumption here is that the caller wants to do some more AVX operations on the columns of the transposed array, so the output is arranged to make that more convenient
187#[inline(always)]
188pub unsafe fn transpose_12x6_to_6x12_f32(
189    rows0: [__m256; 6],
190    rows1: [__m256; 6],
191    rows2: [__m256; 6],
192) -> ([__m128; 12], [__m256; 12]) {
193    let [unpacked0, unpacked1] = AvxVector::unpack_complex([rows0[0], rows0[1]]);
194    let [unpacked2, unpacked3] = AvxVector::unpack_complex([rows1[0], rows1[1]]);
195    let [unpacked4, unpacked5] = AvxVector::unpack_complex([rows2[0], rows2[1]]);
196
197    let output0 = [
198        unpacked0.lo(),
199        unpacked1.lo(),
200        unpacked0.hi(),
201        unpacked1.hi(),
202        unpacked2.lo(),
203        unpacked3.lo(),
204        unpacked2.hi(),
205        unpacked3.hi(),
206        unpacked4.lo(),
207        unpacked5.lo(),
208        unpacked4.hi(),
209        unpacked5.hi(),
210    ];
211    let transposed0 = transpose_4x4_f32([rows0[2], rows0[3], rows0[4], rows0[5]]);
212    let transposed1 = transpose_4x4_f32([rows1[2], rows1[3], rows1[4], rows1[5]]);
213    let transposed2 = transpose_4x4_f32([rows2[2], rows2[3], rows2[4], rows2[5]]);
214
215    let output1 = [
216        transposed0[0],
217        transposed0[1],
218        transposed0[2],
219        transposed0[3],
220        transposed1[0],
221        transposed1[1],
222        transposed1[2],
223        transposed1[3],
224        transposed2[0],
225        transposed2[1],
226        transposed2[2],
227        transposed2[3],
228    ];
229
230    (output0, output1)
231}
232
233// Treat the input like the rows of a 8x8 array, and transpose said rows to the columns
234// The assumption here is that the caller wants to do some more AVX operations on the columns of the transposed array, so the output is arranged to make that more convenient
235#[inline(always)]
236pub unsafe fn transpose_8x8_f32(
237    rows0: [__m256; 8],
238    rows1: [__m256; 8],
239) -> ([__m256; 8], [__m256; 8]) {
240    let chunk00 = [rows0[0], rows0[1], rows0[2], rows0[3]];
241    let chunk01 = [rows0[4], rows0[5], rows0[6], rows0[7]];
242    let chunk10 = [rows1[0], rows1[1], rows1[2], rows1[3]];
243    let chunk11 = [rows1[4], rows1[5], rows1[6], rows1[7]];
244
245    let transposed00 = transpose_4x4_f32(chunk00);
246    let transposed01 = transpose_4x4_f32(chunk10);
247    let transposed10 = transpose_4x4_f32(chunk01);
248    let transposed11 = transpose_4x4_f32(chunk11);
249
250    let output0 = [
251        transposed00[0],
252        transposed00[1],
253        transposed00[2],
254        transposed00[3],
255        transposed01[0],
256        transposed01[1],
257        transposed01[2],
258        transposed01[3],
259    ];
260    let output1 = [
261        transposed10[0],
262        transposed10[1],
263        transposed10[2],
264        transposed10[3],
265        transposed11[0],
266        transposed11[1],
267        transposed11[2],
268        transposed11[3],
269    ];
270
271    (output0, output1)
272}