1use crate::{error::*, layout::MatrixLayout, *};
2use cauchy::*;
3use num_traits::{ToPrimitive, Zero};
4
5#[derive(Clone, Copy, Eq, PartialEq)]
9#[repr(u8)]
10pub enum UVTFlag {
11 Full = b'A',
13 Some = b'S',
15 None = b'N',
17}
18
19pub trait SVDDC_: Scalar {
20 fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result<SVDOutput<Self>>;
21}
22
23macro_rules! impl_svddc {
24 (@real, $scalar:ty, $gesdd:path) => {
25 impl_svddc!(@body, $scalar, $gesdd, );
26 };
27 (@complex, $scalar:ty, $gesdd:path) => {
28 impl_svddc!(@body, $scalar, $gesdd, rwork);
29 };
30 (@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => {
31 impl SVDDC_ for $scalar {
32 fn svddc(l: MatrixLayout, jobz: UVTFlag, mut a: &mut [Self],) -> Result<SVDOutput<Self>> {
33 let m = l.lda();
34 let n = l.len();
35 let k = m.min(n);
36 let mut s = unsafe { vec_uninit( k as usize) };
37
38 let (u_col, vt_row) = match jobz {
39 UVTFlag::Full | UVTFlag::None => (m, n),
40 UVTFlag::Some => (k, k),
41 };
42 let (mut u, mut vt) = match jobz {
43 UVTFlag::Full => (
44 Some(unsafe { vec_uninit( (m * m) as usize) }),
45 Some(unsafe { vec_uninit( (n * n) as usize) }),
46 ),
47 UVTFlag::Some => (
48 Some(unsafe { vec_uninit( (m * u_col) as usize) }),
49 Some(unsafe { vec_uninit( (n * vt_row) as usize) }),
50 ),
51 UVTFlag::None => (None, None),
52 };
53
54 $( let mx = n.max(m) as usize;
56 let mn = n.min(m) as usize;
57 let lrwork = match jobz {
58 UVTFlag::None => 7 * mn,
59 _ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn),
60 };
61 let mut $rwork_ident = unsafe { vec_uninit( lrwork) };
62 )*
63
64 let mut info = 0;
66 let mut iwork = unsafe { vec_uninit( 8 * k as usize) };
67 let mut work_size = [Self::zero()];
68 unsafe {
69 $gesdd(
70 jobz as u8,
71 m,
72 n,
73 &mut a,
74 m,
75 &mut s,
76 u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
77 m,
78 vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
79 vt_row,
80 &mut work_size,
81 -1,
82 $(&mut $rwork_ident,)*
83 &mut iwork,
84 &mut info,
85 );
86 }
87 info.as_lapack_result()?;
88
89 let lwork = work_size[0].to_usize().unwrap();
91 let mut work = unsafe { vec_uninit( lwork) };
92 unsafe {
93 $gesdd(
94 jobz as u8,
95 m,
96 n,
97 &mut a,
98 m,
99 &mut s,
100 u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
101 m,
102 vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
103 vt_row,
104 &mut work,
105 lwork as i32,
106 $(&mut $rwork_ident,)*
107 &mut iwork,
108 &mut info,
109 );
110 }
111 info.as_lapack_result()?;
112
113 match l {
114 MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
115 MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),
116 }
117 }
118 }
119 };
120}
121
122impl_svddc!(@real, f32, lapack::sgesdd);
123impl_svddc!(@real, f64, lapack::dgesdd);
124impl_svddc!(@complex, c32, lapack::cgesdd);
125impl_svddc!(@complex, c64, lapack::zgesdd);