1use crate::{error::*, layout::MatrixLayout, *};
6use cauchy::*;
7use num_traits::{ToPrimitive, Zero};
8
9pub trait Solveh_: Sized {
10 fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot>;
12 fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>;
14 fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>;
16}
17
18macro_rules! impl_solveh {
19 ($scalar:ty, $trf:path, $tri:path, $trs:path) => {
20 impl Solveh_ for $scalar {
21 fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot> {
22 let (n, _) = l.size();
23 let mut ipiv = unsafe { vec_uninit(n as usize) };
24 if n == 0 {
25 return Ok(Vec::new());
26 }
27
28 let mut info = 0;
30 let mut work_size = [Self::zero()];
31 unsafe {
32 $trf(
33 uplo as u8,
34 n,
35 a,
36 l.lda(),
37 &mut ipiv,
38 &mut work_size,
39 -1,
40 &mut info,
41 )
42 };
43 info.as_lapack_result()?;
44
45 let lwork = work_size[0].to_usize().unwrap();
47 let mut work = unsafe { vec_uninit(lwork) };
48 unsafe {
49 $trf(
50 uplo as u8,
51 n,
52 a,
53 l.lda(),
54 &mut ipiv,
55 &mut work,
56 lwork as i32,
57 &mut info,
58 )
59 };
60 info.as_lapack_result()?;
61 Ok(ipiv)
62 }
63
64 fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
65 let (n, _) = l.size();
66 let mut info = 0;
67 let mut work = unsafe { vec_uninit(n as usize) };
68 unsafe { $tri(uplo as u8, n, a, l.lda(), ipiv, &mut work, &mut info) };
69 info.as_lapack_result()?;
70 Ok(())
71 }
72
73 fn solveh(
74 l: MatrixLayout,
75 uplo: UPLO,
76 a: &[Self],
77 ipiv: &Pivot,
78 b: &mut [Self],
79 ) -> Result<()> {
80 let (n, _) = l.size();
81 let mut info = 0;
82 unsafe { $trs(uplo as u8, n, 1, a, l.lda(), ipiv, b, n, &mut info) };
83 info.as_lapack_result()?;
84 Ok(())
85 }
86 }
87 };
88} impl_solveh!(f64, lapack::dsytrf, lapack::dsytri, lapack::dsytrs);
91impl_solveh!(f32, lapack::ssytrf, lapack::ssytri, lapack::ssytrs);
92impl_solveh!(c64, lapack::zhetrf, lapack::zhetri, lapack::zhetrs);
93impl_solveh!(c32, lapack::chetrf, lapack::chetri, lapack::chetrs);