matrixmultiply/
zgemm_kernel.rs

1// Copyright 2016 - 2021 Ulrik Sverdrup "bluss"
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::kernel::GemmKernel;
10use crate::kernel::GemmSelect;
11use crate::kernel::{U2, U4, c64, Element, c64_mul as mul};
12use crate::archparam;
13use crate::cgemm_common::pack_complex;
14
15#[cfg(any(target_arch="x86", target_arch="x86_64"))]
16struct KernelAvx2;
17#[cfg(any(target_arch="x86", target_arch="x86_64"))]
18struct KernelFma;
19
20#[cfg(target_arch = "aarch64")]
21#[cfg(has_aarch64_simd)]
22struct KernelNeon;
23
24struct KernelFallback;
25
26type T = c64;
27type TReal = f64;
28
29/// Detect which implementation to use and select it using the selector's
30/// .select(Kernel) method.
31///
32/// This function is called one or more times during a whole program's
33/// execution, it may be called for each gemm kernel invocation or fewer times.
34#[inline]
35pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
36    // dispatch to specific compiled versions
37    #[cfg(any(target_arch="x86", target_arch="x86_64"))]
38    {
39        if is_x86_feature_detected_!("fma") {
40            if is_x86_feature_detected_!("avx2") {
41                return selector.select(KernelAvx2);
42            }
43            return selector.select(KernelFma);
44        }
45    }
46    #[cfg(target_arch = "aarch64")]
47    #[cfg(has_aarch64_simd)]
48    {
49        if is_aarch64_feature_detected_!("neon") {
50            return selector.select(KernelNeon);
51        }
52    }
53    return selector.select(KernelFallback);
54}
55
56macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
57macro_rules! loop_n { ($j:ident, $e:expr) => { loop2!($j, $e) }; }
58
59#[cfg(any(target_arch="x86", target_arch="x86_64"))]
60impl GemmKernel for KernelAvx2 {
61    type Elem = T;
62
63    type MRTy = U4;
64    type NRTy = U2;
65
66    #[inline(always)]
67    fn align_to() -> usize { 32 }
68
69    #[inline(always)]
70    fn always_masked() -> bool { KernelFallback::always_masked() }
71
72    #[inline(always)]
73    fn nc() -> usize { archparam::Z_NC }
74    #[inline(always)]
75    fn kc() -> usize { archparam::Z_KC }
76    #[inline(always)]
77    fn mc() -> usize { archparam::Z_MC }
78
79    pack_methods!{}
80
81    #[inline(always)]
82    unsafe fn kernel(
83        k: usize,
84        alpha: T,
85        a: *const T,
86        b: *const T,
87        beta: T,
88        c: *mut T, rsc: isize, csc: isize) {
89        kernel_target_avx2(k, alpha, a, b, beta, c, rsc, csc)
90    }
91}
92
93#[cfg(any(target_arch="x86", target_arch="x86_64"))]
94impl GemmKernel for KernelFma {
95    type Elem = T;
96
97    type MRTy = <KernelFallback as GemmKernel>::MRTy;
98    type NRTy = <KernelFallback as GemmKernel>::NRTy;
99
100    #[inline(always)]
101    fn align_to() -> usize { 16 }
102
103    #[inline(always)]
104    fn always_masked() -> bool { KernelFallback::always_masked() }
105
106    #[inline(always)]
107    fn nc() -> usize { archparam::Z_NC }
108    #[inline(always)]
109    fn kc() -> usize { archparam::Z_KC }
110    #[inline(always)]
111    fn mc() -> usize { archparam::Z_MC }
112
113    pack_methods!{}
114
115    #[inline(always)]
116    unsafe fn kernel(
117        k: usize,
118        alpha: T,
119        a: *const T,
120        b: *const T,
121        beta: T,
122        c: *mut T, rsc: isize, csc: isize) {
123        kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc)
124    }
125}
126
127#[cfg(target_arch = "aarch64")]
128#[cfg(has_aarch64_simd)]
129impl GemmKernel for KernelNeon {
130    type Elem = T;
131
132    type MRTy = U4;
133    type NRTy = U2;
134
135    #[inline(always)]
136    fn align_to() -> usize { 16 }
137
138    #[inline(always)]
139    fn always_masked() -> bool { KernelFallback::always_masked() }
140
141    #[inline(always)]
142    fn nc() -> usize { archparam::Z_NC }
143    #[inline(always)]
144    fn kc() -> usize { archparam::Z_KC }
145    #[inline(always)]
146    fn mc() -> usize { archparam::Z_MC }
147
148    pack_methods!{}
149
150    #[inline(always)]
151    unsafe fn kernel(
152        k: usize,
153        alpha: T,
154        a: *const T,
155        b: *const T,
156        beta: T,
157        c: *mut T, rsc: isize, csc: isize) {
158        kernel_target_neon(k, alpha, a, b, beta, c, rsc, csc)
159    }
160}
161
162impl GemmKernel for KernelFallback {
163    type Elem = T;
164
165    type MRTy = U4;
166    type NRTy = U2;
167
168    #[inline(always)]
169    fn align_to() -> usize { 0 }
170
171    #[inline(always)]
172    fn always_masked() -> bool { true }
173
174    #[inline(always)]
175    fn nc() -> usize { archparam::Z_NC }
176    #[inline(always)]
177    fn kc() -> usize { archparam::Z_KC }
178    #[inline(always)]
179    fn mc() -> usize { archparam::Z_MC }
180
181    pack_methods!{}
182
183    #[inline(always)]
184    unsafe fn kernel(
185        k: usize,
186        alpha: T,
187        a: *const T,
188        b: *const T,
189        beta: T,
190        c: *mut T, rsc: isize, csc: isize) {
191        kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
192    }
193}
194
195#[cfg(any(target_arch="x86", target_arch="x86_64"))]
196kernel_fallback_impl_complex! {
197    // instantiate fma separately
198    [inline target_feature(enable="fma") target_feature(enable="avx2")] [fma_yes]
199    kernel_target_avx2, T, TReal, KernelAvx2::MR, KernelAvx2::NR, 4
200}
201
202#[cfg(any(target_arch="x86", target_arch="x86_64"))]
203kernel_fallback_impl_complex! {
204    // instantiate fma separately
205    [inline target_feature(enable="fma")] [fma_no]
206    kernel_target_fma, T, TReal, KernelFma::MR, KernelFma::NR, 2
207}
208
209// Kernel neon
210
211#[cfg(target_arch = "aarch64")]
212#[cfg(has_aarch64_simd)]
213kernel_fallback_impl_complex! {
214    [inline target_feature(enable="neon")] [fma_yes]
215    kernel_target_neon, T, TReal, KernelNeon::MR, KernelNeon::NR, 1
216}
217
218// kernel fallback
219
220kernel_fallback_impl_complex! {
221    [inline] [fma_no]
222    kernel_fallback_impl, T, TReal, KernelFallback::MR, KernelFallback::NR, 1
223}
224
225#[inline(always)]
226unsafe fn at(ptr: *const TReal, i: usize) -> TReal {
227    *ptr.add(i)
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use crate::kernel::test::test_complex_packed_kernel;
234
235    #[test]
236    fn test_kernel_fallback_impl() {
237        test_complex_packed_kernel::<KernelFallback, _, TReal>("kernel");
238    }
239
240    #[cfg(target_arch = "aarch64")]
241    #[cfg(has_aarch64_simd)]
242    mod test_kernel_aarch64 {
243        use super::test_complex_packed_kernel;
244        use super::super::*;
245        #[cfg(feature = "std")]
246        use std::println;
247        macro_rules! test_arch_kernels {
248            ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
249                $(
250                #[test]
251                fn $name() {
252                    if is_aarch64_feature_detected_!($feature_name) {
253                        test_complex_packed_kernel::<$kernel_ty, _, TReal>(stringify!($name));
254                    } else {
255                        #[cfg(feature = "std")]
256                        println!("Skipping, host does not have feature: {:?}", $feature_name);
257                    }
258                }
259                )*
260            }
261        }
262
263        test_arch_kernels! {
264            "neon", neon, KernelNeon
265        }
266    }
267
268    #[cfg(any(target_arch="x86", target_arch="x86_64"))]
269    mod test_arch_kernels {
270        use super::test_complex_packed_kernel;
271        use super::super::*;
272        #[cfg(feature = "std")]
273        use std::println;
274        macro_rules! test_arch_kernels_x86 {
275            ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
276                $(
277                #[test]
278                fn $name() {
279                    if is_x86_feature_detected_!($feature_name) {
280                        test_complex_packed_kernel::<$kernel_ty, _, TReal>(stringify!($name));
281                    } else {
282                        #[cfg(feature = "std")]
283                        println!("Skipping, host does not have feature: {:?}", $feature_name);
284                    }
285                }
286                )*
287            }
288        }
289
290        test_arch_kernels_x86! {
291            "fma", fma, KernelFma,
292            "avx2", avx2, KernelAvx2
293        }
294    }
295}