use crate::kernel::GemmKernel;
use crate::kernel::GemmSelect;
use crate::kernel::{U2, U4, c32, Element, c32_mul as mul};
use crate::archparam;
use crate::cgemm_common::pack_complex;
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
struct KernelAvx2;
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
struct KernelFma;
#[cfg(target_arch = "aarch64")]
#[cfg(has_aarch64_simd)]
struct KernelNeon;
struct KernelFallback;
type T = c32;
type TReal = f32;
#[inline]
pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
{
if is_x86_feature_detected_!("fma") {
if is_x86_feature_detected_!("avx2") {
return selector.select(KernelAvx2);
}
return selector.select(KernelFma);
}
}
#[cfg(target_arch = "aarch64")]
#[cfg(has_aarch64_simd)]
{
if is_aarch64_feature_detected_!("neon") {
return selector.select(KernelNeon);
}
}
return selector.select(KernelFallback);
}
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
impl GemmKernel for KernelAvx2 {
type Elem = T;
type MRTy = U4;
type NRTy = U4;
#[inline(always)]
fn align_to() -> usize { 32 }
#[inline(always)]
fn always_masked() -> bool { KernelFallback::always_masked() }
#[inline(always)]
fn nc() -> usize { archparam::C_NC }
#[inline(always)]
fn kc() -> usize { archparam::C_KC }
#[inline(always)]
fn mc() -> usize { archparam::C_MC }
pack_methods!{}
#[inline(always)]
unsafe fn kernel(
k: usize,
alpha: T,
a: *const T,
b: *const T,
beta: T,
c: *mut T, rsc: isize, csc: isize) {
kernel_target_avx2(k, alpha, a, b, beta, c, rsc, csc)
}
}
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
impl GemmKernel for KernelFma {
type Elem = T;
type MRTy = U4;
type NRTy = U4;
#[inline(always)]
fn align_to() -> usize { 16 }
#[inline(always)]
fn always_masked() -> bool { KernelFallback::always_masked() }
#[inline(always)]
fn nc() -> usize { archparam::C_NC }
#[inline(always)]
fn kc() -> usize { archparam::C_KC }
#[inline(always)]
fn mc() -> usize { archparam::C_MC }
pack_methods!{}
#[inline(always)]
unsafe fn kernel(
k: usize,
alpha: T,
a: *const T,
b: *const T,
beta: T,
c: *mut T, rsc: isize, csc: isize) {
kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc)
}
}
#[cfg(target_arch = "aarch64")]
#[cfg(has_aarch64_simd)]
impl GemmKernel for KernelNeon {
type Elem = T;
type MRTy = U4;
type NRTy = U2;
#[inline(always)]
fn align_to() -> usize { 16 }
#[inline(always)]
fn always_masked() -> bool { KernelFallback::always_masked() }
#[inline(always)]
fn nc() -> usize { archparam::C_NC }
#[inline(always)]
fn kc() -> usize { archparam::C_KC }
#[inline(always)]
fn mc() -> usize { archparam::C_MC }
pack_methods!{}
#[inline(always)]
unsafe fn kernel(
k: usize,
alpha: T,
a: *const T,
b: *const T,
beta: T,
c: *mut T, rsc: isize, csc: isize) {
kernel_target_neon(k, alpha, a, b, beta, c, rsc, csc)
}
}
impl GemmKernel for KernelFallback {
type Elem = T;
type MRTy = U4;
type NRTy = U2;
#[inline(always)]
fn align_to() -> usize { 0 }
#[inline(always)]
fn always_masked() -> bool { true }
#[inline(always)]
fn nc() -> usize { archparam::C_NC }
#[inline(always)]
fn kc() -> usize { archparam::C_KC }
#[inline(always)]
fn mc() -> usize { archparam::C_MC }
pack_methods!{}
#[inline(always)]
unsafe fn kernel(
k: usize,
alpha: T,
a: *const T,
b: *const T,
beta: T,
c: *mut T, rsc: isize, csc: isize) {
kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
}
}
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; }
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
kernel_fallback_impl_complex! {
[inline target_feature(enable="avx2") target_feature(enable="fma")] [fma_yes]
kernel_target_avx2, T, TReal, KernelAvx2::MR, KernelAvx2::NR, 4
}
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; }
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
kernel_fallback_impl_complex! {
[inline target_feature(enable="fma")] [fma_no]
kernel_target_fma, T, TReal, KernelFma::MR, KernelFma::NR, 2
}
#[cfg(target_arch = "aarch64")]
#[cfg(has_aarch64_simd)]
macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
#[cfg(target_arch = "aarch64")]
#[cfg(has_aarch64_simd)]
macro_rules! loop_n { ($j:ident, $e:expr) => { loop2!($j, $e) }; }
#[cfg(target_arch = "aarch64")]
#[cfg(has_aarch64_simd)]
kernel_fallback_impl_complex! {
[inline target_feature(enable="neon")] [fma_yes]
kernel_target_neon, T, TReal, KernelNeon::MR, KernelNeon::NR, 1
}
macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
macro_rules! loop_n { ($j:ident, $e:expr) => { loop2!($j, $e) }; }
kernel_fallback_impl_complex! {
[inline(always)] [fma_no]
kernel_fallback_impl, T, TReal, KernelFallback::MR, KernelFallback::NR, 1
}
#[inline(always)]
unsafe fn at(ptr: *const TReal, i: usize) -> TReal {
*ptr.add(i)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::test::test_complex_packed_kernel;
#[test]
fn test_kernel_fallback_impl() {
test_complex_packed_kernel::<KernelFallback, _, TReal>("kernel");
}
#[cfg(target_arch = "aarch64")]
#[cfg(has_aarch64_simd)]
mod test_kernel_aarch64 {
use super::test_complex_packed_kernel;
use super::super::*;
#[cfg(feature = "std")]
use std::println;
macro_rules! test_arch_kernels {
($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
$(
#[test]
fn $name() {
if is_aarch64_feature_detected_!($feature_name) {
test_complex_packed_kernel::<$kernel_ty, _, TReal>(stringify!($name));
} else {
#[cfg(feature = "std")]
println!("Skipping, host does not have feature: {:?}", $feature_name);
}
}
)*
}
}
test_arch_kernels! {
"neon", neon, KernelNeon
}
}
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
mod test_arch_kernels {
use super::test_complex_packed_kernel;
use super::super::*;
#[cfg(feature = "std")]
use std::println;
macro_rules! test_arch_kernels_x86 {
($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
$(
#[test]
fn $name() {
if is_x86_feature_detected_!($feature_name) {
test_complex_packed_kernel::<$kernel_ty, _, TReal>(stringify!($name));
} else {
#[cfg(feature = "std")]
println!("Skipping, host does not have feature: {:?}", $feature_name);
}
}
)*
}
}
test_arch_kernels_x86! {
"fma", fma, KernelFma,
"avx2", avx2, KernelAvx2
}
}
}