lax/
eig.rs

1//! Eigenvalue decomposition for general matrices
2
3use crate::{error::*, layout::MatrixLayout, *};
4use cauchy::*;
5use num_traits::{ToPrimitive, Zero};
6
7/// Wraps `*geev` for general matrices
8pub trait Eig_: Scalar {
9    /// Calculate Right eigenvalue
10    fn eig(
11        calc_v: bool,
12        l: MatrixLayout,
13        a: &mut [Self],
14    ) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)>;
15}
16
17macro_rules! impl_eig_complex {
18    ($scalar:ty, $ev:path) => {
19        impl Eig_ for $scalar {
20            fn eig(
21                calc_v: bool,
22                l: MatrixLayout,
23                mut a: &mut [Self],
24            ) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
25                let (n, _) = l.size();
26                // LAPACK assumes a column-major input. A row-major input can
27                // be interpreted as the transpose of a column-major input. So,
28                // for row-major inputs, we we want to solve the following,
29                // given the column-major input `A`:
30                //
31                //   A^T V = V Λ ⟺ V^T A = Λ V^T ⟺ conj(V)^H A = Λ conj(V)^H
32                //
33                // So, in this case, the right eigenvectors are the conjugates
34                // of the left eigenvectors computed with `A`, and the
35                // eigenvalues are the eigenvalues computed with `A`.
36                let (jobvl, jobvr) = if calc_v {
37                    match l {
38                        MatrixLayout::C { .. } => (b'V', b'N'),
39                        MatrixLayout::F { .. } => (b'N', b'V'),
40                    }
41                } else {
42                    (b'N', b'N')
43                };
44                let mut eigs = unsafe { vec_uninit(n as usize) };
45                let mut rwork = unsafe { vec_uninit(2 * n as usize) };
46
47                let mut vl = if jobvl == b'V' {
48                    Some(unsafe { vec_uninit((n * n) as usize) })
49                } else {
50                    None
51                };
52                let mut vr = if jobvr == b'V' {
53                    Some(unsafe { vec_uninit((n * n) as usize) })
54                } else {
55                    None
56                };
57
58                // calc work size
59                let mut info = 0;
60                let mut work_size = [Self::zero()];
61                unsafe {
62                    $ev(
63                        jobvl,
64                        jobvr,
65                        n,
66                        &mut a,
67                        n,
68                        &mut eigs,
69                        &mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
70                        n,
71                        &mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
72                        n,
73                        &mut work_size,
74                        -1,
75                        &mut rwork,
76                        &mut info,
77                    )
78                };
79                info.as_lapack_result()?;
80
81                // actal ev
82                let lwork = work_size[0].to_usize().unwrap();
83                let mut work = unsafe { vec_uninit(lwork) };
84                unsafe {
85                    $ev(
86                        jobvl,
87                        jobvr,
88                        n,
89                        &mut a,
90                        n,
91                        &mut eigs,
92                        &mut vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
93                        n,
94                        &mut vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
95                        n,
96                        &mut work,
97                        lwork as i32,
98                        &mut rwork,
99                        &mut info,
100                    )
101                };
102                info.as_lapack_result()?;
103
104                // Hermite conjugate
105                if jobvl == b'V' {
106                    for c in vl.as_mut().unwrap().iter_mut() {
107                        c.im = -c.im
108                    }
109                }
110
111                Ok((eigs, vr.or(vl).unwrap_or(Vec::new())))
112            }
113        }
114    };
115}
116
117impl_eig_complex!(c64, lapack::zgeev);
118impl_eig_complex!(c32, lapack::cgeev);
119
120macro_rules! impl_eig_real {
121    ($scalar:ty, $ev:path) => {
122        impl Eig_ for $scalar {
123            fn eig(
124                calc_v: bool,
125                l: MatrixLayout,
126                mut a: &mut [Self],
127            ) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
128                let (n, _) = l.size();
129                // LAPACK assumes a column-major input. A row-major input can
130                // be interpreted as the transpose of a column-major input. So,
131                // for row-major inputs, we we want to solve the following,
132                // given the column-major input `A`:
133                //
134                //   A^T V = V Λ ⟺ V^T A = Λ V^T ⟺ conj(V)^H A = Λ conj(V)^H
135                //
136                // So, in this case, the right eigenvectors are the conjugates
137                // of the left eigenvectors computed with `A`, and the
138                // eigenvalues are the eigenvalues computed with `A`.
139                //
140                // We could conjugate the eigenvalues instead of the
141                // eigenvectors, but we have to reconstruct the eigenvectors
142                // into new matrices anyway, and by not modifying the
143                // eigenvalues, we preserve the nice ordering specified by
144                // `sgeev`/`dgeev`.
145                let (jobvl, jobvr) = if calc_v {
146                    match l {
147                        MatrixLayout::C { .. } => (b'V', b'N'),
148                        MatrixLayout::F { .. } => (b'N', b'V'),
149                    }
150                } else {
151                    (b'N', b'N')
152                };
153                let mut eig_re = unsafe { vec_uninit(n as usize) };
154                let mut eig_im = unsafe { vec_uninit(n as usize) };
155
156                let mut vl = if jobvl == b'V' {
157                    Some(unsafe { vec_uninit((n * n) as usize) })
158                } else {
159                    None
160                };
161                let mut vr = if jobvr == b'V' {
162                    Some(unsafe { vec_uninit((n * n) as usize) })
163                } else {
164                    None
165                };
166
167                // calc work size
168                let mut info = 0;
169                let mut work_size = [0.0];
170                unsafe {
171                    $ev(
172                        jobvl,
173                        jobvr,
174                        n,
175                        &mut a,
176                        n,
177                        &mut eig_re,
178                        &mut eig_im,
179                        vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
180                        n,
181                        vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
182                        n,
183                        &mut work_size,
184                        -1,
185                        &mut info,
186                    )
187                };
188                info.as_lapack_result()?;
189
190                // actual ev
191                let lwork = work_size[0].to_usize().unwrap();
192                let mut work = unsafe { vec_uninit(lwork) };
193                unsafe {
194                    $ev(
195                        jobvl,
196                        jobvr,
197                        n,
198                        &mut a,
199                        n,
200                        &mut eig_re,
201                        &mut eig_im,
202                        vl.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
203                        n,
204                        vr.as_mut().map(|v| v.as_mut_slice()).unwrap_or(&mut []),
205                        n,
206                        &mut work,
207                        lwork as i32,
208                        &mut info,
209                    )
210                };
211                info.as_lapack_result()?;
212
213                // reconstruct eigenvalues
214                let eigs: Vec<Self::Complex> = eig_re
215                    .iter()
216                    .zip(eig_im.iter())
217                    .map(|(&re, &im)| Self::complex(re, im))
218                    .collect();
219
220                if !calc_v {
221                    return Ok((eigs, Vec::new()));
222                }
223
224                // Reconstruct eigenvectors into complex-array
225                // --------------------------------------------
226                //
227                // From LAPACK API https://software.intel.com/en-us/node/469230
228                //
229                // - If the j-th eigenvalue is real,
230                //   - v(j) = VR(:,j), the j-th column of VR.
231                //
232                // - If the j-th and (j+1)-st eigenvalues form a complex conjugate pair,
233                //   - v(j)   = VR(:,j) + i*VR(:,j+1)
234                //   - v(j+1) = VR(:,j) - i*VR(:,j+1).
235                //
236                // In the C-layout case, we need the conjugates of the left
237                // eigenvectors, so the signs should be reversed.
238
239                let n = n as usize;
240                let v = vr.or(vl).unwrap();
241                let mut eigvecs = unsafe { vec_uninit(n * n) };
242                let mut col = 0;
243                while col < n {
244                    if eig_im[col] == 0. {
245                        // The corresponding eigenvalue is real.
246                        for row in 0..n {
247                            let re = v[row + col * n];
248                            eigvecs[row + col * n] = Self::complex(re, 0.);
249                        }
250                        col += 1;
251                    } else {
252                        // This is a complex conjugate pair.
253                        assert!(col + 1 < n);
254                        for row in 0..n {
255                            let re = v[row + col * n];
256                            let mut im = v[row + (col + 1) * n];
257                            if jobvl == b'V' {
258                                im = -im;
259                            }
260                            eigvecs[row + col * n] = Self::complex(re, im);
261                            eigvecs[row + (col + 1) * n] = Self::complex(re, -im);
262                        }
263                        col += 2;
264                    }
265                }
266
267                Ok((eigs, eigvecs))
268            }
269        }
270    };
271}
272
273impl_eig_real!(f64, lapack::dgeev);
274impl_eig_real!(f32, lapack::sgeev);