1use crate::imp_prelude::*;
10
11#[cfg(feature = "blas")]
12use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
13use crate::numeric_util;
14
15use crate::{LinalgScalar, Zip};
16
17use std::any::TypeId;
18use std::mem::MaybeUninit;
19use alloc::vec::Vec;
20
21use num_complex::Complex;
22use num_complex::{Complex32 as c32, Complex64 as c64};
23
24#[cfg(feature = "blas")]
25use libc::c_int;
26#[cfg(feature = "blas")]
27use std::cmp;
28#[cfg(feature = "blas")]
29use std::mem::swap;
30
31#[cfg(feature = "blas")]
32use cblas_sys as blas_sys;
33#[cfg(feature = "blas")]
34use cblas_sys::{CblasNoTrans, CblasRowMajor, CblasTrans, CBLAS_LAYOUT};
35
36#[cfg(feature = "blas")]
38const DOT_BLAS_CUTOFF: usize = 32;
39#[cfg(feature = "blas")]
41const GEMM_BLAS_CUTOFF: usize = 7;
42#[cfg(feature = "blas")]
43#[allow(non_camel_case_types)]
44type blas_index = c_int; impl<A, S> ArrayBase<S, Ix1>
47where
48 S: Data<Elem = A>,
49{
50 pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
68 where
69 Self: Dot<Rhs>,
70 {
71 Dot::dot(self, rhs)
72 }
73
74 fn dot_generic<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
75 where
76 S2: Data<Elem = A>,
77 A: LinalgScalar,
78 {
79 debug_assert_eq!(self.len(), rhs.len());
80 assert!(self.len() == rhs.len());
81 if let Some(self_s) = self.as_slice() {
82 if let Some(rhs_s) = rhs.as_slice() {
83 return numeric_util::unrolled_dot(self_s, rhs_s);
84 }
85 }
86 let mut sum = A::zero();
87 for i in 0..self.len() {
88 unsafe {
89 sum = sum + *self.uget(i) * *rhs.uget(i);
90 }
91 }
92 sum
93 }
94
95 #[cfg(not(feature = "blas"))]
96 fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
97 where
98 S2: Data<Elem = A>,
99 A: LinalgScalar,
100 {
101 self.dot_generic(rhs)
102 }
103
104 #[cfg(feature = "blas")]
105 fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
106 where
107 S2: Data<Elem = A>,
108 A: LinalgScalar,
109 {
110 if self.len() >= DOT_BLAS_CUTOFF {
112 debug_assert_eq!(self.len(), rhs.len());
113 assert!(self.len() == rhs.len());
114 macro_rules! dot {
115 ($ty:ty, $func:ident) => {{
116 if blas_compat_1d::<$ty, _>(self) && blas_compat_1d::<$ty, _>(rhs) {
117 unsafe {
118 let (lhs_ptr, n, incx) =
119 blas_1d_params(self.ptr.as_ptr(), self.len(), self.strides()[0]);
120 let (rhs_ptr, _, incy) =
121 blas_1d_params(rhs.ptr.as_ptr(), rhs.len(), rhs.strides()[0]);
122 let ret = blas_sys::$func(
123 n,
124 lhs_ptr as *const $ty,
125 incx,
126 rhs_ptr as *const $ty,
127 incy,
128 );
129 return cast_as::<$ty, A>(&ret);
130 }
131 }
132 }};
133 }
134
135 dot! {f32, cblas_sdot};
136 dot! {f64, cblas_ddot};
137 }
138 self.dot_generic(rhs)
139 }
140}
141
142#[cfg(feature = "blas")]
148unsafe fn blas_1d_params<A>(
149 ptr: *const A,
150 len: usize,
151 stride: isize,
152) -> (*const A, blas_index, blas_index) {
153 if stride >= 0 || len == 0 {
158 (ptr, len as blas_index, stride as blas_index)
159 } else {
160 let ptr = ptr.offset((len - 1) as isize * stride);
161 (ptr, len as blas_index, stride as blas_index)
162 }
163}
164
165pub trait Dot<Rhs> {
170 type Output;
174 fn dot(&self, rhs: &Rhs) -> Self::Output;
175}
176
177impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix1>
178where
179 S: Data<Elem = A>,
180 S2: Data<Elem = A>,
181 A: LinalgScalar,
182{
183 type Output = A;
184
185 fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> A {
194 self.dot_impl(rhs)
195 }
196}
197
198impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix1>
199where
200 S: Data<Elem = A>,
201 S2: Data<Elem = A>,
202 A: LinalgScalar,
203{
204 type Output = Array<A, Ix1>;
205
206 fn dot(&self, rhs: &ArrayBase<S2, Ix2>) -> Array<A, Ix1> {
216 rhs.t().dot(self)
217 }
218}
219
220impl<A, S> ArrayBase<S, Ix2>
221where
222 S: Data<Elem = A>,
223{
224 pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
254 where
255 Self: Dot<Rhs>,
256 {
257 Dot::dot(self, rhs)
258 }
259}
260
261impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix2>
262where
263 S: Data<Elem = A>,
264 S2: Data<Elem = A>,
265 A: LinalgScalar,
266{
267 type Output = Array2<A>;
268 fn dot(&self, b: &ArrayBase<S2, Ix2>) -> Array2<A> {
269 let a = self.view();
270 let b = b.view();
271 let ((m, k), (k2, n)) = (a.dim(), b.dim());
272 if k != k2 || m.checked_mul(n).is_none() {
273 dot_shape_error(m, k, k2, n);
274 }
275
276 let lhs_s0 = a.strides()[0];
277 let rhs_s0 = b.strides()[0];
278 let column_major = lhs_s0 == 1 && rhs_s0 == 1;
279 let mut v = Vec::with_capacity(m * n);
281 let mut c;
282 unsafe {
283 v.set_len(m * n);
284 c = Array::from_shape_vec_unchecked((m, n).set_f(column_major), v);
285 }
286 mat_mul_impl(A::one(), &a, &b, A::zero(), &mut c.view_mut());
287 c
288 }
289}
290
291#[cold]
293#[inline(never)]
294fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> ! {
295 match m.checked_mul(n) {
296 Some(len) if len <= ::std::isize::MAX as usize => {}
297 _ => panic!("ndarray: shape {} × {} overflows isize", m, n),
298 }
299 panic!(
300 "ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication",
301 m, k, k2, n
302 );
303}
304
305#[cold]
306#[inline(never)]
307fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> ! {
308 panic!("ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication",
309 m, k, k2, n, c1, c2);
310}
311
312impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix2>
322where
323 S: Data<Elem = A>,
324 S2: Data<Elem = A>,
325 A: LinalgScalar,
326{
327 type Output = Array<A, Ix1>;
328 fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> Array<A, Ix1> {
329 let ((m, a), n) = (self.dim(), rhs.dim());
330 if a != n {
331 dot_shape_error(m, a, n, 1);
332 }
333
334 unsafe {
336 let mut c = Array1::uninit(m);
337 general_mat_vec_mul_impl(A::one(), self, rhs, A::zero(), c.raw_view_mut().cast::<A>());
338 c.assume_init()
339 }
340 }
341}
342
343impl<A, S, D> ArrayBase<S, D>
344where
345 S: Data<Elem = A>,
346 D: Dimension,
347{
348 pub fn scaled_add<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
356 where
357 S: DataMut,
358 S2: Data<Elem = A>,
359 A: LinalgScalar,
360 E: Dimension,
361 {
362 self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x));
363 }
364}
365
366#[cfg(not(feature = "blas"))]
369use self::mat_mul_general as mat_mul_impl;
370
371#[cfg(feature = "blas")]
372fn mat_mul_impl<A>(
373 alpha: A,
374 lhs: &ArrayView2<'_, A>,
375 rhs: &ArrayView2<'_, A>,
376 beta: A,
377 c: &mut ArrayViewMut2<'_, A>,
378) where
379 A: LinalgScalar,
380{
381 let cut = GEMM_BLAS_CUTOFF;
383 let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim());
384 if !(m > cut || n > cut || a > cut)
385 || !(same_type::<A, f32>()
386 || same_type::<A, f64>()
387 || same_type::<A, c32>()
388 || same_type::<A, c64>())
389 {
390 return mat_mul_general(alpha, lhs, rhs, beta, c);
391 }
392 {
393 let mut lhs_ = lhs.view();
397 let mut rhs_ = rhs.view();
398 let mut c_ = c.view_mut();
399 let lhs_s0 = lhs_.strides()[0];
400 let rhs_s0 = rhs_.strides()[0];
401 let both_f = lhs_s0 == 1 && rhs_s0 == 1;
402 let mut lhs_trans = CblasNoTrans;
403 let mut rhs_trans = CblasNoTrans;
404 if both_f {
405 let lhs_t = lhs_.reversed_axes();
407 lhs_ = rhs_.reversed_axes();
408 rhs_ = lhs_t;
409 c_ = c_.reversed_axes();
410 swap(&mut m, &mut n);
411 } else if lhs_s0 == 1 && m == a {
412 lhs_ = lhs_.reversed_axes();
413 lhs_trans = CblasTrans;
414 } else if rhs_s0 == 1 && a == n {
415 rhs_ = rhs_.reversed_axes();
416 rhs_trans = CblasTrans;
417 }
418
419 macro_rules! gemm_scalar_cast {
420 (f32, $var:ident) => {
421 cast_as(&$var)
422 };
423 (f64, $var:ident) => {
424 cast_as(&$var)
425 };
426 (c32, $var:ident) => {
427 &$var as *const A as *const _
428 };
429 (c64, $var:ident) => {
430 &$var as *const A as *const _
431 };
432 }
433
434 macro_rules! gemm {
435 ($ty:tt, $gemm:ident) => {
436 if blas_row_major_2d::<$ty, _>(&lhs_)
437 && blas_row_major_2d::<$ty, _>(&rhs_)
438 && blas_row_major_2d::<$ty, _>(&c_)
439 {
440 let (m, k) = match lhs_trans {
441 CblasNoTrans => lhs_.dim(),
442 _ => {
443 let (rows, cols) = lhs_.dim();
444 (cols, rows)
445 }
446 };
447 let n = match rhs_trans {
448 CblasNoTrans => rhs_.raw_dim()[1],
449 _ => rhs_.raw_dim()[0],
450 };
451 let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index);
453 let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index);
454 let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index);
455
456 unsafe {
459 blas_sys::$gemm(
460 CblasRowMajor,
461 lhs_trans,
462 rhs_trans,
463 m as blas_index, n as blas_index, k as blas_index, gemm_scalar_cast!($ty, alpha), lhs_.ptr.as_ptr() as *const _, lhs_stride, rhs_.ptr.as_ptr() as *const _, rhs_stride, gemm_scalar_cast!($ty, beta), c_.ptr.as_ptr() as *mut _, c_stride, );
475 }
476 return;
477 }
478 };
479 }
480 gemm!(f32, cblas_sgemm);
481 gemm!(f64, cblas_dgemm);
482
483 gemm!(c32, cblas_cgemm);
484 gemm!(c64, cblas_zgemm);
485 }
486 mat_mul_general(alpha, lhs, rhs, beta, c)
487}
488
489fn mat_mul_general<A>(
491 alpha: A,
492 lhs: &ArrayView2<'_, A>,
493 rhs: &ArrayView2<'_, A>,
494 beta: A,
495 c: &mut ArrayViewMut2<'_, A>,
496) where
497 A: LinalgScalar,
498{
499 let ((m, k), (_, n)) = (lhs.dim(), rhs.dim());
500
501 let ap = lhs.as_ptr();
503 let bp = rhs.as_ptr();
504 let cp = c.as_mut_ptr();
505 let (rsc, csc) = (c.strides()[0], c.strides()[1]);
506 if same_type::<A, f32>() {
507 unsafe {
508 matrixmultiply::sgemm(
509 m,
510 k,
511 n,
512 cast_as(&alpha),
513 ap as *const _,
514 lhs.strides()[0],
515 lhs.strides()[1],
516 bp as *const _,
517 rhs.strides()[0],
518 rhs.strides()[1],
519 cast_as(&beta),
520 cp as *mut _,
521 rsc,
522 csc,
523 );
524 }
525 } else if same_type::<A, f64>() {
526 unsafe {
527 matrixmultiply::dgemm(
528 m,
529 k,
530 n,
531 cast_as(&alpha),
532 ap as *const _,
533 lhs.strides()[0],
534 lhs.strides()[1],
535 bp as *const _,
536 rhs.strides()[0],
537 rhs.strides()[1],
538 cast_as(&beta),
539 cp as *mut _,
540 rsc,
541 csc,
542 );
543 }
544 } else if same_type::<A, c32>() {
545 unsafe {
546 matrixmultiply::cgemm(
547 matrixmultiply::CGemmOption::Standard,
548 matrixmultiply::CGemmOption::Standard,
549 m,
550 k,
551 n,
552 complex_array(cast_as(&alpha)),
553 ap as *const _,
554 lhs.strides()[0],
555 lhs.strides()[1],
556 bp as *const _,
557 rhs.strides()[0],
558 rhs.strides()[1],
559 complex_array(cast_as(&beta)),
560 cp as *mut _,
561 rsc,
562 csc,
563 );
564 }
565 } else if same_type::<A, c64>() {
566 unsafe {
567 matrixmultiply::zgemm(
568 matrixmultiply::CGemmOption::Standard,
569 matrixmultiply::CGemmOption::Standard,
570 m,
571 k,
572 n,
573 complex_array(cast_as(&alpha)),
574 ap as *const _,
575 lhs.strides()[0],
576 lhs.strides()[1],
577 bp as *const _,
578 rhs.strides()[0],
579 rhs.strides()[1],
580 complex_array(cast_as(&beta)),
581 cp as *mut _,
582 rsc,
583 csc,
584 );
585 }
586 } else {
587 if c.is_empty() {
589 return;
590 }
591
592 if beta.is_zero() {
594 c.fill(beta);
595 }
596
597 let mut i = 0;
598 let mut j = 0;
599 loop {
600 unsafe {
601 let elt = c.uget_mut((i, j));
602 *elt = *elt * beta
603 + alpha
604 * (0..k).fold(A::zero(), move |s, x| {
605 s + *lhs.uget((i, x)) * *rhs.uget((x, j))
606 });
607 }
608 j += 1;
609 if j == n {
610 j = 0;
611 i += 1;
612 if i == m {
613 break;
614 }
615 }
616 }
617 }
618}
619
620pub fn general_mat_mul<A, S1, S2, S3>(
632 alpha: A,
633 a: &ArrayBase<S1, Ix2>,
634 b: &ArrayBase<S2, Ix2>,
635 beta: A,
636 c: &mut ArrayBase<S3, Ix2>,
637) where
638 S1: Data<Elem = A>,
639 S2: Data<Elem = A>,
640 S3: DataMut<Elem = A>,
641 A: LinalgScalar,
642{
643 let ((m, k), (k2, n)) = (a.dim(), b.dim());
644 let (m2, n2) = c.dim();
645 if k != k2 || m != m2 || n != n2 {
646 general_dot_shape_error(m, k, k2, n, m2, n2);
647 } else {
648 mat_mul_impl(alpha, &a.view(), &b.view(), beta, &mut c.view_mut());
649 }
650}
651
652#[allow(clippy::collapsible_if)]
663pub fn general_mat_vec_mul<A, S1, S2, S3>(
664 alpha: A,
665 a: &ArrayBase<S1, Ix2>,
666 x: &ArrayBase<S2, Ix1>,
667 beta: A,
668 y: &mut ArrayBase<S3, Ix1>,
669) where
670 S1: Data<Elem = A>,
671 S2: Data<Elem = A>,
672 S3: DataMut<Elem = A>,
673 A: LinalgScalar,
674{
675 unsafe { general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) }
676}
677
678#[allow(clippy::collapsible_else_if)]
687unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
688 alpha: A,
689 a: &ArrayBase<S1, Ix2>,
690 x: &ArrayBase<S2, Ix1>,
691 beta: A,
692 y: RawArrayViewMut<A, Ix1>,
693) where
694 S1: Data<Elem = A>,
695 S2: Data<Elem = A>,
696 A: LinalgScalar,
697{
698 let ((m, k), k2) = (a.dim(), x.dim());
699 let m2 = y.dim();
700 if k != k2 || m != m2 {
701 general_dot_shape_error(m, k, k2, 1, m2, 1);
702 } else {
703 #[cfg(feature = "blas")]
704 macro_rules! gemv {
705 ($ty:ty, $gemv:ident) => {
706 if let Some(layout) = blas_layout::<$ty, _>(&a) {
707 if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) {
708 let a_trans = CblasNoTrans;
713 let a_stride = match layout {
714 CBLAS_LAYOUT::CblasRowMajor => {
715 a.strides()[0].max(k as isize) as blas_index
716 }
717 CBLAS_LAYOUT::CblasColMajor => {
718 a.strides()[1].max(m as isize) as blas_index
719 }
720 };
721
722 let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides);
724 let x_ptr = x.ptr.as_ptr().sub(x_offset);
725 let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides);
726 let y_ptr = y.ptr.as_ptr().sub(y_offset);
727
728 let x_stride = x.strides()[0] as blas_index;
729 let y_stride = y.strides()[0] as blas_index;
730
731 blas_sys::$gemv(
732 layout,
733 a_trans,
734 m as blas_index, k as blas_index, cast_as(&alpha), a.ptr.as_ptr() as *const _, a_stride, x_ptr as *const _, x_stride,
741 cast_as(&beta), y_ptr as *mut _, y_stride,
744 );
745 return;
746 }
747 }
748 };
749 }
750 #[cfg(feature = "blas")]
751 gemv!(f32, cblas_sgemv);
752 #[cfg(feature = "blas")]
753 gemv!(f64, cblas_dgemv);
754
755 if beta.is_zero() {
758 Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
760 elt.write(row.dot(x) * alpha);
761 });
762 } else {
763 Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
764 *elt = *elt * beta + row.dot(x) * alpha;
765 });
766 }
767 }
768}
769
770
771pub fn kron<A, S1, S2>(a: &ArrayBase<S1, Ix2>, b: &ArrayBase<S2, Ix2>) -> Array<A, Ix2>
776where
777 S1: Data<Elem = A>,
778 S2: Data<Elem = A>,
779 A: LinalgScalar,
780{
781 let dimar = a.shape()[0];
782 let dimac = a.shape()[1];
783 let dimbr = b.shape()[0];
784 let dimbc = b.shape()[1];
785 let mut out: Array2<MaybeUninit<A>> = Array2::uninit((
786 dimar
787 .checked_mul(dimbr)
788 .expect("Dimensions of kronecker product output array overflows usize."),
789 dimac
790 .checked_mul(dimbc)
791 .expect("Dimensions of kronecker product output array overflows usize."),
792 ));
793 Zip::from(out.exact_chunks_mut((dimbr, dimbc)))
794 .and(a)
795 .for_each(|out, &a| {
796 Zip::from(out).and(b).for_each(|out, &b| {
797 *out = MaybeUninit::new(a * b);
798 })
799 });
800 unsafe { out.assume_init() }
801}
802
803#[inline(always)]
804fn same_type<A: 'static, B: 'static>() -> bool {
806 TypeId::of::<A>() == TypeId::of::<B>()
807}
808
809fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B {
813 assert!(same_type::<A, B>(), "expect type {} and {} to match",
814 std::any::type_name::<A>(), std::any::type_name::<B>());
815 unsafe { ::std::ptr::read(a as *const _ as *const B) }
816}
817
818#[inline]
820fn complex_array<A: 'static + Copy>(z: Complex<A>) -> [A; 2] {
821 [z.re, z.im]
822}
823
824#[cfg(feature = "blas")]
825fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
826where
827 S: RawData,
828 A: 'static,
829 S::Elem: 'static,
830{
831 if !same_type::<A, S::Elem>() {
832 return false;
833 }
834 if a.len() > blas_index::max_value() as usize {
835 return false;
836 }
837 let stride = a.strides()[0];
838 if stride == 0
839 || stride > blas_index::max_value() as isize
840 || stride < blas_index::min_value() as isize
841 {
842 return false;
843 }
844 true
845}
846
847#[cfg(feature = "blas")]
848enum MemoryOrder {
849 C,
850 F,
851}
852
853#[cfg(feature = "blas")]
854fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
855where
856 S: Data,
857 A: 'static,
858 S::Elem: 'static,
859{
860 if !same_type::<A, S::Elem>() {
861 return false;
862 }
863 is_blas_2d(&a.dim, &a.strides, MemoryOrder::C)
864}
865
866#[cfg(feature = "blas")]
867fn blas_column_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
868where
869 S: Data,
870 A: 'static,
871 S::Elem: 'static,
872{
873 if !same_type::<A, S::Elem>() {
874 return false;
875 }
876 is_blas_2d(&a.dim, &a.strides, MemoryOrder::F)
877}
878
879#[cfg(feature = "blas")]
880fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool {
881 let (m, n) = dim.into_pattern();
882 let s0 = stride[0] as isize;
883 let s1 = stride[1] as isize;
884 let (inner_stride, outer_dim) = match order {
885 MemoryOrder::C => (s1, n),
886 MemoryOrder::F => (s0, m),
887 };
888 if !(inner_stride == 1 || outer_dim == 1) {
889 return false;
890 }
891 if s0 < 1 || s1 < 1 {
892 return false;
893 }
894 if (s0 > blas_index::max_value() as isize || s0 < blas_index::min_value() as isize)
895 || (s1 > blas_index::max_value() as isize || s1 < blas_index::min_value() as isize)
896 {
897 return false;
898 }
899 if m > blas_index::max_value() as usize || n > blas_index::max_value() as usize {
900 return false;
901 }
902 true
903}
904
905#[cfg(feature = "blas")]
906fn blas_layout<A, S>(a: &ArrayBase<S, Ix2>) -> Option<CBLAS_LAYOUT>
907where
908 S: Data,
909 A: 'static,
910 S::Elem: 'static,
911{
912 if blas_row_major_2d::<A, _>(a) {
913 Some(CBLAS_LAYOUT::CblasRowMajor)
914 } else if blas_column_major_2d::<A, _>(a) {
915 Some(CBLAS_LAYOUT::CblasColMajor)
916 } else {
917 None
918 }
919}
920
921#[cfg(test)]
922#[cfg(feature = "blas")]
923mod blas_tests {
924 use super::*;
925
926 #[test]
927 fn blas_row_major_2d_normal_matrix() {
928 let m: Array2<f32> = Array2::zeros((3, 5));
929 assert!(blas_row_major_2d::<f32, _>(&m));
930 assert!(!blas_column_major_2d::<f32, _>(&m));
931 }
932
933 #[test]
934 fn blas_row_major_2d_row_matrix() {
935 let m: Array2<f32> = Array2::zeros((1, 5));
936 assert!(blas_row_major_2d::<f32, _>(&m));
937 assert!(blas_column_major_2d::<f32, _>(&m));
938 }
939
940 #[test]
941 fn blas_row_major_2d_column_matrix() {
942 let m: Array2<f32> = Array2::zeros((5, 1));
943 assert!(blas_row_major_2d::<f32, _>(&m));
944 assert!(blas_column_major_2d::<f32, _>(&m));
945 }
946
947 #[test]
948 fn blas_row_major_2d_transposed_row_matrix() {
949 let m: Array2<f32> = Array2::zeros((1, 5));
950 let m_t = m.t();
951 assert!(blas_row_major_2d::<f32, _>(&m_t));
952 assert!(blas_column_major_2d::<f32, _>(&m_t));
953 }
954
955 #[test]
956 fn blas_row_major_2d_transposed_column_matrix() {
957 let m: Array2<f32> = Array2::zeros((5, 1));
958 let m_t = m.t();
959 assert!(blas_row_major_2d::<f32, _>(&m_t));
960 assert!(blas_column_major_2d::<f32, _>(&m_t));
961 }
962
963 #[test]
964 fn blas_column_major_2d_normal_matrix() {
965 let m: Array2<f32> = Array2::zeros((3, 5).f());
966 assert!(!blas_row_major_2d::<f32, _>(&m));
967 assert!(blas_column_major_2d::<f32, _>(&m));
968 }
969}