1use crate::{error::*, layout::*, *};
4use cauchy::*;
5
6#[derive(Debug, Clone, Copy)]
7#[repr(u8)]
8pub enum Diag {
9 Unit = b'U',
10 NonUnit = b'N',
11}
12
13pub trait Triangular_: Scalar {
15 fn solve_triangular(
16 al: MatrixLayout,
17 bl: MatrixLayout,
18 uplo: UPLO,
19 d: Diag,
20 a: &[Self],
21 b: &mut [Self],
22 ) -> Result<()>;
23}
24
25macro_rules! impl_triangular {
26 ($scalar:ty, $trtri:path, $trtrs:path) => {
27 impl Triangular_ for $scalar {
28 fn solve_triangular(
29 a_layout: MatrixLayout,
30 b_layout: MatrixLayout,
31 uplo: UPLO,
32 diag: Diag,
33 a: &[Self],
34 b: &mut [Self],
35 ) -> Result<()> {
36 let mut a_t = None;
38 let a_layout = match a_layout {
39 MatrixLayout::C { .. } => {
40 a_t = Some(unsafe { vec_uninit(a.len()) });
41 transpose(a_layout, a, a_t.as_mut().unwrap())
42 }
43 MatrixLayout::F { .. } => a_layout,
44 };
45
46 let mut b_t = None;
48 let b_layout = match b_layout {
49 MatrixLayout::C { .. } => {
50 b_t = Some(unsafe { vec_uninit(b.len()) });
51 transpose(b_layout, b, b_t.as_mut().unwrap())
52 }
53 MatrixLayout::F { .. } => b_layout,
54 };
55
56 let (m, n) = a_layout.size();
57 let (n_, nrhs) = b_layout.size();
58 assert_eq!(n, n_);
59
60 let mut info = 0;
61 unsafe {
62 $trtrs(
63 uplo as u8,
64 Transpose::No as u8,
65 diag as u8,
66 m,
67 nrhs,
68 a_t.as_ref().map(|v| v.as_slice()).unwrap_or(a),
69 a_layout.lda(),
70 b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b),
71 b_layout.lda(),
72 &mut info,
73 );
74 }
75 info.as_lapack_result()?;
76
77 if let Some(b_t) = b_t {
79 transpose(b_layout, &b_t, b);
80 }
81 Ok(())
82 }
83 }
84 };
85} impl_triangular!(f64, lapack::dtrtri, lapack::dtrtrs);
88impl_triangular!(f32, lapack::strtri, lapack::strtrs);
89impl_triangular!(c64, lapack::ztrtri, lapack::ztrtrs);
90impl_triangular!(c32, lapack::ctrtri, lapack::ctrtrs);