lax/
svddc.rs

1use crate::{error::*, layout::MatrixLayout, *};
2use cauchy::*;
3use num_traits::{ToPrimitive, Zero};
4
5/// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned.
6///
7/// For an input array of shape *m*×*n*, the following are computed:
8#[derive(Clone, Copy, Eq, PartialEq)]
9#[repr(u8)]
10pub enum UVTFlag {
11    /// All *m* columns of *U* and all *n* rows of *V*ᵀ.
12    Full = b'A',
13    /// The first min(*m*,*n*) columns of *U* and the first min(*m*,*n*) rows of *V*ᵀ.
14    Some = b'S',
15    /// No columns of *U* or rows of *V*ᵀ.
16    None = b'N',
17}
18
19pub trait SVDDC_: Scalar {
20    fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result<SVDOutput<Self>>;
21}
22
23macro_rules! impl_svddc {
24    (@real, $scalar:ty, $gesdd:path) => {
25        impl_svddc!(@body, $scalar, $gesdd, );
26    };
27    (@complex, $scalar:ty, $gesdd:path) => {
28        impl_svddc!(@body, $scalar, $gesdd, rwork);
29    };
30    (@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => {
31        impl SVDDC_ for $scalar {
32            fn svddc(l: MatrixLayout, jobz: UVTFlag, mut a: &mut [Self],) -> Result<SVDOutput<Self>> {
33                let m = l.lda();
34                let n = l.len();
35                let k = m.min(n);
36                let mut s = unsafe { vec_uninit( k as usize) };
37
38                let (u_col, vt_row) = match jobz {
39                    UVTFlag::Full | UVTFlag::None => (m, n),
40                    UVTFlag::Some => (k, k),
41                };
42                let (mut u, mut vt) = match jobz {
43                    UVTFlag::Full => (
44                        Some(unsafe { vec_uninit( (m * m) as usize) }),
45                        Some(unsafe { vec_uninit( (n * n) as usize) }),
46                    ),
47                    UVTFlag::Some => (
48                        Some(unsafe { vec_uninit( (m * u_col) as usize) }),
49                        Some(unsafe { vec_uninit( (n * vt_row) as usize) }),
50                    ),
51                    UVTFlag::None => (None, None),
52                };
53
54                $( // for complex only
55                let mx = n.max(m) as usize;
56                let mn = n.min(m) as usize;
57                let lrwork = match jobz {
58                    UVTFlag::None => 7 * mn,
59                    _ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn),
60                };
61                let mut $rwork_ident = unsafe { vec_uninit( lrwork) };
62                )*
63
64                // eval work size
65                let mut info = 0;
66                let mut iwork = unsafe { vec_uninit( 8 * k as usize) };
67                let mut work_size = [Self::zero()];
68                unsafe {
69                    $gesdd(
70                        jobz as u8,
71                        m,
72                        n,
73                        &mut a,
74                        m,
75                        &mut s,
76                        u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
77                        m,
78                        vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
79                        vt_row,
80                        &mut work_size,
81                        -1,
82                        $(&mut $rwork_ident,)*
83                        &mut iwork,
84                        &mut info,
85                    );
86                }
87                info.as_lapack_result()?;
88
89                // do svd
90                let lwork = work_size[0].to_usize().unwrap();
91                let mut work = unsafe { vec_uninit( lwork) };
92                unsafe {
93                    $gesdd(
94                        jobz as u8,
95                        m,
96                        n,
97                        &mut a,
98                        m,
99                        &mut s,
100                        u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
101                        m,
102                        vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
103                        vt_row,
104                        &mut work,
105                        lwork as i32,
106                        $(&mut $rwork_ident,)*
107                        &mut iwork,
108                        &mut info,
109                    );
110                }
111                info.as_lapack_result()?;
112
113                match l {
114                    MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
115                    MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),
116                }
117            }
118        }
119    };
120}
121
122impl_svddc!(@real, f32, lapack::sgesdd);
123impl_svddc!(@real, f64, lapack::dgesdd);
124impl_svddc!(@complex, c32, lapack::cgesdd);
125impl_svddc!(@complex, c64, lapack::zgesdd);