1use std::convert::TryInto;
2use std::sync::Arc;
3use std::{any::TypeId, arch::x86_64::*};
4
5use num_complex::Complex;
6use num_integer::{div_ceil, Integer};
7use num_traits::Zero;
8use primal_check::miller_rabin;
9use strength_reduce::StrengthReducedUsize;
10
11use crate::common::{fft_error_inplace, fft_error_outofplace};
12use crate::{array_utils, FftDirection};
13use crate::{math_utils, twiddles};
14use crate::{Direction, Fft, FftNum, Length};
15
16use super::avx_vector;
17use super::{
18 avx_vector::{AvxArray, AvxArrayMut, AvxVector, AvxVector128, AvxVector256},
19 AvxNum,
20};
21
22#[derive(Clone)]
25struct VectorizedMultiplyMod {
26 b: __m256i,
27 divisor: __m256i,
28 intermediate: __m256i,
29}
30
31impl VectorizedMultiplyMod {
32 #[target_feature(enable = "avx")]
33 unsafe fn new(b: u32, divisor: u32) -> Self {
34 assert!(
35 divisor.leading_zeros() > 0,
36 "divisor must be less than {}, got {}",
37 1 << 31,
38 divisor
39 );
40
41 let b = b % divisor;
42 let intermediate = ((b as i64) << 32) / divisor as i64;
43
44 Self {
45 b: _mm256_set1_epi64x(b as i64),
46 divisor: _mm256_set1_epi64x(divisor as i64),
47 intermediate: _mm256_set1_epi64x(intermediate),
48 }
49 }
50
51 #[allow(unused)]
54 #[inline(always)]
55 unsafe fn mul_rem(&self, a: __m256i) -> __m256i {
56 let masked_divisor = _mm256_blend_epi32(self.divisor, _mm256_setzero_si256(), 0xAA);
61
62 let quotient = _mm256_srli_epi64(_mm256_mul_epu32(a, self.intermediate), 32);
64
65 let numerator = _mm256_mul_epu32(a, self.b);
67 let quotient_product = _mm256_mul_epu32(quotient, masked_divisor);
68
69 let remainder = _mm256_sub_epi64(numerator, quotient_product);
71
72 let casted_remainder = _mm256_castsi256_pd(remainder);
77 let subtracted_remainder = _mm256_castsi256_pd(_mm256_sub_epi64(remainder, masked_divisor));
78 let wrapped_remainder = _mm256_castpd_si256(_mm256_blendv_pd(
79 subtracted_remainder,
80 casted_remainder,
81 subtracted_remainder,
82 ));
83 wrapped_remainder
84 }
85}
86
87pub struct RadersAvx2<A: AvxNum, T> {
98 input_index_multiplier: VectorizedMultiplyMod,
99 input_index_init: __m256i,
100
101 output_index_mapping: Box<[__m128i]>,
102 twiddles: Box<[A::VectorType]>,
103
104 inner_fft: Arc<dyn Fft<T>>,
105
106 len: usize,
107
108 inplace_scratch_len: usize,
109 outofplace_scratch_len: usize,
110 direction: FftDirection,
111
112 _phantom: std::marker::PhantomData<T>,
113}
114
115impl<A: AvxNum, T: FftNum> RadersAvx2<A, T> {
116 #[inline]
122 pub fn new(inner_fft: Arc<dyn Fft<T>>) -> Result<Self, ()> {
123 let id_a = TypeId::of::<A>();
127 let id_t = TypeId::of::<T>();
128 assert_eq!(id_a, id_t);
129
130 let has_avx = is_x86_feature_detected!("avx");
131 let has_avx2 = is_x86_feature_detected!("avx2");
132 let has_fma = is_x86_feature_detected!("fma");
133 if has_avx && has_avx2 && has_fma {
134 Ok(unsafe { Self::new_with_avx(inner_fft) })
136 } else {
137 Err(())
138 }
139 }
140
141 #[target_feature(enable = "avx")]
142 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
143 let inner_fft_len = inner_fft.len();
144 let len = inner_fft_len + 1;
145 assert!(miller_rabin(len as u64), "For raders algorithm, inner_fft.len() + 1 must be prime. Expected prime number, got {} + 1 = {}", inner_fft_len, len);
146
147 let direction = inner_fft.fft_direction();
148 let reduced_len = StrengthReducedUsize::new(len);
149
150 let primitive_root = math_utils::primitive_root(len as u64).unwrap() as usize;
152
153 let gcd_data = i64::extended_gcd(&(primitive_root as i64), &(len as i64));
157 let primitive_root_inverse = if gcd_data.x >= 0 {
158 gcd_data.x
159 } else {
160 gcd_data.x + len as i64
161 } as usize;
162
163 let inner_fft_scale = T::one() / T::from_usize(inner_fft_len).unwrap();
165 let mut inner_fft_input = vec![Complex::zero(); inner_fft_len];
166 let mut twiddle_input = 1;
167 for input_cell in &mut inner_fft_input {
168 let twiddle = twiddles::compute_twiddle(twiddle_input, len, direction);
169 *input_cell = twiddle * inner_fft_scale;
170
171 twiddle_input = (twiddle_input * primitive_root_inverse) % reduced_len;
172 }
173
174 let required_inner_scratch = inner_fft.get_inplace_scratch_len();
175 let extra_inner_scratch = if required_inner_scratch <= inner_fft_len {
176 0
177 } else {
178 required_inner_scratch
179 };
180
181 let mut inner_fft_scratch = vec![Zero::zero(); required_inner_scratch];
183 inner_fft.process_with_scratch(&mut inner_fft_input, &mut inner_fft_scratch);
184
185 let conjugation_mask =
187 AvxVector256::broadcast_complex_elements(Complex::new(A::zero(), -A::zero()));
188
189 let inner_fft_multiplier: Box<[_]> = {
190 let transmuted_inner_input: &mut [Complex<A>] =
192 array_utils::workaround_transmute_mut(&mut inner_fft_input);
193
194 transmuted_inner_input
195 .chunks(A::VectorType::COMPLEX_PER_VECTOR)
196 .map(|chunk| {
197 let chunk_vector = match chunk.len() {
198 1 => chunk.load_partial1_complex(0).zero_extend(),
199 2 => {
200 if chunk.len() == A::VectorType::COMPLEX_PER_VECTOR {
201 chunk.load_complex(0)
202 } else {
203 chunk.load_partial2_complex(0).zero_extend()
204 }
205 }
206 3 => chunk.load_partial3_complex(0),
207 4 => chunk.load_complex(0),
208 _ => unreachable!(),
209 };
210 AvxVector::xor(chunk_vector, conjugation_mask) })
212 .collect()
213 };
214
215 const NUM_POWERS: usize = 5;
217 let mut root_powers = [0; NUM_POWERS];
218 let mut current_power = 1;
219 for i in 0..NUM_POWERS {
220 root_powers[i] = current_power;
221 current_power = (current_power * primitive_root) % reduced_len;
222 }
223
224 let (input_index_multiplier, input_index_init) = if A::VectorType::COMPLEX_PER_VECTOR == 4 {
225 (
226 VectorizedMultiplyMod::new(root_powers[4] as u32, len as u32),
227 _mm256_loadu_si256(root_powers.as_ptr().add(1) as *const __m256i),
228 )
229 } else {
230 let duplicated_powers = [
231 root_powers[1],
232 root_powers[1],
233 root_powers[2],
234 root_powers[2],
235 ];
236 (
237 VectorizedMultiplyMod::new(root_powers[2] as u32, len as u32),
238 _mm256_loadu_si256(duplicated_powers.as_ptr() as *const __m256i),
239 )
240 };
241
242 let mapping_size = 1 + div_ceil(len, A::VectorType::COMPLEX_PER_VECTOR)
247 * A::VectorType::COMPLEX_PER_VECTOR;
248 let mut output_mapping_inverse: Vec<i32> = vec![0; mapping_size];
249 let mut output_index = 1;
250 for i in 1..len {
251 output_index = (output_index * primitive_root_inverse) % reduced_len;
252 output_mapping_inverse[output_index] = i.try_into().unwrap();
253 }
254
255 let output_index_mapping = if A::VectorType::COMPLEX_PER_VECTOR == 4 {
257 (&output_mapping_inverse[1..])
258 .chunks_exact(A::VectorType::COMPLEX_PER_VECTOR)
259 .map(|chunk| _mm_loadu_si128(chunk.as_ptr() as *const __m128i))
260 .collect::<Box<[__m128i]>>()
261 } else {
262 (&output_mapping_inverse[1..])
263 .chunks_exact(A::VectorType::COMPLEX_PER_VECTOR)
264 .map(|chunk| {
265 let duplicated_indexes = [chunk[0], chunk[0], chunk[1], chunk[1]];
266 _mm_loadu_si128(duplicated_indexes.as_ptr() as *const __m128i)
267 })
268 .collect::<Box<[__m128i]>>()
269 };
270 Self {
271 input_index_multiplier,
272 input_index_init,
273
274 output_index_mapping,
275
276 inner_fft: inner_fft,
277 twiddles: inner_fft_multiplier,
278
279 len,
280
281 inplace_scratch_len: len + extra_inner_scratch,
282 outofplace_scratch_len: extra_inner_scratch,
283 direction,
284
285 _phantom: std::marker::PhantomData,
286 }
287 }
288
289 #[target_feature(enable = "avx2", enable = "avx", enable = "fma")]
291 unsafe fn prepare_raders(&self, input: &[Complex<A>], output: &mut [Complex<A>]) {
292 let mut indexes = self.input_index_init;
293
294 let index_multiplier = self.input_index_multiplier.clone();
295
296 let mut chunks_iter =
298 (&mut output[1..]).chunks_exact_mut(A::VectorType::COMPLEX_PER_VECTOR);
299 for mut chunk in chunks_iter.by_ref() {
300 let gathered_elements =
301 A::VectorType::gather_complex_avx2_index64(input.as_ptr(), indexes);
302
303 indexes = index_multiplier.mul_rem(indexes);
305
306 chunk.store_complex(gathered_elements, 0);
308 }
309
310 let mut output_remainder = chunks_iter.into_remainder();
312 if output_remainder.len() == 2 {
313 let half_data = AvxVector128::gather64_complex_avx2(
314 input.as_ptr(),
315 _mm256_castsi256_si128(indexes),
316 );
317
318 output_remainder.store_partial2_complex(half_data, 0);
320 }
321 }
322
323 #[target_feature(enable = "avx2", enable = "avx", enable = "fma")]
325 unsafe fn finalize_raders(&self, input: &[Complex<A>], output: &mut [Complex<A>]) {
326 let conjugation_mask =
328 AvxVector256::broadcast_complex_elements(Complex::new(A::zero(), -A::zero()));
329
330 let mut chunks_iter =
331 (&mut output[1..]).chunks_exact_mut(A::VectorType::COMPLEX_PER_VECTOR);
332 for (i, mut chunk) in chunks_iter.by_ref().enumerate() {
333 let index_chunk = *self.output_index_mapping.get_unchecked(i);
334 let gathered_elements =
335 A::VectorType::gather_complex_avx2_index32(input.as_ptr(), index_chunk);
336
337 let conjugated_elements = AvxVector::xor(gathered_elements, conjugation_mask);
338 chunk.store_complex(conjugated_elements, 0);
339 }
340
341 let mut output_remainder = chunks_iter.into_remainder();
343 if output_remainder.len() == 2 {
344 let index_chunk = *self
345 .output_index_mapping
346 .get_unchecked(self.output_index_mapping.len() - 1);
347 let half_data = AvxVector128::gather32_complex_avx2(input.as_ptr(), index_chunk);
348
349 let conjugated_elements = AvxVector::xor(half_data, conjugation_mask.lo());
350 output_remainder.store_partial2_complex(conjugated_elements, 0);
351 }
352 }
353
354 fn perform_fft_out_of_place(
355 &self,
356 input: &mut [Complex<T>],
357 output: &mut [Complex<T>],
358 scratch: &mut [Complex<T>],
359 ) {
360 unsafe {
361 let transmuted_input: &mut [Complex<A>] = array_utils::workaround_transmute_mut(input);
363 let transmuted_output: &mut [Complex<A>] =
364 array_utils::workaround_transmute_mut(output);
365 self.prepare_raders(transmuted_input, transmuted_output)
366 }
367
368 let (first_input, inner_input) = input.split_first_mut().unwrap();
369 let (first_output, inner_output) = output.split_first_mut().unwrap();
370
371 let inner_scratch = if scratch.len() > 0 {
373 &mut scratch[..]
374 } else {
375 &mut inner_input[..]
376 };
377 self.inner_fft
378 .process_with_scratch(inner_output, inner_scratch);
379
380 *first_output = inner_output[0] + *first_input;
382
383 unsafe {
387 let transmuted_inner_input: &mut [Complex<A>] =
389 array_utils::workaround_transmute_mut(inner_input);
390 let transmuted_inner_output: &mut [Complex<A>] =
391 array_utils::workaround_transmute_mut(inner_output);
392 avx_vector::pairwise_complex_mul_conjugated(
393 transmuted_inner_output,
394 transmuted_inner_input,
395 &self.twiddles,
396 )
397 };
398
399 inner_input[0] = inner_input[0] + first_input.conj();
402
403 let inner_scratch = if scratch.len() > 0 {
405 scratch
406 } else {
407 &mut inner_output[..]
408 };
409 self.inner_fft
410 .process_with_scratch(inner_input, inner_scratch);
411
412 unsafe {
414 let transmuted_input: &mut [Complex<A>] = array_utils::workaround_transmute_mut(input);
416 let transmuted_output: &mut [Complex<A>] =
417 array_utils::workaround_transmute_mut(output);
418 self.finalize_raders(transmuted_input, transmuted_output);
419 }
420 }
421 fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
422 let (scratch, extra_scratch) = scratch.split_at_mut(self.len());
423 unsafe {
424 let transmuted_scratch: &mut [Complex<A>] =
426 array_utils::workaround_transmute_mut(scratch);
427 let transmuted_buffer: &mut [Complex<A>] =
428 array_utils::workaround_transmute_mut(buffer);
429 self.prepare_raders(transmuted_buffer, transmuted_scratch)
430 }
431
432 let first_input = buffer[0];
433
434 let truncated_scratch = &mut scratch[1..];
435
436 let inner_scratch = if extra_scratch.len() > 0 {
438 extra_scratch
439 } else {
440 &mut buffer[..]
441 };
442 self.inner_fft
443 .process_with_scratch(truncated_scratch, inner_scratch);
444
445 let first_output = first_input + truncated_scratch[0];
447
448 unsafe {
452 let transmuted_scratch: &mut [Complex<A>] =
454 array_utils::workaround_transmute_mut(truncated_scratch);
455 avx_vector::pairwise_complex_mul_assign_conjugated(transmuted_scratch, &self.twiddles)
456 };
457
458 truncated_scratch[0] = truncated_scratch[0] + first_input.conj();
461
462 self.inner_fft
464 .process_with_scratch(truncated_scratch, inner_scratch);
465
466 buffer[0] = first_output;
468 unsafe {
469 let transmuted_scratch: &mut [Complex<A>] =
471 array_utils::workaround_transmute_mut(scratch);
472 let transmuted_buffer: &mut [Complex<A>] =
473 array_utils::workaround_transmute_mut(buffer);
474 self.finalize_raders(transmuted_scratch, transmuted_buffer);
475 }
476 }
477}
478boilerplate_avx_fft!(
479 RadersAvx2,
480 |this: &RadersAvx2<_, _>| this.len,
481 |this: &RadersAvx2<_, _>| this.inplace_scratch_len,
482 |this: &RadersAvx2<_, _>| this.outofplace_scratch_len
483);
484
485#[cfg(test)]
486mod unit_tests {
487 use num_traits::Float;
488 use rand::distributions::uniform::SampleUniform;
489
490 use super::*;
491 use crate::algorithm::Dft;
492 use crate::test_utils::check_fft_algorithm;
493 use std::sync::Arc;
494
495 #[test]
496 fn test_raders_avx_f32() {
497 for len in 3..100 {
498 if miller_rabin(len as u64) {
499 test_raders_with_length::<f32>(len, FftDirection::Forward);
500 test_raders_with_length::<f32>(len, FftDirection::Inverse);
501 }
502 }
503 }
504
505 #[test]
506 fn test_raders_avx_f64() {
507 for len in 3..100 {
508 if miller_rabin(len as u64) {
509 test_raders_with_length::<f64>(len, FftDirection::Forward);
510 test_raders_with_length::<f64>(len, FftDirection::Inverse);
511 }
512 }
513 }
514
515 fn test_raders_with_length<T: AvxNum + Float + SampleUniform>(
516 len: usize,
517 direction: FftDirection,
518 ) {
519 let inner_fft = Arc::new(Dft::new(len - 1, direction));
520 let fft = RadersAvx2::<T, T>::new(inner_fft).unwrap();
521
522 check_fft_algorithm::<T>(&fft, len, direction);
523 }
524}