use core::arch::x86_64::*;
use num_complex::Complex;
use std::ops::{Deref, DerefMut};
use crate::array_utils::DoubleBuf;
macro_rules! read_complex_to_array {
($input:ident, { $($idx:literal),* }) => {
[
$(
$input.load_complex($idx),
)*
]
}
}
macro_rules! read_partial1_complex_to_array {
($input:ident, { $($idx:literal),* }) => {
[
$(
$input.load1_complex($idx),
)*
]
}
}
macro_rules! write_complex_to_array {
($input:ident, $output:ident, { $($idx:literal),* }) => {
$(
$output.store_complex($input[$idx], $idx);
)*
}
}
macro_rules! write_partial_lo_complex_to_array {
($input:ident, $output:ident, { $($idx:literal),* }) => {
$(
$output.store_partial_lo_complex($input[$idx], $idx);
)*
}
}
macro_rules! write_complex_to_array_strided {
($input:ident, $output:ident, $stride:literal, { $($idx:literal),* }) => {
$(
$output.store_complex($input[$idx], $idx*$stride);
)*
}
}
pub trait SseNum {
type VectorType;
const COMPLEX_PER_VECTOR: usize;
}
impl SseNum for f32 {
type VectorType = __m128;
const COMPLEX_PER_VECTOR: usize = 2;
}
impl SseNum for f64 {
type VectorType = __m128d;
const COMPLEX_PER_VECTOR: usize = 1;
}
pub trait SseArray<T: SseNum>: Deref {
unsafe fn load_complex(&self, index: usize) -> T::VectorType;
unsafe fn load_partial1_complex(&self, index: usize) -> T::VectorType;
unsafe fn load1_complex(&self, index: usize) -> T::VectorType;
}
impl SseArray<f32> for &[Complex<f32>] {
#[inline(always)]
unsafe fn load_complex(&self, index: usize) -> <f32 as SseNum>::VectorType {
debug_assert!(self.len() >= index + <f32 as SseNum>::COMPLEX_PER_VECTOR);
_mm_loadu_ps(self.as_ptr().add(index) as *const f32)
}
#[inline(always)]
unsafe fn load_partial1_complex(&self, index: usize) -> <f32 as SseNum>::VectorType {
debug_assert!(self.len() >= index + 1);
_mm_castpd_ps(_mm_load_sd(self.as_ptr().add(index) as *const f64))
}
#[inline(always)]
unsafe fn load1_complex(&self, index: usize) -> <f32 as SseNum>::VectorType {
debug_assert!(self.len() >= index + 1);
_mm_castpd_ps(_mm_load1_pd(self.as_ptr().add(index) as *const f64))
}
}
impl SseArray<f32> for &mut [Complex<f32>] {
#[inline(always)]
unsafe fn load_complex(&self, index: usize) -> <f32 as SseNum>::VectorType {
debug_assert!(self.len() >= index + <f32 as SseNum>::COMPLEX_PER_VECTOR);
_mm_loadu_ps(self.as_ptr().add(index) as *const f32)
}
#[inline(always)]
unsafe fn load_partial1_complex(&self, index: usize) -> <f32 as SseNum>::VectorType {
debug_assert!(self.len() >= index + 1);
_mm_castpd_ps(_mm_load_sd(self.as_ptr().add(index) as *const f64))
}
#[inline(always)]
unsafe fn load1_complex(&self, index: usize) -> <f32 as SseNum>::VectorType {
debug_assert!(self.len() >= index + 1);
_mm_castpd_ps(_mm_load1_pd(self.as_ptr().add(index) as *const f64))
}
}
impl SseArray<f64> for &[Complex<f64>] {
#[inline(always)]
unsafe fn load_complex(&self, index: usize) -> <f64 as SseNum>::VectorType {
debug_assert!(self.len() >= index + <f64 as SseNum>::COMPLEX_PER_VECTOR);
_mm_loadu_pd(self.as_ptr().add(index) as *const f64)
}
#[inline(always)]
unsafe fn load_partial1_complex(&self, _index: usize) -> <f64 as SseNum>::VectorType {
unimplemented!("Impossible to do a partial load of complex f64's");
}
#[inline(always)]
unsafe fn load1_complex(&self, _index: usize) -> <f64 as SseNum>::VectorType {
unimplemented!("Impossible to do a partial load of complex f64's");
}
}
impl SseArray<f64> for &mut [Complex<f64>] {
#[inline(always)]
unsafe fn load_complex(&self, index: usize) -> <f64 as SseNum>::VectorType {
debug_assert!(self.len() >= index + <f64 as SseNum>::COMPLEX_PER_VECTOR);
_mm_loadu_pd(self.as_ptr().add(index) as *const f64)
}
#[inline(always)]
unsafe fn load_partial1_complex(&self, _index: usize) -> <f64 as SseNum>::VectorType {
unimplemented!("Impossible to do a partial load of complex f64's");
}
#[inline(always)]
unsafe fn load1_complex(&self, _index: usize) -> <f64 as SseNum>::VectorType {
unimplemented!("Impossible to do a partial load of complex f64's");
}
}
impl<'a, T: SseNum> SseArray<T> for DoubleBuf<'a, T>
where
&'a [Complex<T>]: SseArray<T>,
{
#[inline(always)]
unsafe fn load_complex(&self, index: usize) -> T::VectorType {
self.input.load_complex(index)
}
#[inline(always)]
unsafe fn load_partial1_complex(&self, index: usize) -> T::VectorType {
self.input.load_partial1_complex(index)
}
#[inline(always)]
unsafe fn load1_complex(&self, index: usize) -> T::VectorType {
self.input.load1_complex(index)
}
}
pub trait SseArrayMut<T: SseNum>: SseArray<T> + DerefMut {
unsafe fn store_complex(&mut self, vector: T::VectorType, index: usize);
unsafe fn store_partial_lo_complex(&mut self, vector: T::VectorType, index: usize);
unsafe fn store_partial_hi_complex(&mut self, vector: T::VectorType, index: usize);
}
impl SseArrayMut<f32> for &mut [Complex<f32>] {
#[inline(always)]
unsafe fn store_complex(&mut self, vector: <f32 as SseNum>::VectorType, index: usize) {
debug_assert!(self.len() >= index + <f32 as SseNum>::COMPLEX_PER_VECTOR);
_mm_storeu_ps(self.as_mut_ptr().add(index) as *mut f32, vector);
}
#[inline(always)]
unsafe fn store_partial_hi_complex(
&mut self,
vector: <f32 as SseNum>::VectorType,
index: usize,
) {
debug_assert!(self.len() >= index + 1);
_mm_storeh_pd(
self.as_mut_ptr().add(index) as *mut f64,
_mm_castps_pd(vector),
);
}
#[inline(always)]
unsafe fn store_partial_lo_complex(
&mut self,
vector: <f32 as SseNum>::VectorType,
index: usize,
) {
debug_assert!(self.len() >= index + 1);
_mm_storel_pd(
self.as_mut_ptr().add(index) as *mut f64,
_mm_castps_pd(vector),
);
}
}
impl SseArrayMut<f64> for &mut [Complex<f64>] {
#[inline(always)]
unsafe fn store_complex(&mut self, vector: <f64 as SseNum>::VectorType, index: usize) {
debug_assert!(self.len() >= index + <f64 as SseNum>::COMPLEX_PER_VECTOR);
_mm_storeu_pd(self.as_mut_ptr().add(index) as *mut f64, vector);
}
#[inline(always)]
unsafe fn store_partial_hi_complex(
&mut self,
_vector: <f64 as SseNum>::VectorType,
_index: usize,
) {
unimplemented!("Impossible to do a partial store of complex f64's");
}
#[inline(always)]
unsafe fn store_partial_lo_complex(
&mut self,
_vector: <f64 as SseNum>::VectorType,
_index: usize,
) {
unimplemented!("Impossible to do a partial store of complex f64's");
}
}
impl<'a, T: SseNum> SseArrayMut<T> for DoubleBuf<'a, T>
where
Self: SseArray<T>,
&'a mut [Complex<T>]: SseArrayMut<T>,
{
#[inline(always)]
unsafe fn store_complex(&mut self, vector: T::VectorType, index: usize) {
self.output.store_complex(vector, index);
}
#[inline(always)]
unsafe fn store_partial_lo_complex(&mut self, vector: T::VectorType, index: usize) {
self.output.store_partial_lo_complex(vector, index);
}
#[inline(always)]
unsafe fn store_partial_hi_complex(&mut self, vector: T::VectorType, index: usize) {
self.output.store_partial_hi_complex(vector, index);
}
}