ndarray/linalg/
impl_linalg.rs

1// Copyright 2014-2020 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::imp_prelude::*;
10
11#[cfg(feature = "blas")]
12use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
13use crate::numeric_util;
14
15use crate::{LinalgScalar, Zip};
16
17use std::any::TypeId;
18use std::mem::MaybeUninit;
19use alloc::vec::Vec;
20
21use num_complex::Complex;
22use num_complex::{Complex32 as c32, Complex64 as c64};
23
24#[cfg(feature = "blas")]
25use libc::c_int;
26#[cfg(feature = "blas")]
27use std::cmp;
28#[cfg(feature = "blas")]
29use std::mem::swap;
30
31#[cfg(feature = "blas")]
32use cblas_sys as blas_sys;
33#[cfg(feature = "blas")]
34use cblas_sys::{CblasNoTrans, CblasRowMajor, CblasTrans, CBLAS_LAYOUT};
35
36/// len of vector before we use blas
37#[cfg(feature = "blas")]
38const DOT_BLAS_CUTOFF: usize = 32;
39/// side of matrix before we use blas
40#[cfg(feature = "blas")]
41const GEMM_BLAS_CUTOFF: usize = 7;
42#[cfg(feature = "blas")]
43#[allow(non_camel_case_types)]
44type blas_index = c_int; // blas index type
45
46impl<A, S> ArrayBase<S, Ix1>
47where
48    S: Data<Elem = A>,
49{
50    /// Perform dot product or matrix multiplication of arrays `self` and `rhs`.
51    ///
52    /// `Rhs` may be either a one-dimensional or a two-dimensional array.
53    ///
54    /// If `Rhs` is one-dimensional, then the operation is a vector dot
55    /// product, which is the sum of the elementwise products (no conjugation
56    /// of complex operands, and thus not their inner product). In this case,
57    /// `self` and `rhs` must be the same length.
58    ///
59    /// If `Rhs` is two-dimensional, then the operation is matrix
60    /// multiplication, where `self` is treated as a row vector. In this case,
61    /// if `self` is shape *M*, then `rhs` is shape *M* × *N* and the result is
62    /// shape *N*.
63    ///
64    /// **Panics** if the array shapes are incompatible.<br>
65    /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
66    /// layout allows.
67    pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
68    where
69        Self: Dot<Rhs>,
70    {
71        Dot::dot(self, rhs)
72    }
73
74    fn dot_generic<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
75    where
76        S2: Data<Elem = A>,
77        A: LinalgScalar,
78    {
79        debug_assert_eq!(self.len(), rhs.len());
80        assert!(self.len() == rhs.len());
81        if let Some(self_s) = self.as_slice() {
82            if let Some(rhs_s) = rhs.as_slice() {
83                return numeric_util::unrolled_dot(self_s, rhs_s);
84            }
85        }
86        let mut sum = A::zero();
87        for i in 0..self.len() {
88            unsafe {
89                sum = sum + *self.uget(i) * *rhs.uget(i);
90            }
91        }
92        sum
93    }
94
95    #[cfg(not(feature = "blas"))]
96    fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
97    where
98        S2: Data<Elem = A>,
99        A: LinalgScalar,
100    {
101        self.dot_generic(rhs)
102    }
103
104    #[cfg(feature = "blas")]
105    fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
106    where
107        S2: Data<Elem = A>,
108        A: LinalgScalar,
109    {
110        // Use only if the vector is large enough to be worth it
111        if self.len() >= DOT_BLAS_CUTOFF {
112            debug_assert_eq!(self.len(), rhs.len());
113            assert!(self.len() == rhs.len());
114            macro_rules! dot {
115                ($ty:ty, $func:ident) => {{
116                    if blas_compat_1d::<$ty, _>(self) && blas_compat_1d::<$ty, _>(rhs) {
117                        unsafe {
118                            let (lhs_ptr, n, incx) =
119                                blas_1d_params(self.ptr.as_ptr(), self.len(), self.strides()[0]);
120                            let (rhs_ptr, _, incy) =
121                                blas_1d_params(rhs.ptr.as_ptr(), rhs.len(), rhs.strides()[0]);
122                            let ret = blas_sys::$func(
123                                n,
124                                lhs_ptr as *const $ty,
125                                incx,
126                                rhs_ptr as *const $ty,
127                                incy,
128                            );
129                            return cast_as::<$ty, A>(&ret);
130                        }
131                    }
132                }};
133            }
134
135            dot! {f32, cblas_sdot};
136            dot! {f64, cblas_ddot};
137        }
138        self.dot_generic(rhs)
139    }
140}
141
142/// Return a pointer to the starting element in BLAS's view.
143///
144/// BLAS wants a pointer to the element with lowest address,
145/// which agrees with our pointer for non-negative strides, but
146/// is at the opposite end for negative strides.
147#[cfg(feature = "blas")]
148unsafe fn blas_1d_params<A>(
149    ptr: *const A,
150    len: usize,
151    stride: isize,
152) -> (*const A, blas_index, blas_index) {
153    // [x x x x]
154    //        ^--ptr
155    //        stride = -1
156    //  ^--blas_ptr = ptr + (len - 1) * stride
157    if stride >= 0 || len == 0 {
158        (ptr, len as blas_index, stride as blas_index)
159    } else {
160        let ptr = ptr.offset((len - 1) as isize * stride);
161        (ptr, len as blas_index, stride as blas_index)
162    }
163}
164
165/// Matrix Multiplication
166///
167/// For two-dimensional arrays, the dot method computes the matrix
168/// multiplication.
169pub trait Dot<Rhs> {
170    /// The result of the operation.
171    ///
172    /// For two-dimensional arrays: a rectangular array.
173    type Output;
174    fn dot(&self, rhs: &Rhs) -> Self::Output;
175}
176
177impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix1>
178where
179    S: Data<Elem = A>,
180    S2: Data<Elem = A>,
181    A: LinalgScalar,
182{
183    type Output = A;
184
185    /// Compute the dot product of one-dimensional arrays.
186    ///
187    /// The dot product is a sum of the elementwise products (no conjugation
188    /// of complex operands, and thus not their inner product).
189    ///
190    /// **Panics** if the arrays are not of the same length.<br>
191    /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
192    /// layout allows.
193    fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> A {
194        self.dot_impl(rhs)
195    }
196}
197
198impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix1>
199where
200    S: Data<Elem = A>,
201    S2: Data<Elem = A>,
202    A: LinalgScalar,
203{
204    type Output = Array<A, Ix1>;
205
206    /// Perform the matrix multiplication of the row vector `self` and
207    /// rectangular matrix `rhs`.
208    ///
209    /// The array shapes must agree in the way that
210    /// if `self` is *M*, then `rhs` is *M* × *N*.
211    ///
212    /// Return a result array with shape *N*.
213    ///
214    /// **Panics** if shapes are incompatible.
215    fn dot(&self, rhs: &ArrayBase<S2, Ix2>) -> Array<A, Ix1> {
216        rhs.t().dot(self)
217    }
218}
219
220impl<A, S> ArrayBase<S, Ix2>
221where
222    S: Data<Elem = A>,
223{
224    /// Perform matrix multiplication of rectangular arrays `self` and `rhs`.
225    ///
226    /// `Rhs` may be either a one-dimensional or a two-dimensional array.
227    ///
228    /// If Rhs is two-dimensional, they array shapes must agree in the way that
229    /// if `self` is *M* × *N*, then `rhs` is *N* × *K*.
230    ///
231    /// Return a result array with shape *M* × *K*.
232    ///
233    /// **Panics** if shapes are incompatible or the number of elements in the
234    /// result would overflow `isize`.
235    ///
236    /// *Note:* If enabled, uses blas `gemv/gemm` for elements of `f32, f64`
237    /// when memory layout allows. The default matrixmultiply backend
238    /// is otherwise used for `f32, f64` for all memory layouts.
239    ///
240    /// ```
241    /// use ndarray::arr2;
242    ///
243    /// let a = arr2(&[[1., 2.],
244    ///                [0., 1.]]);
245    /// let b = arr2(&[[1., 2.],
246    ///                [2., 3.]]);
247    ///
248    /// assert!(
249    ///     a.dot(&b) == arr2(&[[5., 8.],
250    ///                         [2., 3.]])
251    /// );
252    /// ```
253    pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
254    where
255        Self: Dot<Rhs>,
256    {
257        Dot::dot(self, rhs)
258    }
259}
260
261impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix2>
262where
263    S: Data<Elem = A>,
264    S2: Data<Elem = A>,
265    A: LinalgScalar,
266{
267    type Output = Array2<A>;
268    fn dot(&self, b: &ArrayBase<S2, Ix2>) -> Array2<A> {
269        let a = self.view();
270        let b = b.view();
271        let ((m, k), (k2, n)) = (a.dim(), b.dim());
272        if k != k2 || m.checked_mul(n).is_none() {
273            dot_shape_error(m, k, k2, n);
274        }
275
276        let lhs_s0 = a.strides()[0];
277        let rhs_s0 = b.strides()[0];
278        let column_major = lhs_s0 == 1 && rhs_s0 == 1;
279        // A is Copy so this is safe
280        let mut v = Vec::with_capacity(m * n);
281        let mut c;
282        unsafe {
283            v.set_len(m * n);
284            c = Array::from_shape_vec_unchecked((m, n).set_f(column_major), v);
285        }
286        mat_mul_impl(A::one(), &a, &b, A::zero(), &mut c.view_mut());
287        c
288    }
289}
290
291/// Assumes that `m` and `n` are ≤ `isize::MAX`.
292#[cold]
293#[inline(never)]
294fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> ! {
295    match m.checked_mul(n) {
296        Some(len) if len <= ::std::isize::MAX as usize => {}
297        _ => panic!("ndarray: shape {} × {} overflows isize", m, n),
298    }
299    panic!(
300        "ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication",
301        m, k, k2, n
302    );
303}
304
305#[cold]
306#[inline(never)]
307fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> ! {
308    panic!("ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication",
309           m, k, k2, n, c1, c2);
310}
311
312/// Perform the matrix multiplication of the rectangular array `self` and
313/// column vector `rhs`.
314///
315/// The array shapes must agree in the way that
316/// if `self` is *M* × *N*, then `rhs` is *N*.
317///
318/// Return a result array with shape *M*.
319///
320/// **Panics** if shapes are incompatible.
321impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix2>
322where
323    S: Data<Elem = A>,
324    S2: Data<Elem = A>,
325    A: LinalgScalar,
326{
327    type Output = Array<A, Ix1>;
328    fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> Array<A, Ix1> {
329        let ((m, a), n) = (self.dim(), rhs.dim());
330        if a != n {
331            dot_shape_error(m, a, n, 1);
332        }
333
334        // Avoid initializing the memory in vec -- set it during iteration
335        unsafe {
336            let mut c = Array1::uninit(m);
337            general_mat_vec_mul_impl(A::one(), self, rhs, A::zero(), c.raw_view_mut().cast::<A>());
338            c.assume_init()
339        }
340    }
341}
342
343impl<A, S, D> ArrayBase<S, D>
344where
345    S: Data<Elem = A>,
346    D: Dimension,
347{
348    /// Perform the operation `self += alpha * rhs` efficiently, where
349    /// `alpha` is a scalar and `rhs` is another array. This operation is
350    /// also known as `axpy` in BLAS.
351    ///
352    /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
353    ///
354    /// **Panics** if broadcasting isn’t possible.
355    pub fn scaled_add<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
356    where
357        S: DataMut,
358        S2: Data<Elem = A>,
359        A: LinalgScalar,
360        E: Dimension,
361    {
362        self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x));
363    }
364}
365
366// mat_mul_impl uses ArrayView arguments to send all array kinds into
367// the same instantiated implementation.
368#[cfg(not(feature = "blas"))]
369use self::mat_mul_general as mat_mul_impl;
370
371#[cfg(feature = "blas")]
372fn mat_mul_impl<A>(
373    alpha: A,
374    lhs: &ArrayView2<'_, A>,
375    rhs: &ArrayView2<'_, A>,
376    beta: A,
377    c: &mut ArrayViewMut2<'_, A>,
378) where
379    A: LinalgScalar,
380{
381    // size cutoff for using BLAS
382    let cut = GEMM_BLAS_CUTOFF;
383    let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim());
384    if !(m > cut || n > cut || a > cut)
385        || !(same_type::<A, f32>()
386        || same_type::<A, f64>()
387        || same_type::<A, c32>()
388        || same_type::<A, c64>())
389    {
390        return mat_mul_general(alpha, lhs, rhs, beta, c);
391    }
392    {
393        // Use `c` for c-order and `f` for an f-order matrix
394        // We can handle c * c, f * f generally and
395        // c * f and f * c if the `f` matrix is square.
396        let mut lhs_ = lhs.view();
397        let mut rhs_ = rhs.view();
398        let mut c_ = c.view_mut();
399        let lhs_s0 = lhs_.strides()[0];
400        let rhs_s0 = rhs_.strides()[0];
401        let both_f = lhs_s0 == 1 && rhs_s0 == 1;
402        let mut lhs_trans = CblasNoTrans;
403        let mut rhs_trans = CblasNoTrans;
404        if both_f {
405            // A^t B^t = C^t => B A = C
406            let lhs_t = lhs_.reversed_axes();
407            lhs_ = rhs_.reversed_axes();
408            rhs_ = lhs_t;
409            c_ = c_.reversed_axes();
410            swap(&mut m, &mut n);
411        } else if lhs_s0 == 1 && m == a {
412            lhs_ = lhs_.reversed_axes();
413            lhs_trans = CblasTrans;
414        } else if rhs_s0 == 1 && a == n {
415            rhs_ = rhs_.reversed_axes();
416            rhs_trans = CblasTrans;
417        }
418
419        macro_rules! gemm_scalar_cast {
420            (f32, $var:ident) => {
421                cast_as(&$var)
422            };
423            (f64, $var:ident) => {
424                cast_as(&$var)
425            };
426            (c32, $var:ident) => {
427                &$var as *const A as *const _
428            };
429            (c64, $var:ident) => {
430                &$var as *const A as *const _
431            };
432        }
433
434        macro_rules! gemm {
435            ($ty:tt, $gemm:ident) => {
436                if blas_row_major_2d::<$ty, _>(&lhs_)
437                    && blas_row_major_2d::<$ty, _>(&rhs_)
438                    && blas_row_major_2d::<$ty, _>(&c_)
439                {
440                    let (m, k) = match lhs_trans {
441                        CblasNoTrans => lhs_.dim(),
442                        _ => {
443                            let (rows, cols) = lhs_.dim();
444                            (cols, rows)
445                        }
446                    };
447                    let n = match rhs_trans {
448                        CblasNoTrans => rhs_.raw_dim()[1],
449                        _ => rhs_.raw_dim()[0],
450                    };
451                    // adjust strides, these may [1, 1] for column matrices
452                    let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index);
453                    let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index);
454                    let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index);
455                    
456                    // gemm is C ← αA^Op B^Op + βC
457                    // Where Op is notrans/trans/conjtrans
458                    unsafe {
459                        blas_sys::$gemm(
460                            CblasRowMajor,
461                            lhs_trans,
462                            rhs_trans,
463                            m as blas_index,                 // m, rows of Op(a)
464                            n as blas_index,                 // n, cols of Op(b)
465                            k as blas_index,                 // k, cols of Op(a)
466                            gemm_scalar_cast!($ty, alpha),   // alpha
467                            lhs_.ptr.as_ptr() as *const _,   // a
468                            lhs_stride,                      // lda
469                            rhs_.ptr.as_ptr() as *const _,   // b
470                            rhs_stride,                      // ldb
471                            gemm_scalar_cast!($ty, beta),    // beta
472                            c_.ptr.as_ptr() as *mut _,       // c
473                            c_stride,                        // ldc
474                        );
475                    }
476                    return;
477                }
478            };
479        }
480        gemm!(f32, cblas_sgemm);
481        gemm!(f64, cblas_dgemm);
482
483        gemm!(c32, cblas_cgemm);
484        gemm!(c64, cblas_zgemm);
485    }
486    mat_mul_general(alpha, lhs, rhs, beta, c)
487}
488
489/// C ← α A B + β C
490fn mat_mul_general<A>(
491    alpha: A,
492    lhs: &ArrayView2<'_, A>,
493    rhs: &ArrayView2<'_, A>,
494    beta: A,
495    c: &mut ArrayViewMut2<'_, A>,
496) where
497    A: LinalgScalar,
498{
499    let ((m, k), (_, n)) = (lhs.dim(), rhs.dim());
500
501    // common parameters for gemm
502    let ap = lhs.as_ptr();
503    let bp = rhs.as_ptr();
504    let cp = c.as_mut_ptr();
505    let (rsc, csc) = (c.strides()[0], c.strides()[1]);
506    if same_type::<A, f32>() {
507        unsafe {
508            matrixmultiply::sgemm(
509                m,
510                k,
511                n,
512                cast_as(&alpha),
513                ap as *const _,
514                lhs.strides()[0],
515                lhs.strides()[1],
516                bp as *const _,
517                rhs.strides()[0],
518                rhs.strides()[1],
519                cast_as(&beta),
520                cp as *mut _,
521                rsc,
522                csc,
523            );
524        }
525    } else if same_type::<A, f64>() {
526        unsafe {
527            matrixmultiply::dgemm(
528                m,
529                k,
530                n,
531                cast_as(&alpha),
532                ap as *const _,
533                lhs.strides()[0],
534                lhs.strides()[1],
535                bp as *const _,
536                rhs.strides()[0],
537                rhs.strides()[1],
538                cast_as(&beta),
539                cp as *mut _,
540                rsc,
541                csc,
542            );
543        }
544    } else if same_type::<A, c32>() {
545        unsafe {
546            matrixmultiply::cgemm(
547                matrixmultiply::CGemmOption::Standard,
548                matrixmultiply::CGemmOption::Standard,
549                m,
550                k,
551                n,
552                complex_array(cast_as(&alpha)),
553                ap as *const _,
554                lhs.strides()[0],
555                lhs.strides()[1],
556                bp as *const _,
557                rhs.strides()[0],
558                rhs.strides()[1],
559                complex_array(cast_as(&beta)),
560                cp as *mut _,
561                rsc,
562                csc,
563            );
564        }
565    } else if same_type::<A, c64>() {
566        unsafe {
567            matrixmultiply::zgemm(
568                matrixmultiply::CGemmOption::Standard,
569                matrixmultiply::CGemmOption::Standard,
570                m,
571                k,
572                n,
573                complex_array(cast_as(&alpha)),
574                ap as *const _,
575                lhs.strides()[0],
576                lhs.strides()[1],
577                bp as *const _,
578                rhs.strides()[0],
579                rhs.strides()[1],
580                complex_array(cast_as(&beta)),
581                cp as *mut _,
582                rsc,
583                csc,
584            );
585        }
586    } else {
587        // It's a no-op if `c` has zero length.
588        if c.is_empty() {
589            return;
590        }
591
592        // initialize memory if beta is zero
593        if beta.is_zero() {
594            c.fill(beta);
595        }
596
597        let mut i = 0;
598        let mut j = 0;
599        loop {
600            unsafe {
601                let elt = c.uget_mut((i, j));
602                *elt = *elt * beta
603                    + alpha
604                        * (0..k).fold(A::zero(), move |s, x| {
605                            s + *lhs.uget((i, x)) * *rhs.uget((x, j))
606                        });
607            }
608            j += 1;
609            if j == n {
610                j = 0;
611                i += 1;
612                if i == m {
613                    break;
614                }
615            }
616        }
617    }
618}
619
620/// General matrix-matrix multiplication.
621///
622/// Compute C ← α A B + β C
623///
624/// The array shapes must agree in the way that
625/// if `a` is *M* × *N*, then `b` is *N* × *K* and `c` is *M* × *K*.
626///
627/// ***Panics*** if array shapes are not compatible<br>
628/// *Note:* If enabled, uses blas `gemm` for elements of `f32, f64` when memory
629/// layout allows.  The default matrixmultiply backend is otherwise used for
630/// `f32, f64` for all memory layouts.
631pub fn general_mat_mul<A, S1, S2, S3>(
632    alpha: A,
633    a: &ArrayBase<S1, Ix2>,
634    b: &ArrayBase<S2, Ix2>,
635    beta: A,
636    c: &mut ArrayBase<S3, Ix2>,
637) where
638    S1: Data<Elem = A>,
639    S2: Data<Elem = A>,
640    S3: DataMut<Elem = A>,
641    A: LinalgScalar,
642{
643    let ((m, k), (k2, n)) = (a.dim(), b.dim());
644    let (m2, n2) = c.dim();
645    if k != k2 || m != m2 || n != n2 {
646        general_dot_shape_error(m, k, k2, n, m2, n2);
647    } else {
648        mat_mul_impl(alpha, &a.view(), &b.view(), beta, &mut c.view_mut());
649    }
650}
651
652/// General matrix-vector multiplication.
653///
654/// Compute y ← α A x + β y
655///
656/// where A is a *M* × *N* matrix and x is an *N*-element column vector and
657/// y an *M*-element column vector (one dimensional arrays).
658///
659/// ***Panics*** if array shapes are not compatible<br>
660/// *Note:* If enabled, uses blas `gemv` for elements of `f32, f64` when memory
661/// layout allows.
662#[allow(clippy::collapsible_if)]
663pub fn general_mat_vec_mul<A, S1, S2, S3>(
664    alpha: A,
665    a: &ArrayBase<S1, Ix2>,
666    x: &ArrayBase<S2, Ix1>,
667    beta: A,
668    y: &mut ArrayBase<S3, Ix1>,
669) where
670    S1: Data<Elem = A>,
671    S2: Data<Elem = A>,
672    S3: DataMut<Elem = A>,
673    A: LinalgScalar,
674{
675    unsafe { general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) }
676}
677
678/// General matrix-vector multiplication
679///
680/// Use a raw view for the destination vector, so that it can be uninitialized.
681///
682/// ## Safety
683///
684/// The caller must ensure that the raw view is valid for writing.
685/// the destination may be uninitialized iff beta is zero.
686#[allow(clippy::collapsible_else_if)]
687unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
688    alpha: A,
689    a: &ArrayBase<S1, Ix2>,
690    x: &ArrayBase<S2, Ix1>,
691    beta: A,
692    y: RawArrayViewMut<A, Ix1>,
693) where
694    S1: Data<Elem = A>,
695    S2: Data<Elem = A>,
696    A: LinalgScalar,
697{
698    let ((m, k), k2) = (a.dim(), x.dim());
699    let m2 = y.dim();
700    if k != k2 || m != m2 {
701        general_dot_shape_error(m, k, k2, 1, m2, 1);
702    } else {
703        #[cfg(feature = "blas")]
704        macro_rules! gemv {
705            ($ty:ty, $gemv:ident) => {
706                if let Some(layout) = blas_layout::<$ty, _>(&a) {
707                    if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) {
708                        // Determine stride between rows or columns. Note that the stride is
709                        // adjusted to at least `k` or `m` to handle the case of a matrix with a
710                        // trivial (length 1) dimension, since the stride for the trivial dimension
711                        // may be arbitrary.
712                        let a_trans = CblasNoTrans;
713                        let a_stride = match layout {
714                            CBLAS_LAYOUT::CblasRowMajor => {
715                                a.strides()[0].max(k as isize) as blas_index
716                            }
717                            CBLAS_LAYOUT::CblasColMajor => {
718                                a.strides()[1].max(m as isize) as blas_index
719                            }
720                        };
721
722                        // Low addr in memory pointers required for x, y
723                        let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides);
724                        let x_ptr = x.ptr.as_ptr().sub(x_offset);
725                        let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides);
726                        let y_ptr = y.ptr.as_ptr().sub(y_offset);
727
728                        let x_stride = x.strides()[0] as blas_index;
729                        let y_stride = y.strides()[0] as blas_index;
730
731                        blas_sys::$gemv(
732                            layout,
733                            a_trans,
734                            m as blas_index,            // m, rows of Op(a)
735                            k as blas_index,            // n, cols of Op(a)
736                            cast_as(&alpha),            // alpha
737                            a.ptr.as_ptr() as *const _, // a
738                            a_stride,                   // lda
739                            x_ptr as *const _,          // x
740                            x_stride,
741                            cast_as(&beta),             // beta
742                            y_ptr as *mut _,            // y
743                            y_stride,
744                        );
745                        return;
746                    }
747                }
748            };
749        }
750        #[cfg(feature = "blas")]
751        gemv!(f32, cblas_sgemv);
752        #[cfg(feature = "blas")]
753        gemv!(f64, cblas_dgemv);
754
755        /* general */
756
757        if beta.is_zero() {
758            // when beta is zero, c may be uninitialized
759            Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
760                elt.write(row.dot(x) * alpha);
761            });
762        } else {
763            Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
764                *elt = *elt * beta + row.dot(x) * alpha;
765            });
766        }
767    }
768}
769
770
771/// Kronecker product of 2D matrices.
772///
773/// The kronecker product of a LxN matrix A and a MxR matrix B is a (L*M)x(N*R)
774/// matrix K formed by the block multiplication A_ij * B.
775pub fn kron<A, S1, S2>(a: &ArrayBase<S1, Ix2>, b: &ArrayBase<S2, Ix2>) -> Array<A, Ix2>
776where
777    S1: Data<Elem = A>,
778    S2: Data<Elem = A>,
779    A: LinalgScalar,
780{
781    let dimar = a.shape()[0];
782    let dimac = a.shape()[1];
783    let dimbr = b.shape()[0];
784    let dimbc = b.shape()[1];
785    let mut out: Array2<MaybeUninit<A>> = Array2::uninit((
786        dimar
787            .checked_mul(dimbr)
788            .expect("Dimensions of kronecker product output array overflows usize."),
789        dimac
790            .checked_mul(dimbc)
791            .expect("Dimensions of kronecker product output array overflows usize."),
792    ));
793    Zip::from(out.exact_chunks_mut((dimbr, dimbc)))
794        .and(a)
795        .for_each(|out, &a| {
796            Zip::from(out).and(b).for_each(|out, &b| {
797                *out = MaybeUninit::new(a * b);
798            })
799        });
800    unsafe { out.assume_init() }
801}
802
803#[inline(always)]
804/// Return `true` if `A` and `B` are the same type
805fn same_type<A: 'static, B: 'static>() -> bool {
806    TypeId::of::<A>() == TypeId::of::<B>()
807}
808
809// Read pointer to type `A` as type `B`.
810//
811// **Panics** if `A` and `B` are not the same type
812fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B {
813    assert!(same_type::<A, B>(), "expect type {} and {} to match",
814            std::any::type_name::<A>(), std::any::type_name::<B>());
815    unsafe { ::std::ptr::read(a as *const _ as *const B) }
816}
817
818/// Return the complex in the form of an array [re, im]
819#[inline]
820fn complex_array<A: 'static + Copy>(z: Complex<A>) -> [A; 2] {
821    [z.re, z.im]
822}
823
824#[cfg(feature = "blas")]
825fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
826where
827    S: RawData,
828    A: 'static,
829    S::Elem: 'static,
830{
831    if !same_type::<A, S::Elem>() {
832        return false;
833    }
834    if a.len() > blas_index::max_value() as usize {
835        return false;
836    }
837    let stride = a.strides()[0];
838    if stride == 0
839        || stride > blas_index::max_value() as isize
840        || stride < blas_index::min_value() as isize
841    {
842        return false;
843    }
844    true
845}
846
847#[cfg(feature = "blas")]
848enum MemoryOrder {
849    C,
850    F,
851}
852
853#[cfg(feature = "blas")]
854fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
855where
856    S: Data,
857    A: 'static,
858    S::Elem: 'static,
859{
860    if !same_type::<A, S::Elem>() {
861        return false;
862    }
863    is_blas_2d(&a.dim, &a.strides, MemoryOrder::C)
864}
865
866#[cfg(feature = "blas")]
867fn blas_column_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
868where
869    S: Data,
870    A: 'static,
871    S::Elem: 'static,
872{
873    if !same_type::<A, S::Elem>() {
874        return false;
875    }
876    is_blas_2d(&a.dim, &a.strides, MemoryOrder::F)
877}
878
879#[cfg(feature = "blas")]
880fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool {
881    let (m, n) = dim.into_pattern();
882    let s0 = stride[0] as isize;
883    let s1 = stride[1] as isize;
884    let (inner_stride, outer_dim) = match order {
885        MemoryOrder::C => (s1, n),
886        MemoryOrder::F => (s0, m),
887    };
888    if !(inner_stride == 1 || outer_dim == 1) {
889        return false;
890    }
891    if s0 < 1 || s1 < 1 {
892        return false;
893    }
894    if (s0 > blas_index::max_value() as isize || s0 < blas_index::min_value() as isize)
895        || (s1 > blas_index::max_value() as isize || s1 < blas_index::min_value() as isize)
896    {
897        return false;
898    }
899    if m > blas_index::max_value() as usize || n > blas_index::max_value() as usize {
900        return false;
901    }
902    true
903}
904
905#[cfg(feature = "blas")]
906fn blas_layout<A, S>(a: &ArrayBase<S, Ix2>) -> Option<CBLAS_LAYOUT>
907where
908    S: Data,
909    A: 'static,
910    S::Elem: 'static,
911{
912    if blas_row_major_2d::<A, _>(a) {
913        Some(CBLAS_LAYOUT::CblasRowMajor)
914    } else if blas_column_major_2d::<A, _>(a) {
915        Some(CBLAS_LAYOUT::CblasColMajor)
916    } else {
917        None
918    }
919}
920
921#[cfg(test)]
922#[cfg(feature = "blas")]
923mod blas_tests {
924    use super::*;
925
926    #[test]
927    fn blas_row_major_2d_normal_matrix() {
928        let m: Array2<f32> = Array2::zeros((3, 5));
929        assert!(blas_row_major_2d::<f32, _>(&m));
930        assert!(!blas_column_major_2d::<f32, _>(&m));
931    }
932
933    #[test]
934    fn blas_row_major_2d_row_matrix() {
935        let m: Array2<f32> = Array2::zeros((1, 5));
936        assert!(blas_row_major_2d::<f32, _>(&m));
937        assert!(blas_column_major_2d::<f32, _>(&m));
938    }
939
940    #[test]
941    fn blas_row_major_2d_column_matrix() {
942        let m: Array2<f32> = Array2::zeros((5, 1));
943        assert!(blas_row_major_2d::<f32, _>(&m));
944        assert!(blas_column_major_2d::<f32, _>(&m));
945    }
946
947    #[test]
948    fn blas_row_major_2d_transposed_row_matrix() {
949        let m: Array2<f32> = Array2::zeros((1, 5));
950        let m_t = m.t();
951        assert!(blas_row_major_2d::<f32, _>(&m_t));
952        assert!(blas_column_major_2d::<f32, _>(&m_t));
953    }
954
955    #[test]
956    fn blas_row_major_2d_transposed_column_matrix() {
957        let m: Array2<f32> = Array2::zeros((5, 1));
958        let m_t = m.t();
959        assert!(blas_row_major_2d::<f32, _>(&m_t));
960        assert!(blas_column_major_2d::<f32, _>(&m_t));
961    }
962
963    #[test]
964    fn blas_column_major_2d_normal_matrix() {
965        let m: Array2<f32> = Array2::zeros((3, 5).f());
966        assert!(!blas_row_major_2d::<f32, _>(&m));
967        assert!(blas_column_major_2d::<f32, _>(&m));
968    }
969}