1use crate::{error::*, layout::*, *};
5use cauchy::*;
6use num_traits::Zero;
7use std::ops::{Index, IndexMut};
8
9#[derive(Clone, PartialEq, Eq)]
19pub struct Tridiagonal<A: Scalar> {
20 pub l: MatrixLayout,
22 pub dl: Vec<A>,
24 pub d: Vec<A>,
26 pub du: Vec<A>,
28}
29
30impl<A: Scalar> Tridiagonal<A> {
31 fn opnorm_one(&self) -> A::Real {
32 let mut col_sum: Vec<A::Real> = self.d.iter().map(|val| val.abs()).collect();
33 for i in 0..col_sum.len() {
34 if i < self.dl.len() {
35 col_sum[i] += self.dl[i].abs();
36 }
37 if i > 0 {
38 col_sum[i] += self.du[i - 1].abs();
39 }
40 }
41 let mut max = A::Real::zero();
42 for &val in &col_sum {
43 if max < val {
44 max = val;
45 }
46 }
47 max
48 }
49}
50
51#[derive(Clone, PartialEq)]
53pub struct LUFactorizedTridiagonal<A: Scalar> {
54 pub a: Tridiagonal<A>,
60 pub du2: Vec<A>,
62 pub ipiv: Pivot,
64
65 a_opnorm_one: A::Real,
66}
67
68impl<A: Scalar> Index<(i32, i32)> for Tridiagonal<A> {
69 type Output = A;
70 #[inline]
71 fn index(&self, (row, col): (i32, i32)) -> &A {
72 let (n, _) = self.l.size();
73 assert!(
74 std::cmp::max(row, col) < n,
75 "ndarray: index {:?} is out of bounds for array of shape {}",
76 [row, col],
77 n
78 );
79 match row - col {
80 0 => &self.d[row as usize],
81 1 => &self.dl[col as usize],
82 -1 => &self.du[row as usize],
83 _ => panic!(
84 "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
85 [row, col]
86 ),
87 }
88 }
89}
90
91impl<A: Scalar> Index<[i32; 2]> for Tridiagonal<A> {
92 type Output = A;
93 #[inline]
94 fn index(&self, [row, col]: [i32; 2]) -> &A {
95 &self[(row, col)]
96 }
97}
98
99impl<A: Scalar> IndexMut<(i32, i32)> for Tridiagonal<A> {
100 #[inline]
101 fn index_mut(&mut self, (row, col): (i32, i32)) -> &mut A {
102 let (n, _) = self.l.size();
103 assert!(
104 std::cmp::max(row, col) < n,
105 "ndarray: index {:?} is out of bounds for array of shape {}",
106 [row, col],
107 n
108 );
109 match row - col {
110 0 => &mut self.d[row as usize],
111 1 => &mut self.dl[col as usize],
112 -1 => &mut self.du[row as usize],
113 _ => panic!(
114 "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
115 [row, col]
116 ),
117 }
118 }
119}
120
121impl<A: Scalar> IndexMut<[i32; 2]> for Tridiagonal<A> {
122 #[inline]
123 fn index_mut(&mut self, [row, col]: [i32; 2]) -> &mut A {
124 &mut self[(row, col)]
125 }
126}
127
128pub trait Tridiagonal_: Scalar + Sized {
130 fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>>;
133
134 fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real>;
135
136 fn solve_tridiagonal(
137 lu: &LUFactorizedTridiagonal<Self>,
138 bl: MatrixLayout,
139 t: Transpose,
140 b: &mut [Self],
141 ) -> Result<()>;
142}
143
144macro_rules! impl_tridiagonal {
145 (@real, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
146 impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, iwork);
147 };
148 (@complex, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
149 impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, );
150 };
151 (@body, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path, $($iwork:ident)*) => {
152 impl Tridiagonal_ for $scalar {
153 fn lu_tridiagonal(mut a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>> {
154 let (n, _) = a.l.size();
155 let mut du2 = unsafe { vec_uninit( (n - 2) as usize) };
156 let mut ipiv = unsafe { vec_uninit( n as usize) };
157 let a_opnorm_one = a.opnorm_one();
159 let mut info = 0;
160 unsafe { $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv, &mut info,) };
161 info.as_lapack_result()?;
162 Ok(LUFactorizedTridiagonal {
163 a,
164 du2,
165 ipiv,
166 a_opnorm_one,
167 })
168 }
169
170 fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
171 let (n, _) = lu.a.l.size();
172 let ipiv = &lu.ipiv;
173 let mut work = unsafe { vec_uninit( 2 * n as usize) };
174 $(
175 let mut $iwork = unsafe { vec_uninit( n as usize) };
176 )*
177 let mut rcond = Self::Real::zero();
178 let mut info = 0;
179 unsafe {
180 $gtcon(
181 NormType::One as u8,
182 n,
183 &lu.a.dl,
184 &lu.a.d,
185 &lu.a.du,
186 &lu.du2,
187 ipiv,
188 lu.a_opnorm_one,
189 &mut rcond,
190 &mut work,
191 $(&mut $iwork,)*
192 &mut info,
193 );
194 }
195 info.as_lapack_result()?;
196 Ok(rcond)
197 }
198
199 fn solve_tridiagonal(
200 lu: &LUFactorizedTridiagonal<Self>,
201 b_layout: MatrixLayout,
202 t: Transpose,
203 b: &mut [Self],
204 ) -> Result<()> {
205 let (n, _) = lu.a.l.size();
206 let ipiv = &lu.ipiv;
207 let mut b_t = None;
209 let b_layout = match b_layout {
210 MatrixLayout::C { .. } => {
211 b_t = Some(unsafe { vec_uninit( b.len()) });
212 transpose(b_layout, b, b_t.as_mut().unwrap())
213 }
214 MatrixLayout::F { .. } => b_layout,
215 };
216 let (ldb, nrhs) = b_layout.size();
217 let mut info = 0;
218 unsafe {
219 $gttrs(
220 t as u8,
221 n,
222 nrhs,
223 &lu.a.dl,
224 &lu.a.d,
225 &lu.a.du,
226 &lu.du2,
227 ipiv,
228 b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b),
229 ldb,
230 &mut info,
231 );
232 }
233 info.as_lapack_result()?;
234 if let Some(b_t) = b_t {
235 transpose(b_layout, &b_t, b);
236 }
237 Ok(())
238 }
239 }
240 };
241} impl_tridiagonal!(@real, f64, lapack::dgttrf, lapack::dgtcon, lapack::dgttrs);
244impl_tridiagonal!(@real, f32, lapack::sgttrf, lapack::sgtcon, lapack::sgttrs);
245impl_tridiagonal!(@complex, c64, lapack::zgttrf, lapack::zgtcon, lapack::zgttrs);
246impl_tridiagonal!(@complex, c32, lapack::cgttrf, lapack::cgtcon, lapack::cgttrs);