ndarray_linalg/solveh.rs
1//! Solve Hermitian (or real symmetric) linear problems and invert Hermitian
2//! (or real symmetric) matrices
3//!
4//! **Note that only the upper triangular portion of the matrix is used.**
5//!
6//! # Examples
7//!
8//! Solve `A * x = b`, where `A` is a Hermitian (or real symmetric) matrix:
9//!
10//! ```
11//! #[macro_use]
12//! extern crate ndarray;
13//! extern crate ndarray_linalg;
14//!
15//! use ndarray::prelude::*;
16//! use ndarray_linalg::SolveH;
17//! # fn main() {
18//!
19//! let a: Array2<f64> = array![
20//! [3., 2., -1.],
21//! [2., -2., 4.],
22//! [-1., 4., 5.]
23//! ];
24//! let b: Array1<f64> = array![11., -12., 1.];
25//! let x = a.solveh_into(b).unwrap();
26//! assert!(x.abs_diff_eq(&array![1., 3., -2.], 1e-9));
27//!
28//! # }
29//! ```
30//!
31//! If you are solving multiple systems of linear equations with the same
32//! Hermitian or real symmetric coefficient matrix `A`, it's faster to compute
33//! the factorization once at the beginning than solving directly using `A`:
34//!
35//! ```
36//! # extern crate ndarray;
37//! # extern crate ndarray_linalg;
38//! use ndarray::prelude::*;
39//! use ndarray_linalg::*;
40//! # fn main() {
41//!
42//! let a: Array2<f64> = random((3, 3));
43//! let f = a.factorizeh_into().unwrap(); // Factorize A (A is consumed)
44//! for _ in 0..10 {
45//! let b: Array1<f64> = random(3);
46//! let x = f.solveh_into(b).unwrap(); // Solve A * x = b using the factorization
47//! }
48//!
49//! # }
50//! ```
51
52use ndarray::*;
53use num_traits::{Float, One, Zero};
54
55use crate::convert::*;
56use crate::error::*;
57use crate::layout::*;
58use crate::types::*;
59
60pub use lax::{Pivot, UPLO};
61
62/// An interface for solving systems of Hermitian (or real symmetric) linear equations.
63///
64/// If you plan to solve many equations with the same Hermitian (or real
65/// symmetric) coefficient matrix `A` but different `b` vectors, it's faster to
66/// factor the `A` matrix once using the `FactorizeH` trait, and then solve
67/// using the `BKFactorized` struct.
68pub trait SolveH<A: Scalar> {
69 /// Solves a system of linear equations `A * x = b` with Hermitian (or real
70 /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
71 /// `x` is the successful result.
72 ///
73 /// # Panics
74 ///
75 /// Panics if the length of `b` is not the equal to the number of columns
76 /// of `A`.
77 fn solveh<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
78 let mut b = replicate(b);
79 self.solveh_inplace(&mut b)?;
80 Ok(b)
81 }
82
83 /// Solves a system of linear equations `A * x = b` with Hermitian (or real
84 /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
85 /// `x` is the successful result.
86 ///
87 /// # Panics
88 ///
89 /// Panics if the length of `b` is not the equal to the number of columns
90 /// of `A`.
91 fn solveh_into<S: DataMut<Elem = A>>(
92 &self,
93 mut b: ArrayBase<S, Ix1>,
94 ) -> Result<ArrayBase<S, Ix1>> {
95 self.solveh_inplace(&mut b)?;
96 Ok(b)
97 }
98
99 /// Solves a system of linear equations `A * x = b` with Hermitian (or real
100 /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
101 /// `x` is the successful result. The value of `x` is also assigned to the
102 /// argument.
103 ///
104 /// # Panics
105 ///
106 /// Panics if the length of `b` is not the equal to the number of columns
107 /// of `A`.
108 fn solveh_inplace<'a, S: DataMut<Elem = A>>(
109 &self,
110 b: &'a mut ArrayBase<S, Ix1>,
111 ) -> Result<&'a mut ArrayBase<S, Ix1>>;
112}
113
114/// Represents the Bunch–Kaufman factorization of a Hermitian (or real
115/// symmetric) matrix as `A = P * U * D * U^H * P^T`.
116pub struct BKFactorized<S: Data> {
117 pub a: ArrayBase<S, Ix2>,
118 pub ipiv: Pivot,
119}
120
121impl<A, S> SolveH<A> for BKFactorized<S>
122where
123 A: Scalar + Lapack,
124 S: Data<Elem = A>,
125{
126 fn solveh_inplace<'a, Sb>(
127 &self,
128 rhs: &'a mut ArrayBase<Sb, Ix1>,
129 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
130 where
131 Sb: DataMut<Elem = A>,
132 {
133 assert_eq!(
134 rhs.len(),
135 self.a.len_of(Axis(1)),
136 "The length of `rhs` must be compatible with the shape of the factored matrix.",
137 );
138 A::solveh(
139 self.a.square_layout()?,
140 UPLO::Upper,
141 self.a.as_allocated()?,
142 &self.ipiv,
143 rhs.as_slice_mut().unwrap(),
144 )?;
145 Ok(rhs)
146 }
147}
148
149impl<A, S> SolveH<A> for ArrayBase<S, Ix2>
150where
151 A: Scalar + Lapack,
152 S: Data<Elem = A>,
153{
154 fn solveh_inplace<'a, Sb>(
155 &self,
156 rhs: &'a mut ArrayBase<Sb, Ix1>,
157 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
158 where
159 Sb: DataMut<Elem = A>,
160 {
161 let f = self.factorizeh()?;
162 f.solveh_inplace(rhs)
163 }
164}
165
166/// An interface for computing the Bunch–Kaufman factorization of Hermitian (or
167/// real symmetric) matrix refs.
168pub trait FactorizeH<S: Data> {
169 /// Computes the Bunch–Kaufman factorization of a Hermitian (or real
170 /// symmetric) matrix.
171 fn factorizeh(&self) -> Result<BKFactorized<S>>;
172}
173
174/// An interface for computing the Bunch–Kaufman factorization of Hermitian (or
175/// real symmetric) matrices.
176pub trait FactorizeHInto<S: Data> {
177 /// Computes the Bunch–Kaufman factorization of a Hermitian (or real
178 /// symmetric) matrix.
179 fn factorizeh_into(self) -> Result<BKFactorized<S>>;
180}
181
182impl<A, S> FactorizeHInto<S> for ArrayBase<S, Ix2>
183where
184 A: Scalar + Lapack,
185 S: DataMut<Elem = A>,
186{
187 fn factorizeh_into(mut self) -> Result<BKFactorized<S>> {
188 let ipiv = A::bk(self.square_layout()?, UPLO::Upper, self.as_allocated_mut()?)?;
189 Ok(BKFactorized { a: self, ipiv })
190 }
191}
192
193impl<A, Si> FactorizeH<OwnedRepr<A>> for ArrayBase<Si, Ix2>
194where
195 A: Scalar + Lapack,
196 Si: Data<Elem = A>,
197{
198 fn factorizeh(&self) -> Result<BKFactorized<OwnedRepr<A>>> {
199 let mut a: Array2<A> = replicate(self);
200 let ipiv = A::bk(a.square_layout()?, UPLO::Upper, a.as_allocated_mut()?)?;
201 Ok(BKFactorized { a, ipiv })
202 }
203}
204
205/// An interface for inverting Hermitian (or real symmetric) matrix refs.
206pub trait InverseH {
207 type Output;
208 /// Computes the inverse of the Hermitian (or real symmetric) matrix.
209 fn invh(&self) -> Result<Self::Output>;
210}
211
212/// An interface for inverting Hermitian (or real symmetric) matrices.
213pub trait InverseHInto {
214 type Output;
215 /// Computes the inverse of the Hermitian (or real symmetric) matrix.
216 fn invh_into(self) -> Result<Self::Output>;
217}
218
219impl<A, S> InverseHInto for BKFactorized<S>
220where
221 A: Scalar + Lapack,
222 S: DataMut<Elem = A>,
223{
224 type Output = ArrayBase<S, Ix2>;
225
226 fn invh_into(mut self) -> Result<ArrayBase<S, Ix2>> {
227 A::invh(
228 self.a.square_layout()?,
229 UPLO::Upper,
230 self.a.as_allocated_mut()?,
231 &self.ipiv,
232 )?;
233 triangular_fill_hermitian(&mut self.a, UPLO::Upper);
234 Ok(self.a)
235 }
236}
237
238impl<A, S> InverseH for BKFactorized<S>
239where
240 A: Scalar + Lapack,
241 S: Data<Elem = A>,
242{
243 type Output = Array2<A>;
244
245 fn invh(&self) -> Result<Self::Output> {
246 let f = BKFactorized {
247 a: replicate(&self.a),
248 ipiv: self.ipiv.clone(),
249 };
250 f.invh_into()
251 }
252}
253
254impl<A, S> InverseHInto for ArrayBase<S, Ix2>
255where
256 A: Scalar + Lapack,
257 S: DataMut<Elem = A>,
258{
259 type Output = Self;
260
261 fn invh_into(self) -> Result<Self::Output> {
262 let f = self.factorizeh_into()?;
263 f.invh_into()
264 }
265}
266
267impl<A, Si> InverseH for ArrayBase<Si, Ix2>
268where
269 A: Scalar + Lapack,
270 Si: Data<Elem = A>,
271{
272 type Output = Array2<A>;
273
274 fn invh(&self) -> Result<Self::Output> {
275 let f = self.factorizeh()?;
276 f.invh_into()
277 }
278}
279
280/// An interface for calculating determinants of Hermitian (or real symmetric) matrix refs.
281pub trait DeterminantH {
282 /// The element type of the matrix.
283 type Elem: Scalar;
284
285 /// Computes the determinant of the Hermitian (or real symmetric) matrix.
286 fn deth(&self) -> Result<<Self::Elem as Scalar>::Real>;
287
288 /// Computes the `(sign, natural_log)` of the determinant of the Hermitian
289 /// (or real symmetric) matrix.
290 ///
291 /// The `natural_log` is the natural logarithm of the absolute value of the
292 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
293 /// is negative infinity.
294 ///
295 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
296 /// or just call `.deth()` instead.
297 ///
298 /// This method is more robust than `.deth()` to very small or very large
299 /// determinants since it returns the natural logarithm of the determinant
300 /// rather than the determinant itself.
301 fn sln_deth(&self) -> Result<(<Self::Elem as Scalar>::Real, <Self::Elem as Scalar>::Real)>;
302}
303
304/// An interface for calculating determinants of Hermitian (or real symmetric) matrices.
305pub trait DeterminantHInto {
306 /// The element type of the matrix.
307 type Elem: Scalar;
308
309 /// Computes the determinant of the Hermitian (or real symmetric) matrix.
310 fn deth_into(self) -> Result<<Self::Elem as Scalar>::Real>;
311
312 /// Computes the `(sign, natural_log)` of the determinant of the Hermitian
313 /// (or real symmetric) matrix.
314 ///
315 /// The `natural_log` is the natural logarithm of the absolute value of the
316 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
317 /// is negative infinity.
318 ///
319 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
320 /// or just call `.deth_into()` instead.
321 ///
322 /// This method is more robust than `.deth_into()` to very small or very
323 /// large determinants since it returns the natural logarithm of the
324 /// determinant rather than the determinant itself.
325 fn sln_deth_into(self) -> Result<(<Self::Elem as Scalar>::Real, <Self::Elem as Scalar>::Real)>;
326}
327
328/// Returns the sign and natural log of the determinant.
329fn bk_sln_det<P, S, A>(uplo: UPLO, ipiv_iter: P, a: &ArrayBase<S, Ix2>) -> (A::Real, A::Real)
330where
331 P: Iterator<Item = i32>,
332 S: Data<Elem = A>,
333 A: Scalar + Lapack,
334{
335 let layout = a.layout().unwrap();
336 let mut sign = A::Real::one();
337 let mut ln_det = A::Real::zero();
338 let mut ipiv_enum = ipiv_iter.enumerate();
339 while let Some((k, ipiv_k)) = ipiv_enum.next() {
340 debug_assert!(k < a.nrows() && k < a.ncols());
341 if ipiv_k > 0 {
342 // 1x1 block at k, must be real.
343 let elem = unsafe { a.uget((k, k)) }.re();
344 debug_assert_eq!(elem.im(), Zero::zero());
345 sign *= elem.signum();
346 ln_det += Float::ln(Float::abs(elem));
347 } else {
348 // 2x2 block at k..k+2.
349
350 // Upper left diagonal elem, must be real.
351 let upper_diag = unsafe { a.uget((k, k)) }.re();
352 debug_assert_eq!(upper_diag.im(), Zero::zero());
353
354 // Lower right diagonal elem, must be real.
355 let lower_diag = unsafe { a.uget((k + 1, k + 1)) }.re();
356 debug_assert_eq!(lower_diag.im(), Zero::zero());
357
358 // Off-diagonal elements, can be complex.
359 let off_diag = match layout {
360 MatrixLayout::C { .. } => match uplo {
361 UPLO::Upper => unsafe { a.uget((k + 1, k)) },
362 UPLO::Lower => unsafe { a.uget((k, k + 1)) },
363 },
364 MatrixLayout::F { .. } => match uplo {
365 UPLO::Upper => unsafe { a.uget((k, k + 1)) },
366 UPLO::Lower => unsafe { a.uget((k + 1, k)) },
367 },
368 };
369
370 // Determinant of 2x2 block.
371 let block_det = upper_diag * lower_diag - off_diag.square();
372 sign *= block_det.signum();
373 ln_det += Float::ln(Float::abs(block_det));
374
375 // Skip the k+1 ipiv value.
376 ipiv_enum.next();
377 }
378 }
379 (sign, ln_det)
380}
381
382impl<A, S> BKFactorized<S>
383where
384 A: Scalar + Lapack,
385 S: Data<Elem = A>,
386{
387 /// Computes the determinant of the factorized Hermitian (or real
388 /// symmetric) matrix.
389 pub fn deth(&self) -> A::Real {
390 let (sign, ln_det) = self.sln_deth();
391 sign * Float::exp(ln_det)
392 }
393
394 /// Computes the `(sign, natural_log)` of the determinant of the factorized
395 /// Hermitian (or real symmetric) matrix.
396 ///
397 /// The `natural_log` is the natural logarithm of the absolute value of the
398 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
399 /// is negative infinity.
400 ///
401 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
402 /// or just call `.deth()` instead.
403 ///
404 /// This method is more robust than `.deth()` to very small or very large
405 /// determinants since it returns the natural logarithm of the determinant
406 /// rather than the determinant itself.
407 pub fn sln_deth(&self) -> (A::Real, A::Real) {
408 bk_sln_det(UPLO::Upper, self.ipiv.iter().cloned(), &self.a)
409 }
410
411 /// Computes the determinant of the factorized Hermitian (or real
412 /// symmetric) matrix.
413 pub fn deth_into(self) -> A::Real {
414 let (sign, ln_det) = self.sln_deth_into();
415 sign * Float::exp(ln_det)
416 }
417
418 /// Computes the `(sign, natural_log)` of the determinant of the factorized
419 /// Hermitian (or real symmetric) matrix.
420 ///
421 /// The `natural_log` is the natural logarithm of the absolute value of the
422 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
423 /// is negative infinity.
424 ///
425 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
426 /// or just call `.deth_into()` instead.
427 ///
428 /// This method is more robust than `.deth_into()` to very small or very
429 /// large determinants since it returns the natural logarithm of the
430 /// determinant rather than the determinant itself.
431 pub fn sln_deth_into(self) -> (A::Real, A::Real) {
432 bk_sln_det(UPLO::Upper, self.ipiv.into_iter(), &self.a)
433 }
434}
435
436impl<A, S> DeterminantH for ArrayBase<S, Ix2>
437where
438 A: Scalar + Lapack,
439 S: Data<Elem = A>,
440{
441 type Elem = A;
442
443 fn deth(&self) -> Result<A::Real> {
444 let (sign, ln_det) = self.sln_deth()?;
445 Ok(sign * Float::exp(ln_det))
446 }
447
448 fn sln_deth(&self) -> Result<(A::Real, A::Real)> {
449 match self.factorizeh() {
450 Ok(fac) => Ok(fac.sln_deth()),
451 Err(LinalgError::Lapack(e))
452 if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
453 {
454 // Determinant is zero.
455 Ok((A::Real::zero(), A::Real::neg_infinity()))
456 }
457 Err(err) => Err(err),
458 }
459 }
460}
461
462impl<A, S> DeterminantHInto for ArrayBase<S, Ix2>
463where
464 A: Scalar + Lapack,
465 S: DataMut<Elem = A>,
466{
467 type Elem = A;
468
469 fn deth_into(self) -> Result<A::Real> {
470 let (sign, ln_det) = self.sln_deth_into()?;
471 Ok(sign * Float::exp(ln_det))
472 }
473
474 fn sln_deth_into(self) -> Result<(A::Real, A::Real)> {
475 match self.factorizeh_into() {
476 Ok(fac) => Ok(fac.sln_deth_into()),
477 Err(LinalgError::Lapack(e))
478 if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
479 {
480 // Determinant is zero.
481 Ok((A::Real::zero(), A::Real::neg_infinity()))
482 }
483 Err(err) => Err(err),
484 }
485 }
486}