lax/solve.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
//! Solve linear problem using LU decomposition
use crate::{error::*, layout::MatrixLayout, *};
use cauchy::*;
use num_traits::{ToPrimitive, Zero};
pub trait Solve_: Scalar + Sized {
/// Computes the LU factorization of a general `m x n` matrix `a` using
/// partial pivoting with row interchanges.
///
/// $ PA = LU $
///
/// Error
/// ------
/// - `LapackComputationalFailure { return_code }` when the matrix is singular
/// - Division by zero will occur if it is used to solve a system of equations
/// because `U[(return_code-1, return_code-1)]` is exactly zero.
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>;
fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>;
fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>;
}
macro_rules! impl_solve {
($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {
impl Solve_ for $scalar {
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> {
let (row, col) = l.size();
assert_eq!(a.len() as i32, row * col);
if row == 0 || col == 0 {
// Do nothing for empty matrix
return Ok(Vec::new());
}
let k = ::std::cmp::min(row, col);
let mut ipiv = unsafe { vec_uninit(k as usize) };
let mut info = 0;
unsafe { $getrf(l.lda(), l.len(), a, l.lda(), &mut ipiv, &mut info) };
info.as_lapack_result()?;
Ok(ipiv)
}
fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
let (n, _) = l.size();
if n == 0 {
// Do nothing for empty matrices.
return Ok(());
}
// calc work size
let mut info = 0;
let mut work_size = [Self::zero()];
unsafe { $getri(n, a, l.lda(), ipiv, &mut work_size, -1, &mut info) };
info.as_lapack_result()?;
// actual
let lwork = work_size[0].to_usize().unwrap();
let mut work = unsafe { vec_uninit(lwork) };
unsafe {
$getri(
l.len(),
a,
l.lda(),
ipiv,
&mut work,
lwork as i32,
&mut info,
)
};
info.as_lapack_result()?;
Ok(())
}
fn solve(
l: MatrixLayout,
t: Transpose,
a: &[Self],
ipiv: &Pivot,
b: &mut [Self],
) -> Result<()> {
// If the array has C layout, then it needs to be handled
// specially, since LAPACK expects a Fortran-layout array.
// Reinterpreting a C layout array as Fortran layout is
// equivalent to transposing it. So, we can handle the "no
// transpose" and "transpose" cases by swapping to "transpose"
// or "no transpose", respectively. For the "Hermite" case, we
// can take advantage of the following:
//
// ```text
// A^H x = b
// ⟺ conj(A^T) x = b
// ⟺ conj(conj(A^T) x) = conj(b)
// ⟺ conj(conj(A^T)) conj(x) = conj(b)
// ⟺ A^T conj(x) = conj(b)
// ```
//
// So, we can handle this case by switching to "no transpose"
// (which is equivalent to transposing the array since it will
// be reinterpreted as Fortran layout) and applying the
// elementwise conjugate to `x` and `b`.
let (t, conj) = match l {
MatrixLayout::C { .. } => match t {
Transpose::No => (Transpose::Transpose, false),
Transpose::Transpose => (Transpose::No, false),
Transpose::Hermite => (Transpose::No, true),
},
MatrixLayout::F { .. } => (t, false),
};
let (n, _) = l.size();
let nrhs = 1;
let ldb = l.lda();
let mut info = 0;
if conj {
for b_elem in &mut *b {
*b_elem = b_elem.conj();
}
}
unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) };
if conj {
for b_elem in &mut *b {
*b_elem = b_elem.conj();
}
}
info.as_lapack_result()?;
Ok(())
}
}
};
} // impl_solve!
impl_solve!(f64, lapack::dgetrf, lapack::dgetri, lapack::dgetrs);
impl_solve!(f32, lapack::sgetrf, lapack::sgetri, lapack::sgetrs);
impl_solve!(c64, lapack::zgetrf, lapack::zgetri, lapack::zgetrs);
impl_solve!(c32, lapack::cgetrf, lapack::cgetri, lapack::cgetrs);