lax/
cholesky.rs

1//! Cholesky decomposition
2
3use super::*;
4use crate::{error::*, layout::*};
5use cauchy::*;
6
7pub trait Cholesky_: Sized {
8    /// Cholesky: wrapper of `*potrf`
9    ///
10    /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.**
11    fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
12
13    /// Wrapper of `*potri`
14    ///
15    /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.**
16    fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
17
18    /// Wrapper of `*potrs`
19    fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>;
20}
21
22macro_rules! impl_cholesky {
23    ($scalar:ty, $trf:path, $tri:path, $trs:path) => {
24        impl Cholesky_ for $scalar {
25            fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
26                let (n, _) = l.size();
27                if matches!(l, MatrixLayout::C { .. }) {
28                    square_transpose(l, a);
29                }
30                let mut info = 0;
31                unsafe {
32                    $trf(uplo as u8, n, a, n, &mut info);
33                }
34                info.as_lapack_result()?;
35                if matches!(l, MatrixLayout::C { .. }) {
36                    square_transpose(l, a);
37                }
38                Ok(())
39            }
40
41            fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
42                let (n, _) = l.size();
43                if matches!(l, MatrixLayout::C { .. }) {
44                    square_transpose(l, a);
45                }
46                let mut info = 0;
47                unsafe {
48                    $tri(uplo as u8, n, a, l.lda(), &mut info);
49                }
50                info.as_lapack_result()?;
51                if matches!(l, MatrixLayout::C { .. }) {
52                    square_transpose(l, a);
53                }
54                Ok(())
55            }
56
57            fn solve_cholesky(
58                l: MatrixLayout,
59                mut uplo: UPLO,
60                a: &[Self],
61                b: &mut [Self],
62            ) -> Result<()> {
63                let (n, _) = l.size();
64                let nrhs = 1;
65                let mut info = 0;
66                if matches!(l, MatrixLayout::C { .. }) {
67                    uplo = uplo.t();
68                    for val in b.iter_mut() {
69                        *val = val.conj();
70                    }
71                }
72                unsafe {
73                    $trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info);
74                }
75                info.as_lapack_result()?;
76                if matches!(l, MatrixLayout::C { .. }) {
77                    for val in b.iter_mut() {
78                        *val = val.conj();
79                    }
80                }
81                Ok(())
82            }
83        }
84    };
85} // end macro_rules
86
87impl_cholesky!(f64, lapack::dpotrf, lapack::dpotri, lapack::dpotrs);
88impl_cholesky!(f32, lapack::spotrf, lapack::spotri, lapack::spotrs);
89impl_cholesky!(c64, lapack::zpotrf, lapack::zpotri, lapack::zpotrs);
90impl_cholesky!(c32, lapack::cpotrf, lapack::cpotri, lapack::cpotrs);