1use crate::{error::*, layout::MatrixLayout, *};
4use cauchy::*;
5use num_traits::{ToPrimitive, Zero};
6
7#[repr(u8)]
8#[derive(Debug, Copy, Clone)]
9enum FlagSVD {
10 All = b'A',
11 No = b'N',
14}
15
16impl FlagSVD {
17 fn from_bool(calc_uv: bool) -> Self {
18 if calc_uv {
19 FlagSVD::All
20 } else {
21 FlagSVD::No
22 }
23 }
24}
25
26pub struct SVDOutput<A: Scalar> {
28 pub s: Vec<A::Real>,
30 pub u: Option<Vec<A>>,
32 pub vt: Option<Vec<A>>,
34}
35
36pub trait SVD_: Scalar {
38 fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self])
40 -> Result<SVDOutput<Self>>;
41}
42
43macro_rules! impl_svd {
44 (@real, $scalar:ty, $gesvd:path) => {
45 impl_svd!(@body, $scalar, $gesvd, );
46 };
47 (@complex, $scalar:ty, $gesvd:path) => {
48 impl_svd!(@body, $scalar, $gesvd, rwork);
49 };
50 (@body, $scalar:ty, $gesvd:path, $($rwork_ident:ident),*) => {
51 impl SVD_ for $scalar {
52 fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, mut a: &mut [Self],) -> Result<SVDOutput<Self>> {
53 let ju = match l {
54 MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u),
55 MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt),
56 };
57 let jvt = match l {
58 MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt),
59 MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u),
60 };
61
62 let m = l.lda();
63 let mut u = match ju {
64 FlagSVD::All => Some(unsafe { vec_uninit( (m * m) as usize) }),
65 FlagSVD::No => None,
66 };
67
68 let n = l.len();
69 let mut vt = match jvt {
70 FlagSVD::All => Some(unsafe { vec_uninit( (n * n) as usize) }),
71 FlagSVD::No => None,
72 };
73
74 let k = std::cmp::min(m, n);
75 let mut s = unsafe { vec_uninit( k as usize) };
76
77 $(
78 let mut $rwork_ident = unsafe { vec_uninit( 5 * k as usize) };
79 )*
80
81 let mut info = 0;
83 let mut work_size = [Self::zero()];
84 unsafe {
85 $gesvd(
86 ju as u8,
87 jvt as u8,
88 m,
89 n,
90 &mut a,
91 m,
92 &mut s,
93 u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
94 m,
95 vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
96 n,
97 &mut work_size,
98 -1,
99 $(&mut $rwork_ident,)*
100 &mut info,
101 );
102 }
103 info.as_lapack_result()?;
104
105 let lwork = work_size[0].to_usize().unwrap();
107 let mut work = unsafe { vec_uninit( lwork) };
108 unsafe {
109 $gesvd(
110 ju as u8,
111 jvt as u8,
112 m,
113 n,
114 &mut a,
115 m,
116 &mut s,
117 u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
118 m,
119 vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
120 n,
121 &mut work,
122 lwork as i32,
123 $(&mut $rwork_ident,)*
124 &mut info,
125 );
126 }
127 info.as_lapack_result()?;
128 match l {
129 MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
130 MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),
131 }
132 }
133 }
134 };
135} impl_svd!(@real, f64, lapack::dgesvd);
138impl_svd!(@real, f32, lapack::sgesvd);
139impl_svd!(@complex, c64, lapack::zgesvd);
140impl_svd!(@complex, c32, lapack::cgesvd);