ndarray_linalg/
solveh.rs

1//! Solve Hermitian (or real symmetric) linear problems and invert Hermitian
2//! (or real symmetric) matrices
3//!
4//! **Note that only the upper triangular portion of the matrix is used.**
5//!
6//! # Examples
7//!
8//! Solve `A * x = b`, where `A` is a Hermitian (or real symmetric) matrix:
9//!
10//! ```
11//! #[macro_use]
12//! extern crate ndarray;
13//! extern crate ndarray_linalg;
14//!
15//! use ndarray::prelude::*;
16//! use ndarray_linalg::SolveH;
17//! # fn main() {
18//!
19//! let a: Array2<f64> = array![
20//!     [3., 2., -1.],
21//!     [2., -2., 4.],
22//!     [-1., 4., 5.]
23//! ];
24//! let b: Array1<f64> = array![11., -12., 1.];
25//! let x = a.solveh_into(b).unwrap();
26//! assert!(x.abs_diff_eq(&array![1., 3., -2.], 1e-9));
27//!
28//! # }
29//! ```
30//!
31//! If you are solving multiple systems of linear equations with the same
32//! Hermitian or real symmetric coefficient matrix `A`, it's faster to compute
33//! the factorization once at the beginning than solving directly using `A`:
34//!
35//! ```
36//! # extern crate ndarray;
37//! # extern crate ndarray_linalg;
38//! use ndarray::prelude::*;
39//! use ndarray_linalg::*;
40//! # fn main() {
41//!
42//! let a: Array2<f64> = random((3, 3));
43//! let f = a.factorizeh_into().unwrap(); // Factorize A (A is consumed)
44//! for _ in 0..10 {
45//!     let b: Array1<f64> = random(3);
46//!     let x = f.solveh_into(b).unwrap(); // Solve A * x = b using the factorization
47//! }
48//!
49//! # }
50//! ```
51
52use ndarray::*;
53use num_traits::{Float, One, Zero};
54
55use crate::convert::*;
56use crate::error::*;
57use crate::layout::*;
58use crate::types::*;
59
60pub use lax::{Pivot, UPLO};
61
62/// An interface for solving systems of Hermitian (or real symmetric) linear equations.
63///
64/// If you plan to solve many equations with the same Hermitian (or real
65/// symmetric) coefficient matrix `A` but different `b` vectors, it's faster to
66/// factor the `A` matrix once using the `FactorizeH` trait, and then solve
67/// using the `BKFactorized` struct.
68pub trait SolveH<A: Scalar> {
69    /// Solves a system of linear equations `A * x = b` with Hermitian (or real
70    /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
71    /// `x` is the successful result.
72    ///
73    /// # Panics
74    ///
75    /// Panics if the length of `b` is not the equal to the number of columns
76    /// of `A`.
77    fn solveh<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
78        let mut b = replicate(b);
79        self.solveh_inplace(&mut b)?;
80        Ok(b)
81    }
82
83    /// Solves a system of linear equations `A * x = b` with Hermitian (or real
84    /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
85    /// `x` is the successful result.
86    ///
87    /// # Panics
88    ///
89    /// Panics if the length of `b` is not the equal to the number of columns
90    /// of `A`.
91    fn solveh_into<S: DataMut<Elem = A>>(
92        &self,
93        mut b: ArrayBase<S, Ix1>,
94    ) -> Result<ArrayBase<S, Ix1>> {
95        self.solveh_inplace(&mut b)?;
96        Ok(b)
97    }
98
99    /// Solves a system of linear equations `A * x = b` with Hermitian (or real
100    /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
101    /// `x` is the successful result. The value of `x` is also assigned to the
102    /// argument.
103    ///
104    /// # Panics
105    ///
106    /// Panics if the length of `b` is not the equal to the number of columns
107    /// of `A`.
108    fn solveh_inplace<'a, S: DataMut<Elem = A>>(
109        &self,
110        b: &'a mut ArrayBase<S, Ix1>,
111    ) -> Result<&'a mut ArrayBase<S, Ix1>>;
112}
113
114/// Represents the Bunch–Kaufman factorization of a Hermitian (or real
115/// symmetric) matrix as `A = P * U * D * U^H * P^T`.
116pub struct BKFactorized<S: Data> {
117    pub a: ArrayBase<S, Ix2>,
118    pub ipiv: Pivot,
119}
120
121impl<A, S> SolveH<A> for BKFactorized<S>
122where
123    A: Scalar + Lapack,
124    S: Data<Elem = A>,
125{
126    fn solveh_inplace<'a, Sb>(
127        &self,
128        rhs: &'a mut ArrayBase<Sb, Ix1>,
129    ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
130    where
131        Sb: DataMut<Elem = A>,
132    {
133        assert_eq!(
134            rhs.len(),
135            self.a.len_of(Axis(1)),
136            "The length of `rhs` must be compatible with the shape of the factored matrix.",
137        );
138        A::solveh(
139            self.a.square_layout()?,
140            UPLO::Upper,
141            self.a.as_allocated()?,
142            &self.ipiv,
143            rhs.as_slice_mut().unwrap(),
144        )?;
145        Ok(rhs)
146    }
147}
148
149impl<A, S> SolveH<A> for ArrayBase<S, Ix2>
150where
151    A: Scalar + Lapack,
152    S: Data<Elem = A>,
153{
154    fn solveh_inplace<'a, Sb>(
155        &self,
156        rhs: &'a mut ArrayBase<Sb, Ix1>,
157    ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
158    where
159        Sb: DataMut<Elem = A>,
160    {
161        let f = self.factorizeh()?;
162        f.solveh_inplace(rhs)
163    }
164}
165
166/// An interface for computing the Bunch–Kaufman factorization of Hermitian (or
167/// real symmetric) matrix refs.
168pub trait FactorizeH<S: Data> {
169    /// Computes the Bunch–Kaufman factorization of a Hermitian (or real
170    /// symmetric) matrix.
171    fn factorizeh(&self) -> Result<BKFactorized<S>>;
172}
173
174/// An interface for computing the Bunch–Kaufman factorization of Hermitian (or
175/// real symmetric) matrices.
176pub trait FactorizeHInto<S: Data> {
177    /// Computes the Bunch–Kaufman factorization of a Hermitian (or real
178    /// symmetric) matrix.
179    fn factorizeh_into(self) -> Result<BKFactorized<S>>;
180}
181
182impl<A, S> FactorizeHInto<S> for ArrayBase<S, Ix2>
183where
184    A: Scalar + Lapack,
185    S: DataMut<Elem = A>,
186{
187    fn factorizeh_into(mut self) -> Result<BKFactorized<S>> {
188        let ipiv = A::bk(self.square_layout()?, UPLO::Upper, self.as_allocated_mut()?)?;
189        Ok(BKFactorized { a: self, ipiv })
190    }
191}
192
193impl<A, Si> FactorizeH<OwnedRepr<A>> for ArrayBase<Si, Ix2>
194where
195    A: Scalar + Lapack,
196    Si: Data<Elem = A>,
197{
198    fn factorizeh(&self) -> Result<BKFactorized<OwnedRepr<A>>> {
199        let mut a: Array2<A> = replicate(self);
200        let ipiv = A::bk(a.square_layout()?, UPLO::Upper, a.as_allocated_mut()?)?;
201        Ok(BKFactorized { a, ipiv })
202    }
203}
204
205/// An interface for inverting Hermitian (or real symmetric) matrix refs.
206pub trait InverseH {
207    type Output;
208    /// Computes the inverse of the Hermitian (or real symmetric) matrix.
209    fn invh(&self) -> Result<Self::Output>;
210}
211
212/// An interface for inverting Hermitian (or real symmetric) matrices.
213pub trait InverseHInto {
214    type Output;
215    /// Computes the inverse of the Hermitian (or real symmetric) matrix.
216    fn invh_into(self) -> Result<Self::Output>;
217}
218
219impl<A, S> InverseHInto for BKFactorized<S>
220where
221    A: Scalar + Lapack,
222    S: DataMut<Elem = A>,
223{
224    type Output = ArrayBase<S, Ix2>;
225
226    fn invh_into(mut self) -> Result<ArrayBase<S, Ix2>> {
227        A::invh(
228            self.a.square_layout()?,
229            UPLO::Upper,
230            self.a.as_allocated_mut()?,
231            &self.ipiv,
232        )?;
233        triangular_fill_hermitian(&mut self.a, UPLO::Upper);
234        Ok(self.a)
235    }
236}
237
238impl<A, S> InverseH for BKFactorized<S>
239where
240    A: Scalar + Lapack,
241    S: Data<Elem = A>,
242{
243    type Output = Array2<A>;
244
245    fn invh(&self) -> Result<Self::Output> {
246        let f = BKFactorized {
247            a: replicate(&self.a),
248            ipiv: self.ipiv.clone(),
249        };
250        f.invh_into()
251    }
252}
253
254impl<A, S> InverseHInto for ArrayBase<S, Ix2>
255where
256    A: Scalar + Lapack,
257    S: DataMut<Elem = A>,
258{
259    type Output = Self;
260
261    fn invh_into(self) -> Result<Self::Output> {
262        let f = self.factorizeh_into()?;
263        f.invh_into()
264    }
265}
266
267impl<A, Si> InverseH for ArrayBase<Si, Ix2>
268where
269    A: Scalar + Lapack,
270    Si: Data<Elem = A>,
271{
272    type Output = Array2<A>;
273
274    fn invh(&self) -> Result<Self::Output> {
275        let f = self.factorizeh()?;
276        f.invh_into()
277    }
278}
279
280/// An interface for calculating determinants of Hermitian (or real symmetric) matrix refs.
281pub trait DeterminantH {
282    /// The element type of the matrix.
283    type Elem: Scalar;
284
285    /// Computes the determinant of the Hermitian (or real symmetric) matrix.
286    fn deth(&self) -> Result<<Self::Elem as Scalar>::Real>;
287
288    /// Computes the `(sign, natural_log)` of the determinant of the Hermitian
289    /// (or real symmetric) matrix.
290    ///
291    /// The `natural_log` is the natural logarithm of the absolute value of the
292    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
293    /// is negative infinity.
294    ///
295    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
296    /// or just call `.deth()` instead.
297    ///
298    /// This method is more robust than `.deth()` to very small or very large
299    /// determinants since it returns the natural logarithm of the determinant
300    /// rather than the determinant itself.
301    fn sln_deth(&self) -> Result<(<Self::Elem as Scalar>::Real, <Self::Elem as Scalar>::Real)>;
302}
303
304/// An interface for calculating determinants of Hermitian (or real symmetric) matrices.
305pub trait DeterminantHInto {
306    /// The element type of the matrix.
307    type Elem: Scalar;
308
309    /// Computes the determinant of the Hermitian (or real symmetric) matrix.
310    fn deth_into(self) -> Result<<Self::Elem as Scalar>::Real>;
311
312    /// Computes the `(sign, natural_log)` of the determinant of the Hermitian
313    /// (or real symmetric) matrix.
314    ///
315    /// The `natural_log` is the natural logarithm of the absolute value of the
316    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
317    /// is negative infinity.
318    ///
319    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
320    /// or just call `.deth_into()` instead.
321    ///
322    /// This method is more robust than `.deth_into()` to very small or very
323    /// large determinants since it returns the natural logarithm of the
324    /// determinant rather than the determinant itself.
325    fn sln_deth_into(self) -> Result<(<Self::Elem as Scalar>::Real, <Self::Elem as Scalar>::Real)>;
326}
327
328/// Returns the sign and natural log of the determinant.
329fn bk_sln_det<P, S, A>(uplo: UPLO, ipiv_iter: P, a: &ArrayBase<S, Ix2>) -> (A::Real, A::Real)
330where
331    P: Iterator<Item = i32>,
332    S: Data<Elem = A>,
333    A: Scalar + Lapack,
334{
335    let layout = a.layout().unwrap();
336    let mut sign = A::Real::one();
337    let mut ln_det = A::Real::zero();
338    let mut ipiv_enum = ipiv_iter.enumerate();
339    while let Some((k, ipiv_k)) = ipiv_enum.next() {
340        debug_assert!(k < a.nrows() && k < a.ncols());
341        if ipiv_k > 0 {
342            // 1x1 block at k, must be real.
343            let elem = unsafe { a.uget((k, k)) }.re();
344            debug_assert_eq!(elem.im(), Zero::zero());
345            sign *= elem.signum();
346            ln_det += Float::ln(Float::abs(elem));
347        } else {
348            // 2x2 block at k..k+2.
349
350            // Upper left diagonal elem, must be real.
351            let upper_diag = unsafe { a.uget((k, k)) }.re();
352            debug_assert_eq!(upper_diag.im(), Zero::zero());
353
354            // Lower right diagonal elem, must be real.
355            let lower_diag = unsafe { a.uget((k + 1, k + 1)) }.re();
356            debug_assert_eq!(lower_diag.im(), Zero::zero());
357
358            // Off-diagonal elements, can be complex.
359            let off_diag = match layout {
360                MatrixLayout::C { .. } => match uplo {
361                    UPLO::Upper => unsafe { a.uget((k + 1, k)) },
362                    UPLO::Lower => unsafe { a.uget((k, k + 1)) },
363                },
364                MatrixLayout::F { .. } => match uplo {
365                    UPLO::Upper => unsafe { a.uget((k, k + 1)) },
366                    UPLO::Lower => unsafe { a.uget((k + 1, k)) },
367                },
368            };
369
370            // Determinant of 2x2 block.
371            let block_det = upper_diag * lower_diag - off_diag.square();
372            sign *= block_det.signum();
373            ln_det += Float::ln(Float::abs(block_det));
374
375            // Skip the k+1 ipiv value.
376            ipiv_enum.next();
377        }
378    }
379    (sign, ln_det)
380}
381
382impl<A, S> BKFactorized<S>
383where
384    A: Scalar + Lapack,
385    S: Data<Elem = A>,
386{
387    /// Computes the determinant of the factorized Hermitian (or real
388    /// symmetric) matrix.
389    pub fn deth(&self) -> A::Real {
390        let (sign, ln_det) = self.sln_deth();
391        sign * Float::exp(ln_det)
392    }
393
394    /// Computes the `(sign, natural_log)` of the determinant of the factorized
395    /// Hermitian (or real symmetric) matrix.
396    ///
397    /// The `natural_log` is the natural logarithm of the absolute value of the
398    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
399    /// is negative infinity.
400    ///
401    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
402    /// or just call `.deth()` instead.
403    ///
404    /// This method is more robust than `.deth()` to very small or very large
405    /// determinants since it returns the natural logarithm of the determinant
406    /// rather than the determinant itself.
407    pub fn sln_deth(&self) -> (A::Real, A::Real) {
408        bk_sln_det(UPLO::Upper, self.ipiv.iter().cloned(), &self.a)
409    }
410
411    /// Computes the determinant of the factorized Hermitian (or real
412    /// symmetric) matrix.
413    pub fn deth_into(self) -> A::Real {
414        let (sign, ln_det) = self.sln_deth_into();
415        sign * Float::exp(ln_det)
416    }
417
418    /// Computes the `(sign, natural_log)` of the determinant of the factorized
419    /// Hermitian (or real symmetric) matrix.
420    ///
421    /// The `natural_log` is the natural logarithm of the absolute value of the
422    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
423    /// is negative infinity.
424    ///
425    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
426    /// or just call `.deth_into()` instead.
427    ///
428    /// This method is more robust than `.deth_into()` to very small or very
429    /// large determinants since it returns the natural logarithm of the
430    /// determinant rather than the determinant itself.
431    pub fn sln_deth_into(self) -> (A::Real, A::Real) {
432        bk_sln_det(UPLO::Upper, self.ipiv.into_iter(), &self.a)
433    }
434}
435
436impl<A, S> DeterminantH for ArrayBase<S, Ix2>
437where
438    A: Scalar + Lapack,
439    S: Data<Elem = A>,
440{
441    type Elem = A;
442
443    fn deth(&self) -> Result<A::Real> {
444        let (sign, ln_det) = self.sln_deth()?;
445        Ok(sign * Float::exp(ln_det))
446    }
447
448    fn sln_deth(&self) -> Result<(A::Real, A::Real)> {
449        match self.factorizeh() {
450            Ok(fac) => Ok(fac.sln_deth()),
451            Err(LinalgError::Lapack(e))
452                if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
453            {
454                // Determinant is zero.
455                Ok((A::Real::zero(), A::Real::neg_infinity()))
456            }
457            Err(err) => Err(err),
458        }
459    }
460}
461
462impl<A, S> DeterminantHInto for ArrayBase<S, Ix2>
463where
464    A: Scalar + Lapack,
465    S: DataMut<Elem = A>,
466{
467    type Elem = A;
468
469    fn deth_into(self) -> Result<A::Real> {
470        let (sign, ln_det) = self.sln_deth_into()?;
471        Ok(sign * Float::exp(ln_det))
472    }
473
474    fn sln_deth_into(self) -> Result<(A::Real, A::Real)> {
475        match self.factorizeh_into() {
476            Ok(fac) => Ok(fac.sln_deth_into()),
477            Err(LinalgError::Lapack(e))
478                if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
479            {
480                // Determinant is zero.
481                Ok((A::Real::zero(), A::Real::neg_infinity()))
482            }
483            Err(err) => Err(err),
484        }
485    }
486}