lax/
solveh.rs

1//! Solve symmetric linear problem using the Bunch-Kaufman diagonal pivoting method.
2//!
3//! See also [the manual of dsytrf](http://www.netlib.org/lapack/lapack-3.1.1/html/dsytrf.f.html)
4
5use crate::{error::*, layout::MatrixLayout, *};
6use cauchy::*;
7use num_traits::{ToPrimitive, Zero};
8
9pub trait Solveh_: Sized {
10    /// Bunch-Kaufman: wrapper of `*sytrf` and `*hetrf`
11    fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot>;
12    /// Wrapper of `*sytri` and `*hetri`
13    fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>;
14    /// Wrapper of `*sytrs` and `*hetrs`
15    fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>;
16}
17
18macro_rules! impl_solveh {
19    ($scalar:ty, $trf:path, $tri:path, $trs:path) => {
20        impl Solveh_ for $scalar {
21            fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot> {
22                let (n, _) = l.size();
23                let mut ipiv = unsafe { vec_uninit(n as usize) };
24                if n == 0 {
25                    return Ok(Vec::new());
26                }
27
28                // calc work size
29                let mut info = 0;
30                let mut work_size = [Self::zero()];
31                unsafe {
32                    $trf(
33                        uplo as u8,
34                        n,
35                        a,
36                        l.lda(),
37                        &mut ipiv,
38                        &mut work_size,
39                        -1,
40                        &mut info,
41                    )
42                };
43                info.as_lapack_result()?;
44
45                // actual
46                let lwork = work_size[0].to_usize().unwrap();
47                let mut work = unsafe { vec_uninit(lwork) };
48                unsafe {
49                    $trf(
50                        uplo as u8,
51                        n,
52                        a,
53                        l.lda(),
54                        &mut ipiv,
55                        &mut work,
56                        lwork as i32,
57                        &mut info,
58                    )
59                };
60                info.as_lapack_result()?;
61                Ok(ipiv)
62            }
63
64            fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
65                let (n, _) = l.size();
66                let mut info = 0;
67                let mut work = unsafe { vec_uninit(n as usize) };
68                unsafe { $tri(uplo as u8, n, a, l.lda(), ipiv, &mut work, &mut info) };
69                info.as_lapack_result()?;
70                Ok(())
71            }
72
73            fn solveh(
74                l: MatrixLayout,
75                uplo: UPLO,
76                a: &[Self],
77                ipiv: &Pivot,
78                b: &mut [Self],
79            ) -> Result<()> {
80                let (n, _) = l.size();
81                let mut info = 0;
82                unsafe { $trs(uplo as u8, n, 1, a, l.lda(), ipiv, b, n, &mut info) };
83                info.as_lapack_result()?;
84                Ok(())
85            }
86        }
87    };
88} // impl_solveh!
89
90impl_solveh!(f64, lapack::dsytrf, lapack::dsytri, lapack::dsytrs);
91impl_solveh!(f32, lapack::ssytrf, lapack::ssytri, lapack::ssytrs);
92impl_solveh!(c64, lapack::zhetrf, lapack::zhetri, lapack::zhetrs);
93impl_solveh!(c32, lapack::chetrf, lapack::chetri, lapack::chetrs);