matrixmultiply/
cgemm_kernel.rs

1// Copyright 2016 - 2021 Ulrik Sverdrup "bluss"
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::kernel::GemmKernel;
10use crate::kernel::GemmSelect;
11use crate::kernel::{U2, U4, c32, Element, c32_mul as mul};
12use crate::archparam;
13use crate::cgemm_common::pack_complex;
14
15#[cfg(any(target_arch="x86", target_arch="x86_64"))]
16struct KernelAvx2;
17#[cfg(any(target_arch="x86", target_arch="x86_64"))]
18struct KernelFma;
19
20#[cfg(target_arch = "aarch64")]
21#[cfg(has_aarch64_simd)]
22struct KernelNeon;
23
24struct KernelFallback;
25
26type T = c32;
27type TReal = f32;
28
29/// Detect which implementation to use and select it using the selector's
30/// .select(Kernel) method.
31///
32/// This function is called one or more times during a whole program's
33/// execution, it may be called for each gemm kernel invocation or fewer times.
34#[inline]
35pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
36    // dispatch to specific compiled versions
37    #[cfg(any(target_arch="x86", target_arch="x86_64"))]
38    {
39        if is_x86_feature_detected_!("fma") {
40            if is_x86_feature_detected_!("avx2") {
41                return selector.select(KernelAvx2);
42            }
43            return selector.select(KernelFma);
44        }
45    }
46    #[cfg(target_arch = "aarch64")]
47    #[cfg(has_aarch64_simd)]
48    {
49        if is_aarch64_feature_detected_!("neon") {
50            return selector.select(KernelNeon);
51        }
52    }
53    return selector.select(KernelFallback);
54}
55
56#[cfg(any(target_arch="x86", target_arch="x86_64"))]
57impl GemmKernel for KernelAvx2 {
58    type Elem = T;
59
60    type MRTy = U4;
61    type NRTy = U4;
62
63    #[inline(always)]
64    fn align_to() -> usize { 32 }
65
66    #[inline(always)]
67    fn always_masked() -> bool { KernelFallback::always_masked() }
68
69    #[inline(always)]
70    fn nc() -> usize { archparam::C_NC }
71    #[inline(always)]
72    fn kc() -> usize { archparam::C_KC }
73    #[inline(always)]
74    fn mc() -> usize { archparam::C_MC }
75
76    pack_methods!{}
77
78    #[inline(always)]
79    unsafe fn kernel(
80        k: usize,
81        alpha: T,
82        a: *const T,
83        b: *const T,
84        beta: T,
85        c: *mut T, rsc: isize, csc: isize) {
86        kernel_target_avx2(k, alpha, a, b, beta, c, rsc, csc)
87    }
88}
89
90#[cfg(any(target_arch="x86", target_arch="x86_64"))]
91impl GemmKernel for KernelFma {
92    type Elem = T;
93
94    type MRTy = U4;
95    type NRTy = U4;
96
97    #[inline(always)]
98    fn align_to() -> usize { 16 }
99
100    #[inline(always)]
101    fn always_masked() -> bool { KernelFallback::always_masked() }
102
103    #[inline(always)]
104    fn nc() -> usize { archparam::C_NC }
105    #[inline(always)]
106    fn kc() -> usize { archparam::C_KC }
107    #[inline(always)]
108    fn mc() -> usize { archparam::C_MC }
109
110    pack_methods!{}
111
112    #[inline(always)]
113    unsafe fn kernel(
114        k: usize,
115        alpha: T,
116        a: *const T,
117        b: *const T,
118        beta: T,
119        c: *mut T, rsc: isize, csc: isize) {
120        kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc)
121    }
122}
123
124#[cfg(target_arch = "aarch64")]
125#[cfg(has_aarch64_simd)]
126impl GemmKernel for KernelNeon {
127    type Elem = T;
128
129    type MRTy = U4;
130    type NRTy = U2;
131
132    #[inline(always)]
133    fn align_to() -> usize { 16 }
134
135    #[inline(always)]
136    fn always_masked() -> bool { KernelFallback::always_masked() }
137
138    #[inline(always)]
139    fn nc() -> usize { archparam::C_NC }
140    #[inline(always)]
141    fn kc() -> usize { archparam::C_KC }
142    #[inline(always)]
143    fn mc() -> usize { archparam::C_MC }
144
145    pack_methods!{}
146
147    #[inline(always)]
148    unsafe fn kernel(
149        k: usize,
150        alpha: T,
151        a: *const T,
152        b: *const T,
153        beta: T,
154        c: *mut T, rsc: isize, csc: isize) {
155        kernel_target_neon(k, alpha, a, b, beta, c, rsc, csc)
156    }
157}
158
159impl GemmKernel for KernelFallback {
160    type Elem = T;
161
162    type MRTy = U4;
163    type NRTy = U2;
164
165    #[inline(always)]
166    fn align_to() -> usize { 0 }
167
168    #[inline(always)]
169    fn always_masked() -> bool { true }
170
171    #[inline(always)]
172    fn nc() -> usize { archparam::C_NC }
173    #[inline(always)]
174    fn kc() -> usize { archparam::C_KC }
175    #[inline(always)]
176    fn mc() -> usize { archparam::C_MC }
177
178    pack_methods!{}
179
180    #[inline(always)]
181    unsafe fn kernel(
182        k: usize,
183        alpha: T,
184        a: *const T,
185        b: *const T,
186        beta: T,
187        c: *mut T, rsc: isize, csc: isize) {
188        kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
189    }
190}
191
192// Kernel AVX2
193#[cfg(any(target_arch="x86", target_arch="x86_64"))]
194macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
195#[cfg(any(target_arch="x86", target_arch="x86_64"))]
196macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; }
197
198#[cfg(any(target_arch="x86", target_arch="x86_64"))]
199kernel_fallback_impl_complex! {
200    // instantiate separately
201    [inline target_feature(enable="avx2") target_feature(enable="fma")] [fma_yes]
202    kernel_target_avx2, T, TReal, KernelAvx2::MR, KernelAvx2::NR, 4
203}
204
205
206// Kernel Fma
207#[cfg(any(target_arch="x86", target_arch="x86_64"))]
208macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
209#[cfg(any(target_arch="x86", target_arch="x86_64"))]
210macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; }
211
212#[cfg(any(target_arch="x86", target_arch="x86_64"))]
213kernel_fallback_impl_complex! {
214    // instantiate separately
215    [inline target_feature(enable="fma")] [fma_no]
216    kernel_target_fma, T, TReal, KernelFma::MR, KernelFma::NR, 2
217}
218
219// Kernel neon
220
221#[cfg(target_arch = "aarch64")]
222#[cfg(has_aarch64_simd)]
223macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
224#[cfg(target_arch = "aarch64")]
225#[cfg(has_aarch64_simd)]
226macro_rules! loop_n { ($j:ident, $e:expr) => { loop2!($j, $e) }; }
227
228#[cfg(target_arch = "aarch64")]
229#[cfg(has_aarch64_simd)]
230kernel_fallback_impl_complex! {
231    [inline target_feature(enable="neon")] [fma_yes]
232    kernel_target_neon, T, TReal, KernelNeon::MR, KernelNeon::NR, 1
233}
234
235// Kernel fallback
236
237macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
238macro_rules! loop_n { ($j:ident, $e:expr) => { loop2!($j, $e) }; }
239
240kernel_fallback_impl_complex! {
241    [inline(always)] [fma_no]
242    kernel_fallback_impl, T, TReal, KernelFallback::MR, KernelFallback::NR, 1
243}
244
245#[inline(always)]
246unsafe fn at(ptr: *const TReal, i: usize) -> TReal {
247    *ptr.add(i)
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use crate::kernel::test::test_complex_packed_kernel;
254
255    #[test]
256    fn test_kernel_fallback_impl() {
257        test_complex_packed_kernel::<KernelFallback, _, TReal>("kernel");
258    }
259
260    #[cfg(target_arch = "aarch64")]
261    #[cfg(has_aarch64_simd)]
262    mod test_kernel_aarch64 {
263        use super::test_complex_packed_kernel;
264        use super::super::*;
265        #[cfg(feature = "std")]
266        use std::println;
267        macro_rules! test_arch_kernels {
268            ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
269                $(
270                #[test]
271                fn $name() {
272                    if is_aarch64_feature_detected_!($feature_name) {
273                        test_complex_packed_kernel::<$kernel_ty, _, TReal>(stringify!($name));
274                    } else {
275                        #[cfg(feature = "std")]
276                        println!("Skipping, host does not have feature: {:?}", $feature_name);
277                    }
278                }
279                )*
280            }
281        }
282
283        test_arch_kernels! {
284            "neon", neon, KernelNeon
285        }
286    }
287
288    #[cfg(any(target_arch="x86", target_arch="x86_64"))]
289    mod test_arch_kernels {
290        use super::test_complex_packed_kernel;
291        use super::super::*;
292        #[cfg(feature = "std")]
293        use std::println;
294        macro_rules! test_arch_kernels_x86 {
295            ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
296                $(
297                #[test]
298                fn $name() {
299                    if is_x86_feature_detected_!($feature_name) {
300                        test_complex_packed_kernel::<$kernel_ty, _, TReal>(stringify!($name));
301                    } else {
302                        #[cfg(feature = "std")]
303                        println!("Skipping, host does not have feature: {:?}", $feature_name);
304                    }
305                }
306                )*
307            }
308        }
309
310        test_arch_kernels_x86! {
311            "fma", fma, KernelFma,
312            "avx2", avx2, KernelAvx2
313        }
314    }
315}