lax/
solve.rs

1//! Solve linear problem using LU decomposition
2
3use crate::{error::*, layout::MatrixLayout, *};
4use cauchy::*;
5use num_traits::{ToPrimitive, Zero};
6
7pub trait Solve_: Scalar + Sized {
8    /// Computes the LU factorization of a general `m x n` matrix `a` using
9    /// partial pivoting with row interchanges.
10    ///
11    /// $ PA = LU $
12    ///
13    /// Error
14    /// ------
15    /// - `LapackComputationalFailure { return_code }` when the matrix is singular
16    ///   - Division by zero will occur if it is used to solve a system of equations
17    ///     because `U[(return_code-1, return_code-1)]` is exactly zero.
18    fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>;
19
20    fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>;
21
22    fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>;
23}
24
25macro_rules! impl_solve {
26    ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {
27        impl Solve_ for $scalar {
28            fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> {
29                let (row, col) = l.size();
30                assert_eq!(a.len() as i32, row * col);
31                if row == 0 || col == 0 {
32                    // Do nothing for empty matrix
33                    return Ok(Vec::new());
34                }
35                let k = ::std::cmp::min(row, col);
36                let mut ipiv = unsafe { vec_uninit(k as usize) };
37                let mut info = 0;
38                unsafe { $getrf(l.lda(), l.len(), a, l.lda(), &mut ipiv, &mut info) };
39                info.as_lapack_result()?;
40                Ok(ipiv)
41            }
42
43            fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
44                let (n, _) = l.size();
45                if n == 0 {
46                    // Do nothing for empty matrices.
47                    return Ok(());
48                }
49
50                // calc work size
51                let mut info = 0;
52                let mut work_size = [Self::zero()];
53                unsafe { $getri(n, a, l.lda(), ipiv, &mut work_size, -1, &mut info) };
54                info.as_lapack_result()?;
55
56                // actual
57                let lwork = work_size[0].to_usize().unwrap();
58                let mut work = unsafe { vec_uninit(lwork) };
59                unsafe {
60                    $getri(
61                        l.len(),
62                        a,
63                        l.lda(),
64                        ipiv,
65                        &mut work,
66                        lwork as i32,
67                        &mut info,
68                    )
69                };
70                info.as_lapack_result()?;
71
72                Ok(())
73            }
74
75            fn solve(
76                l: MatrixLayout,
77                t: Transpose,
78                a: &[Self],
79                ipiv: &Pivot,
80                b: &mut [Self],
81            ) -> Result<()> {
82                // If the array has C layout, then it needs to be handled
83                // specially, since LAPACK expects a Fortran-layout array.
84                // Reinterpreting a C layout array as Fortran layout is
85                // equivalent to transposing it. So, we can handle the "no
86                // transpose" and "transpose" cases by swapping to "transpose"
87                // or "no transpose", respectively. For the "Hermite" case, we
88                // can take advantage of the following:
89                //
90                // ```text
91                // A^H x = b
92                // ⟺ conj(A^T) x = b
93                // ⟺ conj(conj(A^T) x) = conj(b)
94                // ⟺ conj(conj(A^T)) conj(x) = conj(b)
95                // ⟺ A^T conj(x) = conj(b)
96                // ```
97                //
98                // So, we can handle this case by switching to "no transpose"
99                // (which is equivalent to transposing the array since it will
100                // be reinterpreted as Fortran layout) and applying the
101                // elementwise conjugate to `x` and `b`.
102                let (t, conj) = match l {
103                    MatrixLayout::C { .. } => match t {
104                        Transpose::No => (Transpose::Transpose, false),
105                        Transpose::Transpose => (Transpose::No, false),
106                        Transpose::Hermite => (Transpose::No, true),
107                    },
108                    MatrixLayout::F { .. } => (t, false),
109                };
110                let (n, _) = l.size();
111                let nrhs = 1;
112                let ldb = l.lda();
113                let mut info = 0;
114                if conj {
115                    for b_elem in &mut *b {
116                        *b_elem = b_elem.conj();
117                    }
118                }
119                unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) };
120                if conj {
121                    for b_elem in &mut *b {
122                        *b_elem = b_elem.conj();
123                    }
124                }
125                info.as_lapack_result()?;
126                Ok(())
127            }
128        }
129    };
130} // impl_solve!
131
132impl_solve!(f64, lapack::dgetrf, lapack::dgetri, lapack::dgetrs);
133impl_solve!(f32, lapack::sgetrf, lapack::sgetri, lapack::sgetrs);
134impl_solve!(c64, lapack::zgetrf, lapack::zgetri, lapack::zgetrs);
135impl_solve!(c32, lapack::cgetrf, lapack::cgetri, lapack::cgetrs);