ndarray_linalg/solve.rs
1//! Solve systems of linear equations and invert matrices
2//!
3//! # Examples
4//!
5//! Solve `A * x = b`:
6//!
7//! ```
8//! #[macro_use]
9//! extern crate ndarray;
10//! extern crate ndarray_linalg;
11//!
12//! use ndarray::prelude::*;
13//! use ndarray_linalg::Solve;
14//! # fn main() {
15//!
16//! let a: Array2<f64> = array![[3., 2., -1.], [2., -2., 4.], [-2., 1., -2.]];
17//! let b: Array1<f64> = array![1., -2., 0.];
18//! let x = a.solve_into(b).unwrap();
19//! assert!(x.abs_diff_eq(&array![1., -2., -2.], 1e-9));
20//!
21//! # }
22//! ```
23//!
24//! There are also special functions for solving `A^T * x = b` and
25//! `A^H * x = b`.
26//!
27//! If you are solving multiple systems of linear equations with the same
28//! coefficient matrix `A`, it's faster to compute the LU factorization once at
29//! the beginning than solving directly using `A`:
30//!
31//! ```
32//! # extern crate ndarray;
33//! # extern crate ndarray_linalg;
34//!
35//! use ndarray::prelude::*;
36//! use ndarray_linalg::*;
37//! # fn main() {
38//!
39//! let a: Array2<f64> = random((3, 3));
40//! let f = a.factorize_into().unwrap(); // LU factorize A (A is consumed)
41//! for _ in 0..10 {
42//! let b: Array1<f64> = random(3);
43//! let x = f.solve_into(b).unwrap(); // Solve A * x = b using factorized L, U
44//! }
45//!
46//! # }
47//! ```
48
49use ndarray::*;
50use num_traits::{Float, Zero};
51
52use crate::convert::*;
53use crate::error::*;
54use crate::layout::*;
55use crate::opnorm::OperationNorm;
56use crate::types::*;
57
58pub use lax::{Pivot, Transpose};
59
60/// An interface for solving systems of linear equations.
61///
62/// There are three groups of methods:
63///
64/// * `solve*` (normal) methods solve `A * x = b` for `x`.
65/// * `solve_t*` (transpose) methods solve `A^T * x = b` for `x`.
66/// * `solve_h*` (Hermitian conjugate) methods solve `A^H * x = b` for `x`.
67///
68/// Within each group, there are three methods that handle ownership differently:
69///
70/// * `*` methods take a reference to `b` and return `x` as a new array.
71/// * `*_into` methods take ownership of `b`, store the result in it, and return it.
72/// * `*_inplace` methods take a mutable reference to `b` and store the result in that array.
73///
74/// If you plan to solve many equations with the same `A` matrix but different
75/// `b` vectors, it's faster to factor the `A` matrix once using the
76/// `Factorize` trait, and then solve using the `LUFactorized` struct.
77pub trait Solve<A: Scalar> {
78 /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
79 /// is the argument, and `x` is the successful result.
80 ///
81 /// # Panics
82 ///
83 /// Panics if the length of `b` is not the equal to the number of columns
84 /// of `A`.
85 fn solve<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
86 let mut b = replicate(b);
87 self.solve_inplace(&mut b)?;
88 Ok(b)
89 }
90
91 /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
92 /// is the argument, and `x` is the successful result.
93 ///
94 /// # Panics
95 ///
96 /// Panics if the length of `b` is not the equal to the number of columns
97 /// of `A`.
98 fn solve_into<S: DataMut<Elem = A>>(
99 &self,
100 mut b: ArrayBase<S, Ix1>,
101 ) -> Result<ArrayBase<S, Ix1>> {
102 self.solve_inplace(&mut b)?;
103 Ok(b)
104 }
105
106 /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
107 /// is the argument, and `x` is the successful result.
108 ///
109 /// # Panics
110 ///
111 /// Panics if the length of `b` is not the equal to the number of columns
112 /// of `A`.
113 fn solve_inplace<'a, S: DataMut<Elem = A>>(
114 &self,
115 b: &'a mut ArrayBase<S, Ix1>,
116 ) -> Result<&'a mut ArrayBase<S, Ix1>>;
117
118 /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
119 /// is the argument, and `x` is the successful result.
120 ///
121 /// # Panics
122 ///
123 /// Panics if the length of `b` is not the equal to the number of rows of
124 /// `A`.
125 fn solve_t<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
126 let mut b = replicate(b);
127 self.solve_t_inplace(&mut b)?;
128 Ok(b)
129 }
130
131 /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
132 /// is the argument, and `x` is the successful result.
133 ///
134 /// # Panics
135 ///
136 /// Panics if the length of `b` is not the equal to the number of rows of
137 /// `A`.
138 fn solve_t_into<S: DataMut<Elem = A>>(
139 &self,
140 mut b: ArrayBase<S, Ix1>,
141 ) -> Result<ArrayBase<S, Ix1>> {
142 self.solve_t_inplace(&mut b)?;
143 Ok(b)
144 }
145
146 /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
147 /// is the argument, and `x` is the successful result.
148 ///
149 /// # Panics
150 ///
151 /// Panics if the length of `b` is not the equal to the number of rows of
152 /// `A`.
153 fn solve_t_inplace<'a, S: DataMut<Elem = A>>(
154 &self,
155 b: &'a mut ArrayBase<S, Ix1>,
156 ) -> Result<&'a mut ArrayBase<S, Ix1>>;
157
158 /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
159 /// is the argument, and `x` is the successful result.
160 ///
161 /// # Panics
162 ///
163 /// Panics if the length of `b` is not the equal to the number of rows of
164 /// `A`.
165 fn solve_h<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
166 let mut b = replicate(b);
167 self.solve_h_inplace(&mut b)?;
168 Ok(b)
169 }
170 /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
171 /// is the argument, and `x` is the successful result.
172 ///
173 /// # Panics
174 ///
175 /// Panics if the length of `b` is not the equal to the number of rows of
176 /// `A`.
177 fn solve_h_into<S: DataMut<Elem = A>>(
178 &self,
179 mut b: ArrayBase<S, Ix1>,
180 ) -> Result<ArrayBase<S, Ix1>> {
181 self.solve_h_inplace(&mut b)?;
182 Ok(b)
183 }
184 /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
185 /// is the argument, and `x` is the successful result.
186 ///
187 /// # Panics
188 ///
189 /// Panics if the length of `b` is not the equal to the number of rows of
190 /// `A`.
191 fn solve_h_inplace<'a, S: DataMut<Elem = A>>(
192 &self,
193 b: &'a mut ArrayBase<S, Ix1>,
194 ) -> Result<&'a mut ArrayBase<S, Ix1>>;
195}
196
197/// Represents the LU factorization of a matrix `A` as `A = P*L*U`.
198#[derive(Clone)]
199pub struct LUFactorized<S: Data + RawDataClone> {
200 /// The factors `L` and `U`; the unit diagonal elements of `L` are not
201 /// stored.
202 a: ArrayBase<S, Ix2>,
203 /// The pivot indices that define the permutation matrix `P`.
204 ipiv: Pivot,
205}
206
207impl<A, S> Solve<A> for LUFactorized<S>
208where
209 A: Scalar + Lapack,
210 S: Data<Elem = A> + RawDataClone,
211{
212 fn solve_inplace<'a, Sb>(
213 &self,
214 rhs: &'a mut ArrayBase<Sb, Ix1>,
215 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
216 where
217 Sb: DataMut<Elem = A>,
218 {
219 assert_eq!(
220 rhs.len(),
221 self.a.len_of(Axis(1)),
222 "The length of `rhs` must be compatible with the shape of the factored matrix.",
223 );
224 A::solve(
225 self.a.square_layout()?,
226 Transpose::No,
227 self.a.as_allocated()?,
228 &self.ipiv,
229 rhs.as_slice_mut().unwrap(),
230 )?;
231 Ok(rhs)
232 }
233 fn solve_t_inplace<'a, Sb>(
234 &self,
235 rhs: &'a mut ArrayBase<Sb, Ix1>,
236 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
237 where
238 Sb: DataMut<Elem = A>,
239 {
240 assert_eq!(
241 rhs.len(),
242 self.a.len_of(Axis(0)),
243 "The length of `rhs` must be compatible with the shape of the factored matrix.",
244 );
245 A::solve(
246 self.a.square_layout()?,
247 Transpose::Transpose,
248 self.a.as_allocated()?,
249 &self.ipiv,
250 rhs.as_slice_mut().unwrap(),
251 )?;
252 Ok(rhs)
253 }
254 fn solve_h_inplace<'a, Sb>(
255 &self,
256 rhs: &'a mut ArrayBase<Sb, Ix1>,
257 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
258 where
259 Sb: DataMut<Elem = A>,
260 {
261 assert_eq!(
262 rhs.len(),
263 self.a.len_of(Axis(0)),
264 "The length of `rhs` must be compatible with the shape of the factored matrix.",
265 );
266 A::solve(
267 self.a.square_layout()?,
268 Transpose::Hermite,
269 self.a.as_allocated()?,
270 &self.ipiv,
271 rhs.as_slice_mut().unwrap(),
272 )?;
273 Ok(rhs)
274 }
275}
276
277impl<A, S> Solve<A> for ArrayBase<S, Ix2>
278where
279 A: Scalar + Lapack,
280 S: Data<Elem = A>,
281{
282 fn solve_inplace<'a, Sb>(
283 &self,
284 rhs: &'a mut ArrayBase<Sb, Ix1>,
285 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
286 where
287 Sb: DataMut<Elem = A>,
288 {
289 let f = self.factorize()?;
290 f.solve_inplace(rhs)
291 }
292 fn solve_t_inplace<'a, Sb>(
293 &self,
294 rhs: &'a mut ArrayBase<Sb, Ix1>,
295 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
296 where
297 Sb: DataMut<Elem = A>,
298 {
299 let f = self.factorize()?;
300 f.solve_t_inplace(rhs)
301 }
302 fn solve_h_inplace<'a, Sb>(
303 &self,
304 rhs: &'a mut ArrayBase<Sb, Ix1>,
305 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
306 where
307 Sb: DataMut<Elem = A>,
308 {
309 let f = self.factorize()?;
310 f.solve_h_inplace(rhs)
311 }
312}
313
314/// An interface for computing LU factorizations of matrix refs.
315pub trait Factorize<S: Data + RawDataClone> {
316 /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation
317 /// matrix.
318 fn factorize(&self) -> Result<LUFactorized<S>>;
319}
320
321/// An interface for computing LU factorizations of matrices.
322pub trait FactorizeInto<S: Data + RawDataClone> {
323 /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation
324 /// matrix.
325 fn factorize_into(self) -> Result<LUFactorized<S>>;
326}
327
328impl<A, S> FactorizeInto<S> for ArrayBase<S, Ix2>
329where
330 A: Scalar + Lapack,
331 S: DataMut<Elem = A> + RawDataClone,
332{
333 fn factorize_into(mut self) -> Result<LUFactorized<S>> {
334 let ipiv = A::lu(self.layout()?, self.as_allocated_mut()?)?;
335 Ok(LUFactorized { a: self, ipiv })
336 }
337}
338
339impl<A, Si> Factorize<OwnedRepr<A>> for ArrayBase<Si, Ix2>
340where
341 A: Scalar + Lapack,
342 Si: Data<Elem = A>,
343{
344 fn factorize(&self) -> Result<LUFactorized<OwnedRepr<A>>> {
345 let mut a: Array2<A> = replicate(self);
346 let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?;
347 Ok(LUFactorized { a, ipiv })
348 }
349}
350
351/// An interface for inverting matrix refs.
352pub trait Inverse {
353 type Output;
354 /// Computes the inverse of the matrix.
355 fn inv(&self) -> Result<Self::Output>;
356}
357
358/// An interface for inverting matrices.
359pub trait InverseInto {
360 type Output;
361 /// Computes the inverse of the matrix.
362 fn inv_into(self) -> Result<Self::Output>;
363}
364
365impl<A, S> InverseInto for LUFactorized<S>
366where
367 A: Scalar + Lapack,
368 S: DataMut<Elem = A> + RawDataClone,
369{
370 type Output = ArrayBase<S, Ix2>;
371
372 fn inv_into(mut self) -> Result<ArrayBase<S, Ix2>> {
373 A::inv(
374 self.a.square_layout()?,
375 self.a.as_allocated_mut()?,
376 &self.ipiv,
377 )?;
378 Ok(self.a)
379 }
380}
381
382impl<A, S> Inverse for LUFactorized<S>
383where
384 A: Scalar + Lapack,
385 S: Data<Elem = A> + RawDataClone,
386{
387 type Output = Array2<A>;
388
389 fn inv(&self) -> Result<Array2<A>> {
390 // Preserve the existing layout. This is required to obtain the correct
391 // result, because the result of `A::inv` is layout-dependent.
392 let a = if self.a.is_standard_layout() {
393 replicate(&self.a)
394 } else {
395 replicate(&self.a.t()).reversed_axes()
396 };
397 let f = LUFactorized {
398 a,
399 ipiv: self.ipiv.clone(),
400 };
401 f.inv_into()
402 }
403}
404
405impl<A, S> InverseInto for ArrayBase<S, Ix2>
406where
407 A: Scalar + Lapack,
408 S: DataMut<Elem = A> + RawDataClone,
409{
410 type Output = Self;
411
412 fn inv_into(self) -> Result<Self::Output> {
413 let f = self.factorize_into()?;
414 f.inv_into()
415 }
416}
417
418impl<A, Si> Inverse for ArrayBase<Si, Ix2>
419where
420 A: Scalar + Lapack,
421 Si: Data<Elem = A>,
422{
423 type Output = Array2<A>;
424
425 fn inv(&self) -> Result<Self::Output> {
426 let f = self.factorize()?;
427 f.inv_into()
428 }
429}
430
431/// An interface for calculating determinants of matrix refs.
432pub trait Determinant<A: Scalar> {
433 /// Computes the determinant of the matrix.
434 fn det(&self) -> Result<A> {
435 let (sign, ln_det) = self.sln_det()?;
436 Ok(sign * A::from_real(Float::exp(ln_det)))
437 }
438
439 /// Computes the `(sign, natural_log)` of the determinant of the matrix.
440 ///
441 /// For real matrices, `sign` is `1`, `0`, or `-1`. For complex matrices,
442 /// `sign` is `0` or a complex number with absolute value 1. The
443 /// `natural_log` is the natural logarithm of the absolute value of the
444 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
445 /// is negative infinity.
446 ///
447 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
448 /// or just call `.det()` instead.
449 ///
450 /// This method is more robust than `.det()` to very small or very large
451 /// determinants since it returns the natural logarithm of the determinant
452 /// rather than the determinant itself.
453 fn sln_det(&self) -> Result<(A, A::Real)>;
454}
455
456/// An interface for calculating determinants of matrices.
457pub trait DeterminantInto<A: Scalar>: Sized {
458 /// Computes the determinant of the matrix.
459 fn det_into(self) -> Result<A> {
460 let (sign, ln_det) = self.sln_det_into()?;
461 Ok(sign * A::from_real(Float::exp(ln_det)))
462 }
463
464 /// Computes the `(sign, natural_log)` of the determinant of the matrix.
465 ///
466 /// For real matrices, `sign` is `1`, `0`, or `-1`. For complex matrices,
467 /// `sign` is `0` or a complex number with absolute value 1. The
468 /// `natural_log` is the natural logarithm of the absolute value of the
469 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
470 /// is negative infinity.
471 ///
472 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
473 /// or just call `.det_into()` instead.
474 ///
475 /// This method is more robust than `.det()` to very small or very large
476 /// determinants since it returns the natural logarithm of the determinant
477 /// rather than the determinant itself.
478 fn sln_det_into(self) -> Result<(A, A::Real)>;
479}
480
481fn lu_sln_det<'a, A, P, U>(ipiv_iter: P, u_diag_iter: U) -> (A, A::Real)
482where
483 A: Scalar + Lapack,
484 P: Iterator<Item = i32>,
485 U: Iterator<Item = &'a A>,
486{
487 let pivot_sign = if ipiv_iter
488 .enumerate()
489 .filter(|&(i, pivot)| pivot != i as i32 + 1)
490 .count()
491 % 2
492 == 0
493 {
494 A::one()
495 } else {
496 -A::one()
497 };
498 let (upper_sign, ln_det) = u_diag_iter.fold(
499 (A::one(), A::Real::zero()),
500 |(upper_sign, ln_det), &elem| {
501 let abs_elem: A::Real = elem.abs();
502 (
503 upper_sign * elem / A::from_real(abs_elem),
504 ln_det + Float::ln(abs_elem),
505 )
506 },
507 );
508 (pivot_sign * upper_sign, ln_det)
509}
510
511impl<A, S> Determinant<A> for LUFactorized<S>
512where
513 A: Scalar + Lapack,
514 S: Data<Elem = A> + RawDataClone,
515{
516 fn sln_det(&self) -> Result<(A, A::Real)> {
517 self.a.ensure_square()?;
518 Ok(lu_sln_det(self.ipiv.iter().cloned(), self.a.diag().iter()))
519 }
520}
521
522impl<A, S> DeterminantInto<A> for LUFactorized<S>
523where
524 A: Scalar + Lapack,
525 S: Data<Elem = A> + RawDataClone,
526{
527 fn sln_det_into(self) -> Result<(A, A::Real)> {
528 self.a.ensure_square()?;
529 Ok(lu_sln_det(self.ipiv.into_iter(), self.a.into_diag().iter()))
530 }
531}
532
533impl<A, S> Determinant<A> for ArrayBase<S, Ix2>
534where
535 A: Scalar + Lapack,
536 S: Data<Elem = A>,
537{
538 fn sln_det(&self) -> Result<(A, A::Real)> {
539 self.ensure_square()?;
540 match self.factorize() {
541 Ok(fac) => fac.sln_det(),
542 Err(LinalgError::Lapack(e))
543 if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
544 {
545 // The determinant is zero.
546 Ok((A::zero(), A::Real::neg_infinity()))
547 }
548 Err(err) => Err(err),
549 }
550 }
551}
552
553impl<A, S> DeterminantInto<A> for ArrayBase<S, Ix2>
554where
555 A: Scalar + Lapack,
556 S: DataMut<Elem = A> + RawDataClone,
557{
558 fn sln_det_into(self) -> Result<(A, A::Real)> {
559 self.ensure_square()?;
560 match self.factorize_into() {
561 Ok(fac) => fac.sln_det_into(),
562 Err(LinalgError::Lapack(e))
563 if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
564 {
565 // The determinant is zero.
566 Ok((A::zero(), A::Real::neg_infinity()))
567 }
568 Err(err) => Err(err),
569 }
570 }
571}
572
573/// An interface for *estimating* the reciprocal condition number of matrix refs.
574pub trait ReciprocalConditionNum<A: Scalar> {
575 /// *Estimates* the reciprocal of the condition number of the matrix in
576 /// 1-norm.
577 ///
578 /// This method uses the LAPACK `*gecon` routines, which *estimate*
579 /// `self.inv().opnorm_one()` and then compute `rcond = 1. /
580 /// (self.opnorm_one() * self.inv().opnorm_one())`.
581 ///
582 /// * If `rcond` is near `0.`, the matrix is badly conditioned.
583 /// * If `rcond` is near `1.`, the matrix is well conditioned.
584 fn rcond(&self) -> Result<A::Real>;
585}
586
587/// An interface for *estimating* the reciprocal condition number of matrices.
588pub trait ReciprocalConditionNumInto<A: Scalar> {
589 /// *Estimates* the reciprocal of the condition number of the matrix in
590 /// 1-norm.
591 ///
592 /// This method uses the LAPACK `*gecon` routines, which *estimate*
593 /// `self.inv().opnorm_one()` and then compute `rcond = 1. /
594 /// (self.opnorm_one() * self.inv().opnorm_one())`.
595 ///
596 /// * If `rcond` is near `0.`, the matrix is badly conditioned.
597 /// * If `rcond` is near `1.`, the matrix is well conditioned.
598 fn rcond_into(self) -> Result<A::Real>;
599}
600
601impl<A, S> ReciprocalConditionNum<A> for LUFactorized<S>
602where
603 A: Scalar + Lapack,
604 S: Data<Elem = A> + RawDataClone,
605{
606 fn rcond(&self) -> Result<A::Real> {
607 Ok(A::rcond(
608 self.a.layout()?,
609 self.a.as_allocated()?,
610 self.a.opnorm_one()?,
611 )?)
612 }
613}
614
615impl<A, S> ReciprocalConditionNumInto<A> for LUFactorized<S>
616where
617 A: Scalar + Lapack,
618 S: Data<Elem = A> + RawDataClone,
619{
620 fn rcond_into(self) -> Result<A::Real> {
621 self.rcond()
622 }
623}
624
625impl<A, S> ReciprocalConditionNum<A> for ArrayBase<S, Ix2>
626where
627 A: Scalar + Lapack,
628 S: Data<Elem = A>,
629{
630 fn rcond(&self) -> Result<A::Real> {
631 self.factorize()?.rcond_into()
632 }
633}
634
635impl<A, S> ReciprocalConditionNumInto<A> for ArrayBase<S, Ix2>
636where
637 A: Scalar + Lapack,
638 S: DataMut<Elem = A> + RawDataClone,
639{
640 fn rcond_into(self) -> Result<A::Real> {
641 self.factorize_into()?.rcond_into()
642 }
643}