lax/
least_squares.rs

1//! Least squares
2
3use crate::{error::*, layout::*, *};
4use cauchy::*;
5use num_traits::{ToPrimitive, Zero};
6
7/// Result of LeastSquares
8pub struct LeastSquaresOutput<A: Scalar> {
9    /// singular values
10    pub singular_values: Vec<A::Real>,
11    /// The rank of the input matrix A
12    pub rank: i32,
13}
14
15/// Wraps `*gelsd`
16pub trait LeastSquaresSvdDivideConquer_: Scalar {
17    fn least_squares(
18        a_layout: MatrixLayout,
19        a: &mut [Self],
20        b: &mut [Self],
21    ) -> Result<LeastSquaresOutput<Self>>;
22
23    fn least_squares_nrhs(
24        a_layout: MatrixLayout,
25        a: &mut [Self],
26        b_layout: MatrixLayout,
27        b: &mut [Self],
28    ) -> Result<LeastSquaresOutput<Self>>;
29}
30
31macro_rules! impl_least_squares {
32    (@real, $scalar:ty, $gelsd:path) => {
33        impl_least_squares!(@body, $scalar, $gelsd, );
34    };
35    (@complex, $scalar:ty, $gelsd:path) => {
36        impl_least_squares!(@body, $scalar, $gelsd, rwork);
37    };
38
39    (@body, $scalar:ty, $gelsd:path, $($rwork:ident),*) => {
40        impl LeastSquaresSvdDivideConquer_ for $scalar {
41            fn least_squares(
42                l: MatrixLayout,
43                a: &mut [Self],
44                b: &mut [Self],
45            ) -> Result<LeastSquaresOutput<Self>> {
46                let b_layout = l.resized(b.len() as i32, 1);
47                Self::least_squares_nrhs(l, a, b_layout, b)
48            }
49
50            fn least_squares_nrhs(
51                a_layout: MatrixLayout,
52                a: &mut [Self],
53                b_layout: MatrixLayout,
54                b: &mut [Self],
55            ) -> Result<LeastSquaresOutput<Self>> {
56                // Minimize |b - Ax|_2
57                //
58                // where
59                //   A : (m, n)
60                //   b : (max(m, n), nrhs)  // `b` has to store `x` on exit
61                //   x : (n, nrhs)
62                let (m, n) = a_layout.size();
63                let (m_, nrhs) = b_layout.size();
64                let k = m.min(n);
65                assert!(m_ >= m);
66
67                // Transpose if a is C-continuous
68                let mut a_t = None;
69                let a_layout = match a_layout {
70                    MatrixLayout::C { .. } => {
71                        a_t = Some(unsafe { vec_uninit( a.len()) });
72                        transpose(a_layout, a, a_t.as_mut().unwrap())
73                    }
74                    MatrixLayout::F { .. } => a_layout,
75                };
76
77                // Transpose if b is C-continuous
78                let mut b_t = None;
79                let b_layout = match b_layout {
80                    MatrixLayout::C { .. } => {
81                        b_t = Some(unsafe { vec_uninit( b.len()) });
82                        transpose(b_layout, b, b_t.as_mut().unwrap())
83                    }
84                    MatrixLayout::F { .. } => b_layout,
85                };
86
87                let rcond: Self::Real = -1.;
88                let mut singular_values: Vec<Self::Real> = unsafe { vec_uninit( k as usize) };
89                let mut rank: i32 = 0;
90
91                // eval work size
92                let mut info = 0;
93                let mut work_size = [Self::zero()];
94                let mut iwork_size = [0];
95                $(
96                let mut $rwork = [Self::Real::zero()];
97                )*
98                unsafe {
99                    $gelsd(
100                        m,
101                        n,
102                        nrhs,
103                        a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a),
104                        a_layout.lda(),
105                        b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b),
106                        b_layout.lda(),
107                        &mut singular_values,
108                        rcond,
109                        &mut rank,
110                        &mut work_size,
111                        -1,
112                        $(&mut $rwork,)*
113                        &mut iwork_size,
114                        &mut info,
115                    )
116                };
117                info.as_lapack_result()?;
118
119                // calc
120                let lwork = work_size[0].to_usize().unwrap();
121                let mut work = unsafe { vec_uninit( lwork) };
122                let liwork = iwork_size[0].to_usize().unwrap();
123                let mut iwork = unsafe { vec_uninit( liwork) };
124                $(
125                let lrwork = $rwork[0].to_usize().unwrap();
126                let mut $rwork = unsafe { vec_uninit( lrwork) };
127                )*
128                unsafe {
129                    $gelsd(
130                        m,
131                        n,
132                        nrhs,
133                        a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a),
134                        a_layout.lda(),
135                        b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b),
136                        b_layout.lda(),
137                        &mut singular_values,
138                        rcond,
139                        &mut rank,
140                        &mut work,
141                        lwork as i32,
142                        $(&mut $rwork,)*
143                        &mut iwork,
144                        &mut info,
145                    );
146                }
147                info.as_lapack_result()?;
148
149                // Skip a_t -> a transpose because A has been destroyed
150                // Re-transpose b
151                if let Some(b_t) = b_t {
152                    transpose(b_layout, &b_t, b);
153                }
154
155                Ok(LeastSquaresOutput {
156                    singular_values,
157                    rank,
158                })
159            }
160        }
161    };
162}
163
164impl_least_squares!(@real, f64, lapack::dgelsd);
165impl_least_squares!(@real, f32, lapack::sgelsd);
166impl_least_squares!(@complex, c64, lapack::zgelsd);
167impl_least_squares!(@complex, c32, lapack::cgelsd);