lax/solve.rs
1//! Solve linear problem using LU decomposition
2
3use crate::{error::*, layout::MatrixLayout, *};
4use cauchy::*;
5use num_traits::{ToPrimitive, Zero};
6
7pub trait Solve_: Scalar + Sized {
8 /// Computes the LU factorization of a general `m x n` matrix `a` using
9 /// partial pivoting with row interchanges.
10 ///
11 /// $ PA = LU $
12 ///
13 /// Error
14 /// ------
15 /// - `LapackComputationalFailure { return_code }` when the matrix is singular
16 /// - Division by zero will occur if it is used to solve a system of equations
17 /// because `U[(return_code-1, return_code-1)]` is exactly zero.
18 fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>;
19
20 fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>;
21
22 fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>;
23}
24
25macro_rules! impl_solve {
26 ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {
27 impl Solve_ for $scalar {
28 fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> {
29 let (row, col) = l.size();
30 assert_eq!(a.len() as i32, row * col);
31 if row == 0 || col == 0 {
32 // Do nothing for empty matrix
33 return Ok(Vec::new());
34 }
35 let k = ::std::cmp::min(row, col);
36 let mut ipiv = unsafe { vec_uninit(k as usize) };
37 let mut info = 0;
38 unsafe { $getrf(l.lda(), l.len(), a, l.lda(), &mut ipiv, &mut info) };
39 info.as_lapack_result()?;
40 Ok(ipiv)
41 }
42
43 fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
44 let (n, _) = l.size();
45 if n == 0 {
46 // Do nothing for empty matrices.
47 return Ok(());
48 }
49
50 // calc work size
51 let mut info = 0;
52 let mut work_size = [Self::zero()];
53 unsafe { $getri(n, a, l.lda(), ipiv, &mut work_size, -1, &mut info) };
54 info.as_lapack_result()?;
55
56 // actual
57 let lwork = work_size[0].to_usize().unwrap();
58 let mut work = unsafe { vec_uninit(lwork) };
59 unsafe {
60 $getri(
61 l.len(),
62 a,
63 l.lda(),
64 ipiv,
65 &mut work,
66 lwork as i32,
67 &mut info,
68 )
69 };
70 info.as_lapack_result()?;
71
72 Ok(())
73 }
74
75 fn solve(
76 l: MatrixLayout,
77 t: Transpose,
78 a: &[Self],
79 ipiv: &Pivot,
80 b: &mut [Self],
81 ) -> Result<()> {
82 // If the array has C layout, then it needs to be handled
83 // specially, since LAPACK expects a Fortran-layout array.
84 // Reinterpreting a C layout array as Fortran layout is
85 // equivalent to transposing it. So, we can handle the "no
86 // transpose" and "transpose" cases by swapping to "transpose"
87 // or "no transpose", respectively. For the "Hermite" case, we
88 // can take advantage of the following:
89 //
90 // ```text
91 // A^H x = b
92 // ⟺ conj(A^T) x = b
93 // ⟺ conj(conj(A^T) x) = conj(b)
94 // ⟺ conj(conj(A^T)) conj(x) = conj(b)
95 // ⟺ A^T conj(x) = conj(b)
96 // ```
97 //
98 // So, we can handle this case by switching to "no transpose"
99 // (which is equivalent to transposing the array since it will
100 // be reinterpreted as Fortran layout) and applying the
101 // elementwise conjugate to `x` and `b`.
102 let (t, conj) = match l {
103 MatrixLayout::C { .. } => match t {
104 Transpose::No => (Transpose::Transpose, false),
105 Transpose::Transpose => (Transpose::No, false),
106 Transpose::Hermite => (Transpose::No, true),
107 },
108 MatrixLayout::F { .. } => (t, false),
109 };
110 let (n, _) = l.size();
111 let nrhs = 1;
112 let ldb = l.lda();
113 let mut info = 0;
114 if conj {
115 for b_elem in &mut *b {
116 *b_elem = b_elem.conj();
117 }
118 }
119 unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) };
120 if conj {
121 for b_elem in &mut *b {
122 *b_elem = b_elem.conj();
123 }
124 }
125 info.as_lapack_result()?;
126 Ok(())
127 }
128 }
129 };
130} // impl_solve!
131
132impl_solve!(f64, lapack::dgetrf, lapack::dgetri, lapack::dgetrs);
133impl_solve!(f32, lapack::sgetrf, lapack::sgetri, lapack::sgetrs);
134impl_solve!(c64, lapack::zgetrf, lapack::zgetri, lapack::zgetrs);
135impl_solve!(c32, lapack::cgetrf, lapack::cgetri, lapack::cgetrs);