ndarray_linalg/
solve.rs

1//! Solve systems of linear equations and invert matrices
2//!
3//! # Examples
4//!
5//! Solve `A * x = b`:
6//!
7//! ```
8//! #[macro_use]
9//! extern crate ndarray;
10//! extern crate ndarray_linalg;
11//!
12//! use ndarray::prelude::*;
13//! use ndarray_linalg::Solve;
14//! # fn main() {
15//!
16//! let a: Array2<f64> = array![[3., 2., -1.], [2., -2., 4.], [-2., 1., -2.]];
17//! let b: Array1<f64> = array![1., -2., 0.];
18//! let x = a.solve_into(b).unwrap();
19//! assert!(x.abs_diff_eq(&array![1., -2., -2.], 1e-9));
20//!
21//! # }
22//! ```
23//!
24//! There are also special functions for solving `A^T * x = b` and
25//! `A^H * x = b`.
26//!
27//! If you are solving multiple systems of linear equations with the same
28//! coefficient matrix `A`, it's faster to compute the LU factorization once at
29//! the beginning than solving directly using `A`:
30//!
31//! ```
32//! # extern crate ndarray;
33//! # extern crate ndarray_linalg;
34//!
35//! use ndarray::prelude::*;
36//! use ndarray_linalg::*;
37//! # fn main() {
38//!
39//! let a: Array2<f64> = random((3, 3));
40//! let f = a.factorize_into().unwrap(); // LU factorize A (A is consumed)
41//! for _ in 0..10 {
42//!     let b: Array1<f64> = random(3);
43//!     let x = f.solve_into(b).unwrap(); // Solve A * x = b using factorized L, U
44//! }
45//!
46//! # }
47//! ```
48
49use ndarray::*;
50use num_traits::{Float, Zero};
51
52use crate::convert::*;
53use crate::error::*;
54use crate::layout::*;
55use crate::opnorm::OperationNorm;
56use crate::types::*;
57
58pub use lax::{Pivot, Transpose};
59
60/// An interface for solving systems of linear equations.
61///
62/// There are three groups of methods:
63///
64/// * `solve*` (normal) methods solve `A * x = b` for `x`.
65/// * `solve_t*` (transpose) methods solve `A^T * x = b` for `x`.
66/// * `solve_h*` (Hermitian conjugate) methods solve `A^H * x = b` for `x`.
67///
68/// Within each group, there are three methods that handle ownership differently:
69///
70/// * `*` methods take a reference to `b` and return `x` as a new array.
71/// * `*_into` methods take ownership of `b`, store the result in it, and return it.
72/// * `*_inplace` methods take a mutable reference to `b` and store the result in that array.
73///
74/// If you plan to solve many equations with the same `A` matrix but different
75/// `b` vectors, it's faster to factor the `A` matrix once using the
76/// `Factorize` trait, and then solve using the `LUFactorized` struct.
77pub trait Solve<A: Scalar> {
78    /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
79    /// is the argument, and `x` is the successful result.
80    ///
81    /// # Panics
82    ///
83    /// Panics if the length of `b` is not the equal to the number of columns
84    /// of `A`.
85    fn solve<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
86        let mut b = replicate(b);
87        self.solve_inplace(&mut b)?;
88        Ok(b)
89    }
90
91    /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
92    /// is the argument, and `x` is the successful result.
93    ///
94    /// # Panics
95    ///
96    /// Panics if the length of `b` is not the equal to the number of columns
97    /// of `A`.
98    fn solve_into<S: DataMut<Elem = A>>(
99        &self,
100        mut b: ArrayBase<S, Ix1>,
101    ) -> Result<ArrayBase<S, Ix1>> {
102        self.solve_inplace(&mut b)?;
103        Ok(b)
104    }
105
106    /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
107    /// is the argument, and `x` is the successful result.
108    ///
109    /// # Panics
110    ///
111    /// Panics if the length of `b` is not the equal to the number of columns
112    /// of `A`.
113    fn solve_inplace<'a, S: DataMut<Elem = A>>(
114        &self,
115        b: &'a mut ArrayBase<S, Ix1>,
116    ) -> Result<&'a mut ArrayBase<S, Ix1>>;
117
118    /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
119    /// is the argument, and `x` is the successful result.
120    ///
121    /// # Panics
122    ///
123    /// Panics if the length of `b` is not the equal to the number of rows of
124    /// `A`.
125    fn solve_t<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
126        let mut b = replicate(b);
127        self.solve_t_inplace(&mut b)?;
128        Ok(b)
129    }
130
131    /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
132    /// is the argument, and `x` is the successful result.
133    ///
134    /// # Panics
135    ///
136    /// Panics if the length of `b` is not the equal to the number of rows of
137    /// `A`.
138    fn solve_t_into<S: DataMut<Elem = A>>(
139        &self,
140        mut b: ArrayBase<S, Ix1>,
141    ) -> Result<ArrayBase<S, Ix1>> {
142        self.solve_t_inplace(&mut b)?;
143        Ok(b)
144    }
145
146    /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
147    /// is the argument, and `x` is the successful result.
148    ///
149    /// # Panics
150    ///
151    /// Panics if the length of `b` is not the equal to the number of rows of
152    /// `A`.
153    fn solve_t_inplace<'a, S: DataMut<Elem = A>>(
154        &self,
155        b: &'a mut ArrayBase<S, Ix1>,
156    ) -> Result<&'a mut ArrayBase<S, Ix1>>;
157
158    /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
159    /// is the argument, and `x` is the successful result.
160    ///
161    /// # Panics
162    ///
163    /// Panics if the length of `b` is not the equal to the number of rows of
164    /// `A`.
165    fn solve_h<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
166        let mut b = replicate(b);
167        self.solve_h_inplace(&mut b)?;
168        Ok(b)
169    }
170    /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
171    /// is the argument, and `x` is the successful result.
172    ///
173    /// # Panics
174    ///
175    /// Panics if the length of `b` is not the equal to the number of rows of
176    /// `A`.
177    fn solve_h_into<S: DataMut<Elem = A>>(
178        &self,
179        mut b: ArrayBase<S, Ix1>,
180    ) -> Result<ArrayBase<S, Ix1>> {
181        self.solve_h_inplace(&mut b)?;
182        Ok(b)
183    }
184    /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
185    /// is the argument, and `x` is the successful result.
186    ///
187    /// # Panics
188    ///
189    /// Panics if the length of `b` is not the equal to the number of rows of
190    /// `A`.
191    fn solve_h_inplace<'a, S: DataMut<Elem = A>>(
192        &self,
193        b: &'a mut ArrayBase<S, Ix1>,
194    ) -> Result<&'a mut ArrayBase<S, Ix1>>;
195}
196
197/// Represents the LU factorization of a matrix `A` as `A = P*L*U`.
198#[derive(Clone)]
199pub struct LUFactorized<S: Data + RawDataClone> {
200    /// The factors `L` and `U`; the unit diagonal elements of `L` are not
201    /// stored.
202    a: ArrayBase<S, Ix2>,
203    /// The pivot indices that define the permutation matrix `P`.
204    ipiv: Pivot,
205}
206
207impl<A, S> Solve<A> for LUFactorized<S>
208where
209    A: Scalar + Lapack,
210    S: Data<Elem = A> + RawDataClone,
211{
212    fn solve_inplace<'a, Sb>(
213        &self,
214        rhs: &'a mut ArrayBase<Sb, Ix1>,
215    ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
216    where
217        Sb: DataMut<Elem = A>,
218    {
219        assert_eq!(
220            rhs.len(),
221            self.a.len_of(Axis(1)),
222            "The length of `rhs` must be compatible with the shape of the factored matrix.",
223        );
224        A::solve(
225            self.a.square_layout()?,
226            Transpose::No,
227            self.a.as_allocated()?,
228            &self.ipiv,
229            rhs.as_slice_mut().unwrap(),
230        )?;
231        Ok(rhs)
232    }
233    fn solve_t_inplace<'a, Sb>(
234        &self,
235        rhs: &'a mut ArrayBase<Sb, Ix1>,
236    ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
237    where
238        Sb: DataMut<Elem = A>,
239    {
240        assert_eq!(
241            rhs.len(),
242            self.a.len_of(Axis(0)),
243            "The length of `rhs` must be compatible with the shape of the factored matrix.",
244        );
245        A::solve(
246            self.a.square_layout()?,
247            Transpose::Transpose,
248            self.a.as_allocated()?,
249            &self.ipiv,
250            rhs.as_slice_mut().unwrap(),
251        )?;
252        Ok(rhs)
253    }
254    fn solve_h_inplace<'a, Sb>(
255        &self,
256        rhs: &'a mut ArrayBase<Sb, Ix1>,
257    ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
258    where
259        Sb: DataMut<Elem = A>,
260    {
261        assert_eq!(
262            rhs.len(),
263            self.a.len_of(Axis(0)),
264            "The length of `rhs` must be compatible with the shape of the factored matrix.",
265        );
266        A::solve(
267            self.a.square_layout()?,
268            Transpose::Hermite,
269            self.a.as_allocated()?,
270            &self.ipiv,
271            rhs.as_slice_mut().unwrap(),
272        )?;
273        Ok(rhs)
274    }
275}
276
277impl<A, S> Solve<A> for ArrayBase<S, Ix2>
278where
279    A: Scalar + Lapack,
280    S: Data<Elem = A>,
281{
282    fn solve_inplace<'a, Sb>(
283        &self,
284        rhs: &'a mut ArrayBase<Sb, Ix1>,
285    ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
286    where
287        Sb: DataMut<Elem = A>,
288    {
289        let f = self.factorize()?;
290        f.solve_inplace(rhs)
291    }
292    fn solve_t_inplace<'a, Sb>(
293        &self,
294        rhs: &'a mut ArrayBase<Sb, Ix1>,
295    ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
296    where
297        Sb: DataMut<Elem = A>,
298    {
299        let f = self.factorize()?;
300        f.solve_t_inplace(rhs)
301    }
302    fn solve_h_inplace<'a, Sb>(
303        &self,
304        rhs: &'a mut ArrayBase<Sb, Ix1>,
305    ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
306    where
307        Sb: DataMut<Elem = A>,
308    {
309        let f = self.factorize()?;
310        f.solve_h_inplace(rhs)
311    }
312}
313
314/// An interface for computing LU factorizations of matrix refs.
315pub trait Factorize<S: Data + RawDataClone> {
316    /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation
317    /// matrix.
318    fn factorize(&self) -> Result<LUFactorized<S>>;
319}
320
321/// An interface for computing LU factorizations of matrices.
322pub trait FactorizeInto<S: Data + RawDataClone> {
323    /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation
324    /// matrix.
325    fn factorize_into(self) -> Result<LUFactorized<S>>;
326}
327
328impl<A, S> FactorizeInto<S> for ArrayBase<S, Ix2>
329where
330    A: Scalar + Lapack,
331    S: DataMut<Elem = A> + RawDataClone,
332{
333    fn factorize_into(mut self) -> Result<LUFactorized<S>> {
334        let ipiv = A::lu(self.layout()?, self.as_allocated_mut()?)?;
335        Ok(LUFactorized { a: self, ipiv })
336    }
337}
338
339impl<A, Si> Factorize<OwnedRepr<A>> for ArrayBase<Si, Ix2>
340where
341    A: Scalar + Lapack,
342    Si: Data<Elem = A>,
343{
344    fn factorize(&self) -> Result<LUFactorized<OwnedRepr<A>>> {
345        let mut a: Array2<A> = replicate(self);
346        let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?;
347        Ok(LUFactorized { a, ipiv })
348    }
349}
350
351/// An interface for inverting matrix refs.
352pub trait Inverse {
353    type Output;
354    /// Computes the inverse of the matrix.
355    fn inv(&self) -> Result<Self::Output>;
356}
357
358/// An interface for inverting matrices.
359pub trait InverseInto {
360    type Output;
361    /// Computes the inverse of the matrix.
362    fn inv_into(self) -> Result<Self::Output>;
363}
364
365impl<A, S> InverseInto for LUFactorized<S>
366where
367    A: Scalar + Lapack,
368    S: DataMut<Elem = A> + RawDataClone,
369{
370    type Output = ArrayBase<S, Ix2>;
371
372    fn inv_into(mut self) -> Result<ArrayBase<S, Ix2>> {
373        A::inv(
374            self.a.square_layout()?,
375            self.a.as_allocated_mut()?,
376            &self.ipiv,
377        )?;
378        Ok(self.a)
379    }
380}
381
382impl<A, S> Inverse for LUFactorized<S>
383where
384    A: Scalar + Lapack,
385    S: Data<Elem = A> + RawDataClone,
386{
387    type Output = Array2<A>;
388
389    fn inv(&self) -> Result<Array2<A>> {
390        // Preserve the existing layout. This is required to obtain the correct
391        // result, because the result of `A::inv` is layout-dependent.
392        let a = if self.a.is_standard_layout() {
393            replicate(&self.a)
394        } else {
395            replicate(&self.a.t()).reversed_axes()
396        };
397        let f = LUFactorized {
398            a,
399            ipiv: self.ipiv.clone(),
400        };
401        f.inv_into()
402    }
403}
404
405impl<A, S> InverseInto for ArrayBase<S, Ix2>
406where
407    A: Scalar + Lapack,
408    S: DataMut<Elem = A> + RawDataClone,
409{
410    type Output = Self;
411
412    fn inv_into(self) -> Result<Self::Output> {
413        let f = self.factorize_into()?;
414        f.inv_into()
415    }
416}
417
418impl<A, Si> Inverse for ArrayBase<Si, Ix2>
419where
420    A: Scalar + Lapack,
421    Si: Data<Elem = A>,
422{
423    type Output = Array2<A>;
424
425    fn inv(&self) -> Result<Self::Output> {
426        let f = self.factorize()?;
427        f.inv_into()
428    }
429}
430
431/// An interface for calculating determinants of matrix refs.
432pub trait Determinant<A: Scalar> {
433    /// Computes the determinant of the matrix.
434    fn det(&self) -> Result<A> {
435        let (sign, ln_det) = self.sln_det()?;
436        Ok(sign * A::from_real(Float::exp(ln_det)))
437    }
438
439    /// Computes the `(sign, natural_log)` of the determinant of the matrix.
440    ///
441    /// For real matrices, `sign` is `1`, `0`, or `-1`. For complex matrices,
442    /// `sign` is `0` or a complex number with absolute value 1. The
443    /// `natural_log` is the natural logarithm of the absolute value of the
444    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
445    /// is negative infinity.
446    ///
447    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
448    /// or just call `.det()` instead.
449    ///
450    /// This method is more robust than `.det()` to very small or very large
451    /// determinants since it returns the natural logarithm of the determinant
452    /// rather than the determinant itself.
453    fn sln_det(&self) -> Result<(A, A::Real)>;
454}
455
456/// An interface for calculating determinants of matrices.
457pub trait DeterminantInto<A: Scalar>: Sized {
458    /// Computes the determinant of the matrix.
459    fn det_into(self) -> Result<A> {
460        let (sign, ln_det) = self.sln_det_into()?;
461        Ok(sign * A::from_real(Float::exp(ln_det)))
462    }
463
464    /// Computes the `(sign, natural_log)` of the determinant of the matrix.
465    ///
466    /// For real matrices, `sign` is `1`, `0`, or `-1`. For complex matrices,
467    /// `sign` is `0` or a complex number with absolute value 1. The
468    /// `natural_log` is the natural logarithm of the absolute value of the
469    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
470    /// is negative infinity.
471    ///
472    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
473    /// or just call `.det_into()` instead.
474    ///
475    /// This method is more robust than `.det()` to very small or very large
476    /// determinants since it returns the natural logarithm of the determinant
477    /// rather than the determinant itself.
478    fn sln_det_into(self) -> Result<(A, A::Real)>;
479}
480
481fn lu_sln_det<'a, A, P, U>(ipiv_iter: P, u_diag_iter: U) -> (A, A::Real)
482where
483    A: Scalar + Lapack,
484    P: Iterator<Item = i32>,
485    U: Iterator<Item = &'a A>,
486{
487    let pivot_sign = if ipiv_iter
488        .enumerate()
489        .filter(|&(i, pivot)| pivot != i as i32 + 1)
490        .count()
491        % 2
492        == 0
493    {
494        A::one()
495    } else {
496        -A::one()
497    };
498    let (upper_sign, ln_det) = u_diag_iter.fold(
499        (A::one(), A::Real::zero()),
500        |(upper_sign, ln_det), &elem| {
501            let abs_elem: A::Real = elem.abs();
502            (
503                upper_sign * elem / A::from_real(abs_elem),
504                ln_det + Float::ln(abs_elem),
505            )
506        },
507    );
508    (pivot_sign * upper_sign, ln_det)
509}
510
511impl<A, S> Determinant<A> for LUFactorized<S>
512where
513    A: Scalar + Lapack,
514    S: Data<Elem = A> + RawDataClone,
515{
516    fn sln_det(&self) -> Result<(A, A::Real)> {
517        self.a.ensure_square()?;
518        Ok(lu_sln_det(self.ipiv.iter().cloned(), self.a.diag().iter()))
519    }
520}
521
522impl<A, S> DeterminantInto<A> for LUFactorized<S>
523where
524    A: Scalar + Lapack,
525    S: Data<Elem = A> + RawDataClone,
526{
527    fn sln_det_into(self) -> Result<(A, A::Real)> {
528        self.a.ensure_square()?;
529        Ok(lu_sln_det(self.ipiv.into_iter(), self.a.into_diag().iter()))
530    }
531}
532
533impl<A, S> Determinant<A> for ArrayBase<S, Ix2>
534where
535    A: Scalar + Lapack,
536    S: Data<Elem = A>,
537{
538    fn sln_det(&self) -> Result<(A, A::Real)> {
539        self.ensure_square()?;
540        match self.factorize() {
541            Ok(fac) => fac.sln_det(),
542            Err(LinalgError::Lapack(e))
543                if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
544            {
545                // The determinant is zero.
546                Ok((A::zero(), A::Real::neg_infinity()))
547            }
548            Err(err) => Err(err),
549        }
550    }
551}
552
553impl<A, S> DeterminantInto<A> for ArrayBase<S, Ix2>
554where
555    A: Scalar + Lapack,
556    S: DataMut<Elem = A> + RawDataClone,
557{
558    fn sln_det_into(self) -> Result<(A, A::Real)> {
559        self.ensure_square()?;
560        match self.factorize_into() {
561            Ok(fac) => fac.sln_det_into(),
562            Err(LinalgError::Lapack(e))
563                if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
564            {
565                // The determinant is zero.
566                Ok((A::zero(), A::Real::neg_infinity()))
567            }
568            Err(err) => Err(err),
569        }
570    }
571}
572
573/// An interface for *estimating* the reciprocal condition number of matrix refs.
574pub trait ReciprocalConditionNum<A: Scalar> {
575    /// *Estimates* the reciprocal of the condition number of the matrix in
576    /// 1-norm.
577    ///
578    /// This method uses the LAPACK `*gecon` routines, which *estimate*
579    /// `self.inv().opnorm_one()` and then compute `rcond = 1. /
580    /// (self.opnorm_one() * self.inv().opnorm_one())`.
581    ///
582    /// * If `rcond` is near `0.`, the matrix is badly conditioned.
583    /// * If `rcond` is near `1.`, the matrix is well conditioned.
584    fn rcond(&self) -> Result<A::Real>;
585}
586
587/// An interface for *estimating* the reciprocal condition number of matrices.
588pub trait ReciprocalConditionNumInto<A: Scalar> {
589    /// *Estimates* the reciprocal of the condition number of the matrix in
590    /// 1-norm.
591    ///
592    /// This method uses the LAPACK `*gecon` routines, which *estimate*
593    /// `self.inv().opnorm_one()` and then compute `rcond = 1. /
594    /// (self.opnorm_one() * self.inv().opnorm_one())`.
595    ///
596    /// * If `rcond` is near `0.`, the matrix is badly conditioned.
597    /// * If `rcond` is near `1.`, the matrix is well conditioned.
598    fn rcond_into(self) -> Result<A::Real>;
599}
600
601impl<A, S> ReciprocalConditionNum<A> for LUFactorized<S>
602where
603    A: Scalar + Lapack,
604    S: Data<Elem = A> + RawDataClone,
605{
606    fn rcond(&self) -> Result<A::Real> {
607        Ok(A::rcond(
608            self.a.layout()?,
609            self.a.as_allocated()?,
610            self.a.opnorm_one()?,
611        )?)
612    }
613}
614
615impl<A, S> ReciprocalConditionNumInto<A> for LUFactorized<S>
616where
617    A: Scalar + Lapack,
618    S: Data<Elem = A> + RawDataClone,
619{
620    fn rcond_into(self) -> Result<A::Real> {
621        self.rcond()
622    }
623}
624
625impl<A, S> ReciprocalConditionNum<A> for ArrayBase<S, Ix2>
626where
627    A: Scalar + Lapack,
628    S: Data<Elem = A>,
629{
630    fn rcond(&self) -> Result<A::Real> {
631        self.factorize()?.rcond_into()
632    }
633}
634
635impl<A, S> ReciprocalConditionNumInto<A> for ArrayBase<S, Ix2>
636where
637    A: Scalar + Lapack,
638    S: DataMut<Elem = A> + RawDataClone,
639{
640    fn rcond_into(self) -> Result<A::Real> {
641        self.factorize_into()?.rcond_into()
642    }
643}