1use super::*;
4use crate::{error::*, layout::*};
5use cauchy::*;
6
7pub trait Cholesky_: Sized {
8 fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
12
13 fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
17
18 fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>;
20}
21
22macro_rules! impl_cholesky {
23 ($scalar:ty, $trf:path, $tri:path, $trs:path) => {
24 impl Cholesky_ for $scalar {
25 fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
26 let (n, _) = l.size();
27 if matches!(l, MatrixLayout::C { .. }) {
28 square_transpose(l, a);
29 }
30 let mut info = 0;
31 unsafe {
32 $trf(uplo as u8, n, a, n, &mut info);
33 }
34 info.as_lapack_result()?;
35 if matches!(l, MatrixLayout::C { .. }) {
36 square_transpose(l, a);
37 }
38 Ok(())
39 }
40
41 fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
42 let (n, _) = l.size();
43 if matches!(l, MatrixLayout::C { .. }) {
44 square_transpose(l, a);
45 }
46 let mut info = 0;
47 unsafe {
48 $tri(uplo as u8, n, a, l.lda(), &mut info);
49 }
50 info.as_lapack_result()?;
51 if matches!(l, MatrixLayout::C { .. }) {
52 square_transpose(l, a);
53 }
54 Ok(())
55 }
56
57 fn solve_cholesky(
58 l: MatrixLayout,
59 mut uplo: UPLO,
60 a: &[Self],
61 b: &mut [Self],
62 ) -> Result<()> {
63 let (n, _) = l.size();
64 let nrhs = 1;
65 let mut info = 0;
66 if matches!(l, MatrixLayout::C { .. }) {
67 uplo = uplo.t();
68 for val in b.iter_mut() {
69 *val = val.conj();
70 }
71 }
72 unsafe {
73 $trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info);
74 }
75 info.as_lapack_result()?;
76 if matches!(l, MatrixLayout::C { .. }) {
77 for val in b.iter_mut() {
78 *val = val.conj();
79 }
80 }
81 Ok(())
82 }
83 }
84 };
85} impl_cholesky!(f64, lapack::dpotrf, lapack::dpotri, lapack::dpotrs);
88impl_cholesky!(f32, lapack::spotrf, lapack::spotri, lapack::spotrs);
89impl_cholesky!(c64, lapack::zpotrf, lapack::zpotri, lapack::zpotrs);
90impl_cholesky!(c32, lapack::cpotrf, lapack::cpotri, lapack::cpotrs);