1use crate::kernel::GemmKernel;
10use crate::kernel::GemmSelect;
11use crate::kernel::{U2, U4, c64, Element, c64_mul as mul};
12use crate::archparam;
13use crate::cgemm_common::pack_complex;
14
15#[cfg(any(target_arch="x86", target_arch="x86_64"))]
16struct KernelAvx2;
17#[cfg(any(target_arch="x86", target_arch="x86_64"))]
18struct KernelFma;
19
20#[cfg(target_arch = "aarch64")]
21#[cfg(has_aarch64_simd)]
22struct KernelNeon;
23
24struct KernelFallback;
25
26type T = c64;
27type TReal = f64;
28
29#[inline]
35pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
36 #[cfg(any(target_arch="x86", target_arch="x86_64"))]
38 {
39 if is_x86_feature_detected_!("fma") {
40 if is_x86_feature_detected_!("avx2") {
41 return selector.select(KernelAvx2);
42 }
43 return selector.select(KernelFma);
44 }
45 }
46 #[cfg(target_arch = "aarch64")]
47 #[cfg(has_aarch64_simd)]
48 {
49 if is_aarch64_feature_detected_!("neon") {
50 return selector.select(KernelNeon);
51 }
52 }
53 return selector.select(KernelFallback);
54}
55
56macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
57macro_rules! loop_n { ($j:ident, $e:expr) => { loop2!($j, $e) }; }
58
59#[cfg(any(target_arch="x86", target_arch="x86_64"))]
60impl GemmKernel for KernelAvx2 {
61 type Elem = T;
62
63 type MRTy = U4;
64 type NRTy = U2;
65
66 #[inline(always)]
67 fn align_to() -> usize { 32 }
68
69 #[inline(always)]
70 fn always_masked() -> bool { KernelFallback::always_masked() }
71
72 #[inline(always)]
73 fn nc() -> usize { archparam::Z_NC }
74 #[inline(always)]
75 fn kc() -> usize { archparam::Z_KC }
76 #[inline(always)]
77 fn mc() -> usize { archparam::Z_MC }
78
79 pack_methods!{}
80
81 #[inline(always)]
82 unsafe fn kernel(
83 k: usize,
84 alpha: T,
85 a: *const T,
86 b: *const T,
87 beta: T,
88 c: *mut T, rsc: isize, csc: isize) {
89 kernel_target_avx2(k, alpha, a, b, beta, c, rsc, csc)
90 }
91}
92
93#[cfg(any(target_arch="x86", target_arch="x86_64"))]
94impl GemmKernel for KernelFma {
95 type Elem = T;
96
97 type MRTy = <KernelFallback as GemmKernel>::MRTy;
98 type NRTy = <KernelFallback as GemmKernel>::NRTy;
99
100 #[inline(always)]
101 fn align_to() -> usize { 16 }
102
103 #[inline(always)]
104 fn always_masked() -> bool { KernelFallback::always_masked() }
105
106 #[inline(always)]
107 fn nc() -> usize { archparam::Z_NC }
108 #[inline(always)]
109 fn kc() -> usize { archparam::Z_KC }
110 #[inline(always)]
111 fn mc() -> usize { archparam::Z_MC }
112
113 pack_methods!{}
114
115 #[inline(always)]
116 unsafe fn kernel(
117 k: usize,
118 alpha: T,
119 a: *const T,
120 b: *const T,
121 beta: T,
122 c: *mut T, rsc: isize, csc: isize) {
123 kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc)
124 }
125}
126
127#[cfg(target_arch = "aarch64")]
128#[cfg(has_aarch64_simd)]
129impl GemmKernel for KernelNeon {
130 type Elem = T;
131
132 type MRTy = U4;
133 type NRTy = U2;
134
135 #[inline(always)]
136 fn align_to() -> usize { 16 }
137
138 #[inline(always)]
139 fn always_masked() -> bool { KernelFallback::always_masked() }
140
141 #[inline(always)]
142 fn nc() -> usize { archparam::Z_NC }
143 #[inline(always)]
144 fn kc() -> usize { archparam::Z_KC }
145 #[inline(always)]
146 fn mc() -> usize { archparam::Z_MC }
147
148 pack_methods!{}
149
150 #[inline(always)]
151 unsafe fn kernel(
152 k: usize,
153 alpha: T,
154 a: *const T,
155 b: *const T,
156 beta: T,
157 c: *mut T, rsc: isize, csc: isize) {
158 kernel_target_neon(k, alpha, a, b, beta, c, rsc, csc)
159 }
160}
161
162impl GemmKernel for KernelFallback {
163 type Elem = T;
164
165 type MRTy = U4;
166 type NRTy = U2;
167
168 #[inline(always)]
169 fn align_to() -> usize { 0 }
170
171 #[inline(always)]
172 fn always_masked() -> bool { true }
173
174 #[inline(always)]
175 fn nc() -> usize { archparam::Z_NC }
176 #[inline(always)]
177 fn kc() -> usize { archparam::Z_KC }
178 #[inline(always)]
179 fn mc() -> usize { archparam::Z_MC }
180
181 pack_methods!{}
182
183 #[inline(always)]
184 unsafe fn kernel(
185 k: usize,
186 alpha: T,
187 a: *const T,
188 b: *const T,
189 beta: T,
190 c: *mut T, rsc: isize, csc: isize) {
191 kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
192 }
193}
194
195#[cfg(any(target_arch="x86", target_arch="x86_64"))]
196kernel_fallback_impl_complex! {
197 [inline target_feature(enable="fma") target_feature(enable="avx2")] [fma_yes]
199 kernel_target_avx2, T, TReal, KernelAvx2::MR, KernelAvx2::NR, 4
200}
201
202#[cfg(any(target_arch="x86", target_arch="x86_64"))]
203kernel_fallback_impl_complex! {
204 [inline target_feature(enable="fma")] [fma_no]
206 kernel_target_fma, T, TReal, KernelFma::MR, KernelFma::NR, 2
207}
208
209#[cfg(target_arch = "aarch64")]
212#[cfg(has_aarch64_simd)]
213kernel_fallback_impl_complex! {
214 [inline target_feature(enable="neon")] [fma_yes]
215 kernel_target_neon, T, TReal, KernelNeon::MR, KernelNeon::NR, 1
216}
217
218kernel_fallback_impl_complex! {
221 [inline] [fma_no]
222 kernel_fallback_impl, T, TReal, KernelFallback::MR, KernelFallback::NR, 1
223}
224
225#[inline(always)]
226unsafe fn at(ptr: *const TReal, i: usize) -> TReal {
227 *ptr.add(i)
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::kernel::test::test_complex_packed_kernel;
234
235 #[test]
236 fn test_kernel_fallback_impl() {
237 test_complex_packed_kernel::<KernelFallback, _, TReal>("kernel");
238 }
239
240 #[cfg(target_arch = "aarch64")]
241 #[cfg(has_aarch64_simd)]
242 mod test_kernel_aarch64 {
243 use super::test_complex_packed_kernel;
244 use super::super::*;
245 #[cfg(feature = "std")]
246 use std::println;
247 macro_rules! test_arch_kernels {
248 ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
249 $(
250 #[test]
251 fn $name() {
252 if is_aarch64_feature_detected_!($feature_name) {
253 test_complex_packed_kernel::<$kernel_ty, _, TReal>(stringify!($name));
254 } else {
255 #[cfg(feature = "std")]
256 println!("Skipping, host does not have feature: {:?}", $feature_name);
257 }
258 }
259 )*
260 }
261 }
262
263 test_arch_kernels! {
264 "neon", neon, KernelNeon
265 }
266 }
267
268 #[cfg(any(target_arch="x86", target_arch="x86_64"))]
269 mod test_arch_kernels {
270 use super::test_complex_packed_kernel;
271 use super::super::*;
272 #[cfg(feature = "std")]
273 use std::println;
274 macro_rules! test_arch_kernels_x86 {
275 ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
276 $(
277 #[test]
278 fn $name() {
279 if is_x86_feature_detected_!($feature_name) {
280 test_complex_packed_kernel::<$kernel_ty, _, TReal>(stringify!($name));
281 } else {
282 #[cfg(feature = "std")]
283 println!("Skipping, host does not have feature: {:?}", $feature_name);
284 }
285 }
286 )*
287 }
288 }
289
290 test_arch_kernels_x86! {
291 "fma", fma, KernelFma,
292 "avx2", avx2, KernelAvx2
293 }
294 }
295}