lax/
tridiagonal.rs

1//! Implement linear solver using LU decomposition
2//! for tridiagonal matrix
3
4use crate::{error::*, layout::*, *};
5use cauchy::*;
6use num_traits::Zero;
7use std::ops::{Index, IndexMut};
8
9/// Represents a tridiagonal matrix as 3 one-dimensional vectors.
10///
11/// ```text
12/// [d0, u1,  0,   ...,       0,
13///  l1, d1, u2,            ...,
14///   0, l2, d2,
15///  ...           ...,  u{n-1},
16///   0,  ...,  l{n-1},  d{n-1},]
17/// ```
18#[derive(Clone, PartialEq, Eq)]
19pub struct Tridiagonal<A: Scalar> {
20    /// layout of raw matrix
21    pub l: MatrixLayout,
22    /// (n-1) sub-diagonal elements of matrix.
23    pub dl: Vec<A>,
24    /// (n) diagonal elements of matrix.
25    pub d: Vec<A>,
26    /// (n-1) super-diagonal elements of matrix.
27    pub du: Vec<A>,
28}
29
30impl<A: Scalar> Tridiagonal<A> {
31    fn opnorm_one(&self) -> A::Real {
32        let mut col_sum: Vec<A::Real> = self.d.iter().map(|val| val.abs()).collect();
33        for i in 0..col_sum.len() {
34            if i < self.dl.len() {
35                col_sum[i] += self.dl[i].abs();
36            }
37            if i > 0 {
38                col_sum[i] += self.du[i - 1].abs();
39            }
40        }
41        let mut max = A::Real::zero();
42        for &val in &col_sum {
43            if max < val {
44                max = val;
45            }
46        }
47        max
48    }
49}
50
51/// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`.
52#[derive(Clone, PartialEq)]
53pub struct LUFactorizedTridiagonal<A: Scalar> {
54    /// A tridiagonal matrix which consists of
55    /// - l : layout of raw matrix
56    /// - dl: (n-1) multipliers that define the matrix L.
57    /// - d : (n) diagonal elements of the upper triangular matrix U.
58    /// - du: (n-1) elements of the first super-diagonal of U.
59    pub a: Tridiagonal<A>,
60    /// (n-2) elements of the second super-diagonal of U.
61    pub du2: Vec<A>,
62    /// The pivot indices that define the permutation matrix `P`.
63    pub ipiv: Pivot,
64
65    a_opnorm_one: A::Real,
66}
67
68impl<A: Scalar> Index<(i32, i32)> for Tridiagonal<A> {
69    type Output = A;
70    #[inline]
71    fn index(&self, (row, col): (i32, i32)) -> &A {
72        let (n, _) = self.l.size();
73        assert!(
74            std::cmp::max(row, col) < n,
75            "ndarray: index {:?} is out of bounds for array of shape {}",
76            [row, col],
77            n
78        );
79        match row - col {
80            0 => &self.d[row as usize],
81            1 => &self.dl[col as usize],
82            -1 => &self.du[row as usize],
83            _ => panic!(
84                "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
85                [row, col]
86            ),
87        }
88    }
89}
90
91impl<A: Scalar> Index<[i32; 2]> for Tridiagonal<A> {
92    type Output = A;
93    #[inline]
94    fn index(&self, [row, col]: [i32; 2]) -> &A {
95        &self[(row, col)]
96    }
97}
98
99impl<A: Scalar> IndexMut<(i32, i32)> for Tridiagonal<A> {
100    #[inline]
101    fn index_mut(&mut self, (row, col): (i32, i32)) -> &mut A {
102        let (n, _) = self.l.size();
103        assert!(
104            std::cmp::max(row, col) < n,
105            "ndarray: index {:?} is out of bounds for array of shape {}",
106            [row, col],
107            n
108        );
109        match row - col {
110            0 => &mut self.d[row as usize],
111            1 => &mut self.dl[col as usize],
112            -1 => &mut self.du[row as usize],
113            _ => panic!(
114                "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
115                [row, col]
116            ),
117        }
118    }
119}
120
121impl<A: Scalar> IndexMut<[i32; 2]> for Tridiagonal<A> {
122    #[inline]
123    fn index_mut(&mut self, [row, col]: [i32; 2]) -> &mut A {
124        &mut self[(row, col)]
125    }
126}
127
128/// Wraps `*gttrf`, `*gtcon` and `*gttrs`
129pub trait Tridiagonal_: Scalar + Sized {
130    /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
131    /// partial pivoting with row interchanges.
132    fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>>;
133
134    fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real>;
135
136    fn solve_tridiagonal(
137        lu: &LUFactorizedTridiagonal<Self>,
138        bl: MatrixLayout,
139        t: Transpose,
140        b: &mut [Self],
141    ) -> Result<()>;
142}
143
144macro_rules! impl_tridiagonal {
145    (@real, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
146        impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, iwork);
147    };
148    (@complex, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
149        impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, );
150    };
151    (@body, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path, $($iwork:ident)*) => {
152        impl Tridiagonal_ for $scalar {
153            fn lu_tridiagonal(mut a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>> {
154                let (n, _) = a.l.size();
155                let mut du2 = unsafe { vec_uninit( (n - 2) as usize) };
156                let mut ipiv = unsafe { vec_uninit( n as usize) };
157                // We have to calc one-norm before LU factorization
158                let a_opnorm_one = a.opnorm_one();
159                let mut info = 0;
160                unsafe { $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv, &mut info,) };
161                info.as_lapack_result()?;
162                Ok(LUFactorizedTridiagonal {
163                    a,
164                    du2,
165                    ipiv,
166                    a_opnorm_one,
167                })
168            }
169
170            fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
171                let (n, _) = lu.a.l.size();
172                let ipiv = &lu.ipiv;
173                let mut work = unsafe { vec_uninit( 2 * n as usize) };
174                $(
175                let mut $iwork = unsafe { vec_uninit( n as usize) };
176                )*
177                let mut rcond = Self::Real::zero();
178                let mut info = 0;
179                unsafe {
180                    $gtcon(
181                        NormType::One as u8,
182                        n,
183                        &lu.a.dl,
184                        &lu.a.d,
185                        &lu.a.du,
186                        &lu.du2,
187                        ipiv,
188                        lu.a_opnorm_one,
189                        &mut rcond,
190                        &mut work,
191                        $(&mut $iwork,)*
192                        &mut info,
193                    );
194                }
195                info.as_lapack_result()?;
196                Ok(rcond)
197            }
198
199            fn solve_tridiagonal(
200                lu: &LUFactorizedTridiagonal<Self>,
201                b_layout: MatrixLayout,
202                t: Transpose,
203                b: &mut [Self],
204            ) -> Result<()> {
205                let (n, _) = lu.a.l.size();
206                let ipiv = &lu.ipiv;
207                // Transpose if b is C-continuous
208                let mut b_t = None;
209                let b_layout = match b_layout {
210                    MatrixLayout::C { .. } => {
211                        b_t = Some(unsafe { vec_uninit( b.len()) });
212                        transpose(b_layout, b, b_t.as_mut().unwrap())
213                    }
214                    MatrixLayout::F { .. } => b_layout,
215                };
216                let (ldb, nrhs) = b_layout.size();
217                let mut info = 0;
218                unsafe {
219                    $gttrs(
220                        t as u8,
221                        n,
222                        nrhs,
223                        &lu.a.dl,
224                        &lu.a.d,
225                        &lu.a.du,
226                        &lu.du2,
227                        ipiv,
228                        b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b),
229                        ldb,
230                        &mut info,
231                    );
232                }
233                info.as_lapack_result()?;
234                if let Some(b_t) = b_t {
235                    transpose(b_layout, &b_t, b);
236                }
237                Ok(())
238            }
239        }
240    };
241} // impl_tridiagonal!
242
243impl_tridiagonal!(@real, f64, lapack::dgttrf, lapack::dgtcon, lapack::dgttrs);
244impl_tridiagonal!(@real, f32, lapack::sgttrf, lapack::sgtcon, lapack::sgttrs);
245impl_tridiagonal!(@complex, c64, lapack::zgttrf, lapack::zgtcon, lapack::zgttrs);
246impl_tridiagonal!(@complex, c32, lapack::cgttrf, lapack::cgtcon, lapack::cgttrs);