nalgebra/base/
blas_uninit.rs

1/*
2 * This file implements some BLAS operations in such a way that they work
3 * even if the first argument (the output parameter) is an uninitialized matrix.
4 *
5 * Because doing this makes the code harder to read, we only implemented the operations that we
6 * know would benefit from this performance-wise, namely, GEMM (which we use for our matrix
7 * multiplication code). If we identify other operations like that in the future, we could add
8 * them here.
9 */
10
11#[cfg(feature = "std")]
12use matrixmultiply;
13use num::{One, Zero};
14use simba::scalar::{ClosedAddAssign, ClosedMulAssign};
15#[cfg(feature = "std")]
16use std::{any::TypeId, mem};
17
18use crate::base::constraint::{
19    AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint,
20};
21#[cfg(feature = "std")]
22use crate::base::dimension::Dyn;
23use crate::base::dimension::{Dim, U1};
24use crate::base::storage::{RawStorage, RawStorageMut};
25use crate::base::uninit::InitStatus;
26use crate::base::{Matrix, Scalar, Vector};
27
28// # Safety
29// The content of `y` must only contain values for which
30// `Status::assume_init_mut` is sound.
31#[allow(clippy::too_many_arguments)]
32unsafe fn array_axcpy<Status, T>(
33    _: Status,
34    y: &mut [Status::Value],
35    a: T,
36    x: &[T],
37    c: T,
38    beta: T,
39    stride1: usize,
40    stride2: usize,
41    len: usize,
42) where
43    Status: InitStatus<T>,
44    T: Scalar + Zero + ClosedAddAssign + ClosedMulAssign,
45{
46    for i in 0..len {
47        let y = Status::assume_init_mut(y.get_unchecked_mut(i * stride1));
48        *y =
49            a.clone() * x.get_unchecked(i * stride2).clone() * c.clone() + beta.clone() * y.clone();
50    }
51}
52
53fn array_axc<Status, T>(
54    _: Status,
55    y: &mut [Status::Value],
56    a: T,
57    x: &[T],
58    c: T,
59    stride1: usize,
60    stride2: usize,
61    len: usize,
62) where
63    Status: InitStatus<T>,
64    T: Scalar + Zero + ClosedAddAssign + ClosedMulAssign,
65{
66    for i in 0..len {
67        unsafe {
68            Status::init(
69                y.get_unchecked_mut(i * stride1),
70                a.clone() * x.get_unchecked(i * stride2).clone() * c.clone(),
71            );
72        }
73    }
74}
75
76/// Computes `y = a * x * c + b * y`.
77///
78/// If `b` is zero, `y` is never read from and may be uninitialized.
79///
80/// # Safety
81/// This is UB if b != 0 and any component of `y` is uninitialized.
82#[inline(always)]
83#[allow(clippy::many_single_char_names)]
84pub unsafe fn axcpy_uninit<Status, T, D1: Dim, D2: Dim, SA, SB>(
85    status: Status,
86    y: &mut Vector<Status::Value, D1, SA>,
87    a: T,
88    x: &Vector<T, D2, SB>,
89    c: T,
90    b: T,
91) where
92    T: Scalar + Zero + ClosedAddAssign + ClosedMulAssign,
93    SA: RawStorageMut<Status::Value, D1>,
94    SB: RawStorage<T, D2>,
95    ShapeConstraint: DimEq<D1, D2>,
96    Status: InitStatus<T>,
97{
98    assert_eq!(y.nrows(), x.nrows(), "Axcpy: mismatched vector shapes.");
99
100    let rstride1 = y.strides().0;
101    let rstride2 = x.strides().0;
102
103    // SAFETY: the conversion to slices is OK because we access the
104    //         elements taking the strides into account.
105    let y = y.data.as_mut_slice_unchecked();
106    let x = x.data.as_slice_unchecked();
107
108    if !b.is_zero() {
109        array_axcpy(status, y, a, x, c, b, rstride1, rstride2, x.len());
110    } else {
111        array_axc(status, y, a, x, c, rstride1, rstride2, x.len());
112    }
113}
114
115/// Computes `y = alpha * a * x + beta * y`, where `a` is a matrix, `x` a vector, and
116/// `alpha, beta` two scalars.
117///
118/// If `beta` is zero, `y` is never read from and may be uninitialized.
119///
120/// # Safety
121/// This is UB if beta != 0 and any component of `y` is uninitialized.
122#[inline(always)]
123pub unsafe fn gemv_uninit<Status, T, D1: Dim, R2: Dim, C2: Dim, D3: Dim, SA, SB, SC>(
124    status: Status,
125    y: &mut Vector<Status::Value, D1, SA>,
126    alpha: T,
127    a: &Matrix<T, R2, C2, SB>,
128    x: &Vector<T, D3, SC>,
129    beta: T,
130) where
131    Status: InitStatus<T>,
132    T: Scalar + Zero + One + ClosedAddAssign + ClosedMulAssign,
133    SA: RawStorageMut<Status::Value, D1>,
134    SB: RawStorage<T, R2, C2>,
135    SC: RawStorage<T, D3>,
136    ShapeConstraint: DimEq<D1, R2> + AreMultipliable<R2, C2, D3, U1>,
137{
138    let dim1 = y.nrows();
139    let (nrows2, ncols2) = a.shape();
140    let dim3 = x.nrows();
141
142    assert!(
143        ncols2 == dim3 && dim1 == nrows2,
144        "Gemv: dimensions mismatch."
145    );
146
147    if ncols2 == 0 {
148        if beta.is_zero() {
149            y.apply(|e| Status::init(e, T::zero()));
150        } else {
151            // SAFETY: this is UB if y is uninitialized.
152            y.apply(|e| *Status::assume_init_mut(e) *= beta.clone());
153        }
154        return;
155    }
156
157    // TODO: avoid bound checks.
158    let col2 = a.column(0);
159    let val = x.vget_unchecked(0).clone();
160
161    // SAFETY: this is the call that makes this method unsafe: it is UB if Status = Uninit and beta != 0.
162    axcpy_uninit(status, y, alpha.clone(), &col2, val, beta);
163
164    for j in 1..ncols2 {
165        let col2 = a.column(j);
166        let val = x.vget_unchecked(j).clone();
167
168        // SAFETY: safe because y was initialized above.
169        axcpy_uninit(status, y, alpha.clone(), &col2, val, T::one());
170    }
171}
172
173/// Computes `y = alpha * a * b + beta * y`, where `a, b, y` are matrices.
174/// `alpha` and `beta` are scalar.
175///
176/// If `beta` is zero, `y` is never read from and may be uninitialized.
177///
178/// # Safety
179/// This is UB if beta != 0 and any component of `y` is uninitialized.
180#[inline(always)]
181pub unsafe fn gemm_uninit<
182    Status,
183    T,
184    R1: Dim,
185    C1: Dim,
186    R2: Dim,
187    C2: Dim,
188    R3: Dim,
189    C3: Dim,
190    SA,
191    SB,
192    SC,
193>(
194    status: Status,
195    y: &mut Matrix<Status::Value, R1, C1, SA>,
196    alpha: T,
197    a: &Matrix<T, R2, C2, SB>,
198    b: &Matrix<T, R3, C3, SC>,
199    beta: T,
200) where
201    Status: InitStatus<T>,
202    T: Scalar + Zero + One + ClosedAddAssign + ClosedMulAssign,
203    SA: RawStorageMut<Status::Value, R1, C1>,
204    SB: RawStorage<T, R2, C2>,
205    SC: RawStorage<T, R3, C3>,
206    ShapeConstraint:
207        SameNumberOfRows<R1, R2> + SameNumberOfColumns<C1, C3> + AreMultipliable<R2, C2, R3, C3>,
208{
209    let ncols1 = y.ncols();
210
211    #[cfg(feature = "std")]
212    {
213        // We assume large matrices will be Dyn but small matrices static.
214        // We could use matrixmultiply for large statically-sized matrices but the performance
215        // threshold to activate it would be different from SMALL_DIM because our code optimizes
216        // better for statically-sized matrices.
217        if R1::is::<Dyn>()
218            || C1::is::<Dyn>()
219            || R2::is::<Dyn>()
220            || C2::is::<Dyn>()
221            || R3::is::<Dyn>()
222            || C3::is::<Dyn>()
223        {
224            // matrixmultiply can be used only if the std feature is available.
225            let nrows1 = y.nrows();
226            let (nrows2, ncols2) = a.shape();
227            let (nrows3, ncols3) = b.shape();
228
229            // Threshold determined empirically.
230            const SMALL_DIM: usize = 5;
231
232            if nrows1 > SMALL_DIM && ncols1 > SMALL_DIM && nrows2 > SMALL_DIM && ncols2 > SMALL_DIM
233            {
234                assert_eq!(
235                    ncols2, nrows3,
236                    "gemm: dimensions mismatch for multiplication."
237                );
238                assert_eq!(
239                    (nrows1, ncols1),
240                    (nrows2, ncols3),
241                    "gemm: dimensions mismatch for addition."
242                );
243
244                // NOTE: this case should never happen because we enter this
245                // codepath only when ncols2 > SMALL_DIM. Though we keep this
246                // here just in case if in the future we change the conditions to
247                // enter this codepath.
248                if ncols2 == 0 {
249                    // NOTE: we can't just always multiply by beta
250                    // because we documented the guaranty that `self` is
251                    // never read if `beta` is zero.
252                    if beta.is_zero() {
253                        y.apply(|e| Status::init(e, T::zero()));
254                    } else {
255                        // SAFETY: this is UB if Status = Uninit
256                        y.apply(|e| *Status::assume_init_mut(e) *= beta.clone());
257                    }
258                    return;
259                }
260
261                if TypeId::of::<T>() == TypeId::of::<f32>() {
262                    let (rsa, csa) = a.strides();
263                    let (rsb, csb) = b.strides();
264                    let (rsc, csc) = y.strides();
265
266                    matrixmultiply::sgemm(
267                        nrows2,
268                        ncols2,
269                        ncols3,
270                        mem::transmute_copy(&alpha),
271                        a.data.ptr() as *const f32,
272                        rsa as isize,
273                        csa as isize,
274                        b.data.ptr() as *const f32,
275                        rsb as isize,
276                        csb as isize,
277                        mem::transmute_copy(&beta),
278                        y.data.ptr_mut() as *mut f32,
279                        rsc as isize,
280                        csc as isize,
281                    );
282                    return;
283                } else if TypeId::of::<T>() == TypeId::of::<f64>() {
284                    let (rsa, csa) = a.strides();
285                    let (rsb, csb) = b.strides();
286                    let (rsc, csc) = y.strides();
287
288                    matrixmultiply::dgemm(
289                        nrows2,
290                        ncols2,
291                        ncols3,
292                        mem::transmute_copy(&alpha),
293                        a.data.ptr() as *const f64,
294                        rsa as isize,
295                        csa as isize,
296                        b.data.ptr() as *const f64,
297                        rsb as isize,
298                        csb as isize,
299                        mem::transmute_copy(&beta),
300                        y.data.ptr_mut() as *mut f64,
301                        rsc as isize,
302                        csc as isize,
303                    );
304                    return;
305                }
306            }
307        }
308    }
309
310    for j1 in 0..ncols1 {
311        // TODO: avoid bound checks.
312        // SAFETY: this is UB if Status = Uninit && beta != 0
313        gemv_uninit(
314            status,
315            &mut y.column_mut(j1),
316            alpha.clone(),
317            a,
318            &b.column(j1),
319            beta.clone(),
320        );
321    }
322}