1#[cfg(feature = "std")]
12use matrixmultiply;
13use num::{One, Zero};
14use simba::scalar::{ClosedAddAssign, ClosedMulAssign};
15#[cfg(feature = "std")]
16use std::{any::TypeId, mem};
17
18use crate::base::constraint::{
19 AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint,
20};
21#[cfg(feature = "std")]
22use crate::base::dimension::Dyn;
23use crate::base::dimension::{Dim, U1};
24use crate::base::storage::{RawStorage, RawStorageMut};
25use crate::base::uninit::InitStatus;
26use crate::base::{Matrix, Scalar, Vector};
27
28#[allow(clippy::too_many_arguments)]
32unsafe fn array_axcpy<Status, T>(
33 _: Status,
34 y: &mut [Status::Value],
35 a: T,
36 x: &[T],
37 c: T,
38 beta: T,
39 stride1: usize,
40 stride2: usize,
41 len: usize,
42) where
43 Status: InitStatus<T>,
44 T: Scalar + Zero + ClosedAddAssign + ClosedMulAssign,
45{
46 for i in 0..len {
47 let y = Status::assume_init_mut(y.get_unchecked_mut(i * stride1));
48 *y =
49 a.clone() * x.get_unchecked(i * stride2).clone() * c.clone() + beta.clone() * y.clone();
50 }
51}
52
53fn array_axc<Status, T>(
54 _: Status,
55 y: &mut [Status::Value],
56 a: T,
57 x: &[T],
58 c: T,
59 stride1: usize,
60 stride2: usize,
61 len: usize,
62) where
63 Status: InitStatus<T>,
64 T: Scalar + Zero + ClosedAddAssign + ClosedMulAssign,
65{
66 for i in 0..len {
67 unsafe {
68 Status::init(
69 y.get_unchecked_mut(i * stride1),
70 a.clone() * x.get_unchecked(i * stride2).clone() * c.clone(),
71 );
72 }
73 }
74}
75
76#[inline(always)]
83#[allow(clippy::many_single_char_names)]
84pub unsafe fn axcpy_uninit<Status, T, D1: Dim, D2: Dim, SA, SB>(
85 status: Status,
86 y: &mut Vector<Status::Value, D1, SA>,
87 a: T,
88 x: &Vector<T, D2, SB>,
89 c: T,
90 b: T,
91) where
92 T: Scalar + Zero + ClosedAddAssign + ClosedMulAssign,
93 SA: RawStorageMut<Status::Value, D1>,
94 SB: RawStorage<T, D2>,
95 ShapeConstraint: DimEq<D1, D2>,
96 Status: InitStatus<T>,
97{
98 assert_eq!(y.nrows(), x.nrows(), "Axcpy: mismatched vector shapes.");
99
100 let rstride1 = y.strides().0;
101 let rstride2 = x.strides().0;
102
103 let y = y.data.as_mut_slice_unchecked();
106 let x = x.data.as_slice_unchecked();
107
108 if !b.is_zero() {
109 array_axcpy(status, y, a, x, c, b, rstride1, rstride2, x.len());
110 } else {
111 array_axc(status, y, a, x, c, rstride1, rstride2, x.len());
112 }
113}
114
115#[inline(always)]
123pub unsafe fn gemv_uninit<Status, T, D1: Dim, R2: Dim, C2: Dim, D3: Dim, SA, SB, SC>(
124 status: Status,
125 y: &mut Vector<Status::Value, D1, SA>,
126 alpha: T,
127 a: &Matrix<T, R2, C2, SB>,
128 x: &Vector<T, D3, SC>,
129 beta: T,
130) where
131 Status: InitStatus<T>,
132 T: Scalar + Zero + One + ClosedAddAssign + ClosedMulAssign,
133 SA: RawStorageMut<Status::Value, D1>,
134 SB: RawStorage<T, R2, C2>,
135 SC: RawStorage<T, D3>,
136 ShapeConstraint: DimEq<D1, R2> + AreMultipliable<R2, C2, D3, U1>,
137{
138 let dim1 = y.nrows();
139 let (nrows2, ncols2) = a.shape();
140 let dim3 = x.nrows();
141
142 assert!(
143 ncols2 == dim3 && dim1 == nrows2,
144 "Gemv: dimensions mismatch."
145 );
146
147 if ncols2 == 0 {
148 if beta.is_zero() {
149 y.apply(|e| Status::init(e, T::zero()));
150 } else {
151 y.apply(|e| *Status::assume_init_mut(e) *= beta.clone());
153 }
154 return;
155 }
156
157 let col2 = a.column(0);
159 let val = x.vget_unchecked(0).clone();
160
161 axcpy_uninit(status, y, alpha.clone(), &col2, val, beta);
163
164 for j in 1..ncols2 {
165 let col2 = a.column(j);
166 let val = x.vget_unchecked(j).clone();
167
168 axcpy_uninit(status, y, alpha.clone(), &col2, val, T::one());
170 }
171}
172
173#[inline(always)]
181pub unsafe fn gemm_uninit<
182 Status,
183 T,
184 R1: Dim,
185 C1: Dim,
186 R2: Dim,
187 C2: Dim,
188 R3: Dim,
189 C3: Dim,
190 SA,
191 SB,
192 SC,
193>(
194 status: Status,
195 y: &mut Matrix<Status::Value, R1, C1, SA>,
196 alpha: T,
197 a: &Matrix<T, R2, C2, SB>,
198 b: &Matrix<T, R3, C3, SC>,
199 beta: T,
200) where
201 Status: InitStatus<T>,
202 T: Scalar + Zero + One + ClosedAddAssign + ClosedMulAssign,
203 SA: RawStorageMut<Status::Value, R1, C1>,
204 SB: RawStorage<T, R2, C2>,
205 SC: RawStorage<T, R3, C3>,
206 ShapeConstraint:
207 SameNumberOfRows<R1, R2> + SameNumberOfColumns<C1, C3> + AreMultipliable<R2, C2, R3, C3>,
208{
209 let ncols1 = y.ncols();
210
211 #[cfg(feature = "std")]
212 {
213 if R1::is::<Dyn>()
218 || C1::is::<Dyn>()
219 || R2::is::<Dyn>()
220 || C2::is::<Dyn>()
221 || R3::is::<Dyn>()
222 || C3::is::<Dyn>()
223 {
224 let nrows1 = y.nrows();
226 let (nrows2, ncols2) = a.shape();
227 let (nrows3, ncols3) = b.shape();
228
229 const SMALL_DIM: usize = 5;
231
232 if nrows1 > SMALL_DIM && ncols1 > SMALL_DIM && nrows2 > SMALL_DIM && ncols2 > SMALL_DIM
233 {
234 assert_eq!(
235 ncols2, nrows3,
236 "gemm: dimensions mismatch for multiplication."
237 );
238 assert_eq!(
239 (nrows1, ncols1),
240 (nrows2, ncols3),
241 "gemm: dimensions mismatch for addition."
242 );
243
244 if ncols2 == 0 {
249 if beta.is_zero() {
253 y.apply(|e| Status::init(e, T::zero()));
254 } else {
255 y.apply(|e| *Status::assume_init_mut(e) *= beta.clone());
257 }
258 return;
259 }
260
261 if TypeId::of::<T>() == TypeId::of::<f32>() {
262 let (rsa, csa) = a.strides();
263 let (rsb, csb) = b.strides();
264 let (rsc, csc) = y.strides();
265
266 matrixmultiply::sgemm(
267 nrows2,
268 ncols2,
269 ncols3,
270 mem::transmute_copy(&alpha),
271 a.data.ptr() as *const f32,
272 rsa as isize,
273 csa as isize,
274 b.data.ptr() as *const f32,
275 rsb as isize,
276 csb as isize,
277 mem::transmute_copy(&beta),
278 y.data.ptr_mut() as *mut f32,
279 rsc as isize,
280 csc as isize,
281 );
282 return;
283 } else if TypeId::of::<T>() == TypeId::of::<f64>() {
284 let (rsa, csa) = a.strides();
285 let (rsb, csb) = b.strides();
286 let (rsc, csc) = y.strides();
287
288 matrixmultiply::dgemm(
289 nrows2,
290 ncols2,
291 ncols3,
292 mem::transmute_copy(&alpha),
293 a.data.ptr() as *const f64,
294 rsa as isize,
295 csa as isize,
296 b.data.ptr() as *const f64,
297 rsb as isize,
298 csb as isize,
299 mem::transmute_copy(&beta),
300 y.data.ptr_mut() as *mut f64,
301 rsc as isize,
302 csc as isize,
303 );
304 return;
305 }
306 }
307 }
308 }
309
310 for j1 in 0..ncols1 {
311 gemv_uninit(
314 status,
315 &mut y.column_mut(j1),
316 alpha.clone(),
317 a,
318 &b.column(j1),
319 beta.clone(),
320 );
321 }
322}