1use crate::kernel::GemmKernel;
10use crate::kernel::GemmSelect;
11use crate::kernel::{U2, U4, c32, Element, c32_mul as mul};
12use crate::archparam;
13use crate::cgemm_common::pack_complex;
14
15#[cfg(any(target_arch="x86", target_arch="x86_64"))]
16struct KernelAvx2;
17#[cfg(any(target_arch="x86", target_arch="x86_64"))]
18struct KernelFma;
19
20#[cfg(target_arch = "aarch64")]
21#[cfg(has_aarch64_simd)]
22struct KernelNeon;
23
24struct KernelFallback;
25
26type T = c32;
27type TReal = f32;
28
29#[inline]
35pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
36 #[cfg(any(target_arch="x86", target_arch="x86_64"))]
38 {
39 if is_x86_feature_detected_!("fma") {
40 if is_x86_feature_detected_!("avx2") {
41 return selector.select(KernelAvx2);
42 }
43 return selector.select(KernelFma);
44 }
45 }
46 #[cfg(target_arch = "aarch64")]
47 #[cfg(has_aarch64_simd)]
48 {
49 if is_aarch64_feature_detected_!("neon") {
50 return selector.select(KernelNeon);
51 }
52 }
53 return selector.select(KernelFallback);
54}
55
56#[cfg(any(target_arch="x86", target_arch="x86_64"))]
57impl GemmKernel for KernelAvx2 {
58 type Elem = T;
59
60 type MRTy = U4;
61 type NRTy = U4;
62
63 #[inline(always)]
64 fn align_to() -> usize { 32 }
65
66 #[inline(always)]
67 fn always_masked() -> bool { KernelFallback::always_masked() }
68
69 #[inline(always)]
70 fn nc() -> usize { archparam::C_NC }
71 #[inline(always)]
72 fn kc() -> usize { archparam::C_KC }
73 #[inline(always)]
74 fn mc() -> usize { archparam::C_MC }
75
76 pack_methods!{}
77
78 #[inline(always)]
79 unsafe fn kernel(
80 k: usize,
81 alpha: T,
82 a: *const T,
83 b: *const T,
84 beta: T,
85 c: *mut T, rsc: isize, csc: isize) {
86 kernel_target_avx2(k, alpha, a, b, beta, c, rsc, csc)
87 }
88}
89
90#[cfg(any(target_arch="x86", target_arch="x86_64"))]
91impl GemmKernel for KernelFma {
92 type Elem = T;
93
94 type MRTy = U4;
95 type NRTy = U4;
96
97 #[inline(always)]
98 fn align_to() -> usize { 16 }
99
100 #[inline(always)]
101 fn always_masked() -> bool { KernelFallback::always_masked() }
102
103 #[inline(always)]
104 fn nc() -> usize { archparam::C_NC }
105 #[inline(always)]
106 fn kc() -> usize { archparam::C_KC }
107 #[inline(always)]
108 fn mc() -> usize { archparam::C_MC }
109
110 pack_methods!{}
111
112 #[inline(always)]
113 unsafe fn kernel(
114 k: usize,
115 alpha: T,
116 a: *const T,
117 b: *const T,
118 beta: T,
119 c: *mut T, rsc: isize, csc: isize) {
120 kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc)
121 }
122}
123
124#[cfg(target_arch = "aarch64")]
125#[cfg(has_aarch64_simd)]
126impl GemmKernel for KernelNeon {
127 type Elem = T;
128
129 type MRTy = U4;
130 type NRTy = U2;
131
132 #[inline(always)]
133 fn align_to() -> usize { 16 }
134
135 #[inline(always)]
136 fn always_masked() -> bool { KernelFallback::always_masked() }
137
138 #[inline(always)]
139 fn nc() -> usize { archparam::C_NC }
140 #[inline(always)]
141 fn kc() -> usize { archparam::C_KC }
142 #[inline(always)]
143 fn mc() -> usize { archparam::C_MC }
144
145 pack_methods!{}
146
147 #[inline(always)]
148 unsafe fn kernel(
149 k: usize,
150 alpha: T,
151 a: *const T,
152 b: *const T,
153 beta: T,
154 c: *mut T, rsc: isize, csc: isize) {
155 kernel_target_neon(k, alpha, a, b, beta, c, rsc, csc)
156 }
157}
158
159impl GemmKernel for KernelFallback {
160 type Elem = T;
161
162 type MRTy = U4;
163 type NRTy = U2;
164
165 #[inline(always)]
166 fn align_to() -> usize { 0 }
167
168 #[inline(always)]
169 fn always_masked() -> bool { true }
170
171 #[inline(always)]
172 fn nc() -> usize { archparam::C_NC }
173 #[inline(always)]
174 fn kc() -> usize { archparam::C_KC }
175 #[inline(always)]
176 fn mc() -> usize { archparam::C_MC }
177
178 pack_methods!{}
179
180 #[inline(always)]
181 unsafe fn kernel(
182 k: usize,
183 alpha: T,
184 a: *const T,
185 b: *const T,
186 beta: T,
187 c: *mut T, rsc: isize, csc: isize) {
188 kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
189 }
190}
191
192#[cfg(any(target_arch="x86", target_arch="x86_64"))]
194macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
195#[cfg(any(target_arch="x86", target_arch="x86_64"))]
196macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; }
197
198#[cfg(any(target_arch="x86", target_arch="x86_64"))]
199kernel_fallback_impl_complex! {
200 [inline target_feature(enable="avx2") target_feature(enable="fma")] [fma_yes]
202 kernel_target_avx2, T, TReal, KernelAvx2::MR, KernelAvx2::NR, 4
203}
204
205
206#[cfg(any(target_arch="x86", target_arch="x86_64"))]
208macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
209#[cfg(any(target_arch="x86", target_arch="x86_64"))]
210macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; }
211
212#[cfg(any(target_arch="x86", target_arch="x86_64"))]
213kernel_fallback_impl_complex! {
214 [inline target_feature(enable="fma")] [fma_no]
216 kernel_target_fma, T, TReal, KernelFma::MR, KernelFma::NR, 2
217}
218
219#[cfg(target_arch = "aarch64")]
222#[cfg(has_aarch64_simd)]
223macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
224#[cfg(target_arch = "aarch64")]
225#[cfg(has_aarch64_simd)]
226macro_rules! loop_n { ($j:ident, $e:expr) => { loop2!($j, $e) }; }
227
228#[cfg(target_arch = "aarch64")]
229#[cfg(has_aarch64_simd)]
230kernel_fallback_impl_complex! {
231 [inline target_feature(enable="neon")] [fma_yes]
232 kernel_target_neon, T, TReal, KernelNeon::MR, KernelNeon::NR, 1
233}
234
235macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
238macro_rules! loop_n { ($j:ident, $e:expr) => { loop2!($j, $e) }; }
239
240kernel_fallback_impl_complex! {
241 [inline(always)] [fma_no]
242 kernel_fallback_impl, T, TReal, KernelFallback::MR, KernelFallback::NR, 1
243}
244
245#[inline(always)]
246unsafe fn at(ptr: *const TReal, i: usize) -> TReal {
247 *ptr.add(i)
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use crate::kernel::test::test_complex_packed_kernel;
254
255 #[test]
256 fn test_kernel_fallback_impl() {
257 test_complex_packed_kernel::<KernelFallback, _, TReal>("kernel");
258 }
259
260 #[cfg(target_arch = "aarch64")]
261 #[cfg(has_aarch64_simd)]
262 mod test_kernel_aarch64 {
263 use super::test_complex_packed_kernel;
264 use super::super::*;
265 #[cfg(feature = "std")]
266 use std::println;
267 macro_rules! test_arch_kernels {
268 ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
269 $(
270 #[test]
271 fn $name() {
272 if is_aarch64_feature_detected_!($feature_name) {
273 test_complex_packed_kernel::<$kernel_ty, _, TReal>(stringify!($name));
274 } else {
275 #[cfg(feature = "std")]
276 println!("Skipping, host does not have feature: {:?}", $feature_name);
277 }
278 }
279 )*
280 }
281 }
282
283 test_arch_kernels! {
284 "neon", neon, KernelNeon
285 }
286 }
287
288 #[cfg(any(target_arch="x86", target_arch="x86_64"))]
289 mod test_arch_kernels {
290 use super::test_complex_packed_kernel;
291 use super::super::*;
292 #[cfg(feature = "std")]
293 use std::println;
294 macro_rules! test_arch_kernels_x86 {
295 ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
296 $(
297 #[test]
298 fn $name() {
299 if is_x86_feature_detected_!($feature_name) {
300 test_complex_packed_kernel::<$kernel_ty, _, TReal>(stringify!($name));
301 } else {
302 #[cfg(feature = "std")]
303 println!("Skipping, host does not have feature: {:?}", $feature_name);
304 }
305 }
306 )*
307 }
308 }
309
310 test_arch_kernels_x86! {
311 "fma", fma, KernelFma,
312 "avx2", avx2, KernelAvx2
313 }
314 }
315}