lax/
svd.rs

1//! Singular-value decomposition
2
3use crate::{error::*, layout::MatrixLayout, *};
4use cauchy::*;
5use num_traits::{ToPrimitive, Zero};
6
7#[repr(u8)]
8#[derive(Debug, Copy, Clone)]
9enum FlagSVD {
10    All = b'A',
11    // OverWrite = b'O',
12    // Separately = b'S',
13    No = b'N',
14}
15
16impl FlagSVD {
17    fn from_bool(calc_uv: bool) -> Self {
18        if calc_uv {
19            FlagSVD::All
20        } else {
21            FlagSVD::No
22        }
23    }
24}
25
26/// Result of SVD
27pub struct SVDOutput<A: Scalar> {
28    /// diagonal values
29    pub s: Vec<A::Real>,
30    /// Unitary matrix for destination space
31    pub u: Option<Vec<A>>,
32    /// Unitary matrix for departure space
33    pub vt: Option<Vec<A>>,
34}
35
36/// Wraps `*gesvd`
37pub trait SVD_: Scalar {
38    /// Calculate singular value decomposition $ A = U \Sigma V^T $
39    fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self])
40        -> Result<SVDOutput<Self>>;
41}
42
43macro_rules! impl_svd {
44    (@real, $scalar:ty, $gesvd:path) => {
45        impl_svd!(@body, $scalar, $gesvd, );
46    };
47    (@complex, $scalar:ty, $gesvd:path) => {
48        impl_svd!(@body, $scalar, $gesvd, rwork);
49    };
50    (@body, $scalar:ty, $gesvd:path, $($rwork_ident:ident),*) => {
51        impl SVD_ for $scalar {
52            fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, mut a: &mut [Self],) -> Result<SVDOutput<Self>> {
53                let ju = match l {
54                    MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u),
55                    MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt),
56                };
57                let jvt = match l {
58                    MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt),
59                    MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u),
60                };
61
62                let m = l.lda();
63                let mut u = match ju {
64                    FlagSVD::All => Some(unsafe { vec_uninit( (m * m) as usize) }),
65                    FlagSVD::No => None,
66                };
67
68                let n = l.len();
69                let mut vt = match jvt {
70                    FlagSVD::All => Some(unsafe { vec_uninit( (n * n) as usize) }),
71                    FlagSVD::No => None,
72                };
73
74                let k = std::cmp::min(m, n);
75                let mut s = unsafe { vec_uninit( k as usize) };
76
77                $(
78                let mut $rwork_ident = unsafe { vec_uninit( 5 * k as usize) };
79                )*
80
81                // eval work size
82                let mut info = 0;
83                let mut work_size = [Self::zero()];
84                unsafe {
85                    $gesvd(
86                        ju as u8,
87                        jvt as u8,
88                        m,
89                        n,
90                        &mut a,
91                        m,
92                        &mut s,
93                        u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
94                        m,
95                        vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
96                        n,
97                        &mut work_size,
98                        -1,
99                        $(&mut $rwork_ident,)*
100                        &mut info,
101                    );
102                }
103                info.as_lapack_result()?;
104
105                // calc
106                let lwork = work_size[0].to_usize().unwrap();
107                let mut work = unsafe { vec_uninit( lwork) };
108                unsafe {
109                    $gesvd(
110                        ju as u8,
111                        jvt as u8,
112                        m,
113                        n,
114                        &mut a,
115                        m,
116                        &mut s,
117                        u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
118                        m,
119                        vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
120                        n,
121                        &mut work,
122                        lwork as i32,
123                        $(&mut $rwork_ident,)*
124                        &mut info,
125                    );
126                }
127                info.as_lapack_result()?;
128                match l {
129                    MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
130                    MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),
131                }
132            }
133        }
134    };
135} // impl_svd!
136
137impl_svd!(@real, f64, lapack::dgesvd);
138impl_svd!(@real, f32, lapack::sgesvd);
139impl_svd!(@complex, c64, lapack::zgesvd);
140impl_svd!(@complex, c32, lapack::cgesvd);