lax/
eigh.rs

1//! Eigenvalue decomposition for Symmetric/Hermite matrices
2
3use super::*;
4use crate::{error::*, layout::MatrixLayout};
5use cauchy::*;
6use num_traits::{ToPrimitive, Zero};
7
8pub trait Eigh_: Scalar {
9    /// Wraps `*syev` for real and `*heev` for complex
10    fn eigh(
11        calc_eigenvec: bool,
12        layout: MatrixLayout,
13        uplo: UPLO,
14        a: &mut [Self],
15    ) -> Result<Vec<Self::Real>>;
16
17    /// Wraps `*sygv` for real and `*hegv` for complex
18    fn eigh_generalized(
19        calc_eigenvec: bool,
20        layout: MatrixLayout,
21        uplo: UPLO,
22        a: &mut [Self],
23        b: &mut [Self],
24    ) -> Result<Vec<Self::Real>>;
25}
26
27macro_rules! impl_eigh {
28    (@real, $scalar:ty, $ev:path, $evg:path) => {
29        impl_eigh!(@body, $scalar, $ev, $evg, );
30    };
31    (@complex, $scalar:ty, $ev:path, $evg:path) => {
32        impl_eigh!(@body, $scalar, $ev, $evg, rwork);
33    };
34    (@body, $scalar:ty, $ev:path, $evg:path, $($rwork_ident:ident),*) => {
35        impl Eigh_ for $scalar {
36            fn eigh(
37                calc_v: bool,
38                layout: MatrixLayout,
39                uplo: UPLO,
40                mut a: &mut [Self],
41            ) -> Result<Vec<Self::Real>> {
42                assert_eq!(layout.len(), layout.lda());
43                let n = layout.len();
44                let jobz = if calc_v { b'V' } else { b'N' };
45                let mut eigs = unsafe { vec_uninit(n as usize) };
46
47                $(
48                let mut $rwork_ident = unsafe { vec_uninit(3 * n as usize - 2 as usize) };
49                )*
50
51                // calc work size
52                let mut info = 0;
53                let mut work_size = [Self::zero()];
54                unsafe {
55                    $ev(
56                        jobz,
57                        uplo as u8,
58                        n,
59                        &mut a,
60                        n,
61                        &mut eigs,
62                        &mut work_size,
63                        -1,
64                        $(&mut $rwork_ident,)*
65                        &mut info,
66                    );
67                }
68                info.as_lapack_result()?;
69
70                // actual ev
71                let lwork = work_size[0].to_usize().unwrap();
72                let mut work = unsafe { vec_uninit(lwork) };
73                unsafe {
74                    $ev(
75                        jobz,
76                        uplo as u8,
77                        n,
78                        &mut a,
79                        n,
80                        &mut eigs,
81                        &mut work,
82                        lwork as i32,
83                        $(&mut $rwork_ident,)*
84                        &mut info,
85                    );
86                }
87                info.as_lapack_result()?;
88                Ok(eigs)
89            }
90
91            fn eigh_generalized(
92                calc_v: bool,
93                layout: MatrixLayout,
94                uplo: UPLO,
95                mut a: &mut [Self],
96                mut b: &mut [Self],
97            ) -> Result<Vec<Self::Real>> {
98                assert_eq!(layout.len(), layout.lda());
99                let n = layout.len();
100                let jobz = if calc_v { b'V' } else { b'N' };
101                let mut eigs = unsafe { vec_uninit(n as usize) };
102
103                $(
104                let mut $rwork_ident = unsafe { vec_uninit(3 * n as usize - 2) };
105                )*
106
107                // calc work size
108                let mut info = 0;
109                let mut work_size = [Self::zero()];
110                unsafe {
111                    $evg(
112                        &[1],
113                        jobz,
114                        uplo as u8,
115                        n,
116                        &mut a,
117                        n,
118                        &mut b,
119                        n,
120                        &mut eigs,
121                        &mut work_size,
122                        -1,
123                        $(&mut $rwork_ident,)*
124                        &mut info,
125                    );
126                }
127                info.as_lapack_result()?;
128
129                // actual evg
130                let lwork = work_size[0].to_usize().unwrap();
131                let mut work = unsafe { vec_uninit(lwork) };
132                unsafe {
133                    $evg(
134                        &[1],
135                        jobz,
136                        uplo as u8,
137                        n,
138                        &mut a,
139                        n,
140                        &mut b,
141                        n,
142                        &mut eigs,
143                        &mut work,
144                        lwork as i32,
145                        $(&mut $rwork_ident,)*
146                        &mut info,
147                    );
148                }
149                info.as_lapack_result()?;
150                Ok(eigs)
151            }
152        }
153    };
154} // impl_eigh!
155
156impl_eigh!(@real, f64, lapack::dsyev, lapack::dsygv);
157impl_eigh!(@real, f32, lapack::ssyev, lapack::ssygv);
158impl_eigh!(@complex, c64, lapack::zheev, lapack::zhegv);
159impl_eigh!(@complex, c32, lapack::cheev, lapack::chegv);