lax/
triangular.rs

1//! Implement linear solver and inverse matrix
2
3use crate::{error::*, layout::*, *};
4use cauchy::*;
5
6#[derive(Debug, Clone, Copy)]
7#[repr(u8)]
8pub enum Diag {
9    Unit = b'U',
10    NonUnit = b'N',
11}
12
13/// Wraps `*trtri` and `*trtrs`
14pub trait Triangular_: Scalar {
15    fn solve_triangular(
16        al: MatrixLayout,
17        bl: MatrixLayout,
18        uplo: UPLO,
19        d: Diag,
20        a: &[Self],
21        b: &mut [Self],
22    ) -> Result<()>;
23}
24
25macro_rules! impl_triangular {
26    ($scalar:ty, $trtri:path, $trtrs:path) => {
27        impl Triangular_ for $scalar {
28            fn solve_triangular(
29                a_layout: MatrixLayout,
30                b_layout: MatrixLayout,
31                uplo: UPLO,
32                diag: Diag,
33                a: &[Self],
34                b: &mut [Self],
35            ) -> Result<()> {
36                // Transpose if a is C-continuous
37                let mut a_t = None;
38                let a_layout = match a_layout {
39                    MatrixLayout::C { .. } => {
40                        a_t = Some(unsafe { vec_uninit(a.len()) });
41                        transpose(a_layout, a, a_t.as_mut().unwrap())
42                    }
43                    MatrixLayout::F { .. } => a_layout,
44                };
45
46                // Transpose if b is C-continuous
47                let mut b_t = None;
48                let b_layout = match b_layout {
49                    MatrixLayout::C { .. } => {
50                        b_t = Some(unsafe { vec_uninit(b.len()) });
51                        transpose(b_layout, b, b_t.as_mut().unwrap())
52                    }
53                    MatrixLayout::F { .. } => b_layout,
54                };
55
56                let (m, n) = a_layout.size();
57                let (n_, nrhs) = b_layout.size();
58                assert_eq!(n, n_);
59
60                let mut info = 0;
61                unsafe {
62                    $trtrs(
63                        uplo as u8,
64                        Transpose::No as u8,
65                        diag as u8,
66                        m,
67                        nrhs,
68                        a_t.as_ref().map(|v| v.as_slice()).unwrap_or(a),
69                        a_layout.lda(),
70                        b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b),
71                        b_layout.lda(),
72                        &mut info,
73                    );
74                }
75                info.as_lapack_result()?;
76
77                // Re-transpose b
78                if let Some(b_t) = b_t {
79                    transpose(b_layout, &b_t, b);
80                }
81                Ok(())
82            }
83        }
84    };
85} // impl_triangular!
86
87impl_triangular!(f64, lapack::dtrtri, lapack::dtrtrs);
88impl_triangular!(f32, lapack::strtri, lapack::strtrs);
89impl_triangular!(c64, lapack::ztrtri, lapack::ztrtrs);
90impl_triangular!(c32, lapack::ctrtri, lapack::ctrtrs);