ndarray/linalg/
impl_linalg.rs

1// Copyright 2014-2020 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::imp_prelude::*;
10
11#[cfg(feature = "blas")]
12use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
13use crate::numeric_util;
14
15use crate::{LinalgScalar, Zip};
16
17#[cfg(not(feature = "std"))]
18use alloc::vec::Vec;
19use std::any::TypeId;
20use std::mem::MaybeUninit;
21
22use num_complex::Complex;
23use num_complex::{Complex32 as c32, Complex64 as c64};
24
25#[cfg(feature = "blas")]
26use libc::c_int;
27
28#[cfg(feature = "blas")]
29use cblas_sys as blas_sys;
30#[cfg(feature = "blas")]
31use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT, CBLAS_TRANSPOSE};
32
33/// len of vector before we use blas
34#[cfg(feature = "blas")]
35const DOT_BLAS_CUTOFF: usize = 32;
36/// side of matrix before we use blas
37#[cfg(feature = "blas")]
38const GEMM_BLAS_CUTOFF: usize = 7;
39#[cfg(feature = "blas")]
40#[allow(non_camel_case_types)]
41type blas_index = c_int; // blas index type
42
43impl<A, S> ArrayBase<S, Ix1>
44where S: Data<Elem = A>
45{
46    /// Perform dot product or matrix multiplication of arrays `self` and `rhs`.
47    ///
48    /// `Rhs` may be either a one-dimensional or a two-dimensional array.
49    ///
50    /// If `Rhs` is one-dimensional, then the operation is a vector dot
51    /// product, which is the sum of the elementwise products (no conjugation
52    /// of complex operands, and thus not their inner product). In this case,
53    /// `self` and `rhs` must be the same length.
54    ///
55    /// If `Rhs` is two-dimensional, then the operation is matrix
56    /// multiplication, where `self` is treated as a row vector. In this case,
57    /// if `self` is shape *M*, then `rhs` is shape *M* × *N* and the result is
58    /// shape *N*.
59    ///
60    /// **Panics** if the array shapes are incompatible.<br>
61    /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
62    /// layout allows.
63    #[track_caller]
64    pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
65    where Self: Dot<Rhs>
66    {
67        Dot::dot(self, rhs)
68    }
69
70    fn dot_generic<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
71    where
72        S2: Data<Elem = A>,
73        A: LinalgScalar,
74    {
75        debug_assert_eq!(self.len(), rhs.len());
76        assert!(self.len() == rhs.len());
77        if let Some(self_s) = self.as_slice() {
78            if let Some(rhs_s) = rhs.as_slice() {
79                return numeric_util::unrolled_dot(self_s, rhs_s);
80            }
81        }
82        let mut sum = A::zero();
83        for i in 0..self.len() {
84            unsafe {
85                sum = sum + *self.uget(i) * *rhs.uget(i);
86            }
87        }
88        sum
89    }
90
91    #[cfg(not(feature = "blas"))]
92    fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
93    where
94        S2: Data<Elem = A>,
95        A: LinalgScalar,
96    {
97        self.dot_generic(rhs)
98    }
99
100    #[cfg(feature = "blas")]
101    fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
102    where
103        S2: Data<Elem = A>,
104        A: LinalgScalar,
105    {
106        // Use only if the vector is large enough to be worth it
107        if self.len() >= DOT_BLAS_CUTOFF {
108            debug_assert_eq!(self.len(), rhs.len());
109            assert!(self.len() == rhs.len());
110            macro_rules! dot {
111                ($ty:ty, $func:ident) => {{
112                    if blas_compat_1d::<$ty, _>(self) && blas_compat_1d::<$ty, _>(rhs) {
113                        unsafe {
114                            let (lhs_ptr, n, incx) =
115                                blas_1d_params(self.ptr.as_ptr(), self.len(), self.strides()[0]);
116                            let (rhs_ptr, _, incy) =
117                                blas_1d_params(rhs.ptr.as_ptr(), rhs.len(), rhs.strides()[0]);
118                            let ret = blas_sys::$func(
119                                n,
120                                lhs_ptr as *const $ty,
121                                incx,
122                                rhs_ptr as *const $ty,
123                                incy,
124                            );
125                            return cast_as::<$ty, A>(&ret);
126                        }
127                    }
128                }};
129            }
130
131            dot! {f32, cblas_sdot};
132            dot! {f64, cblas_ddot};
133        }
134        self.dot_generic(rhs)
135    }
136}
137
138/// Return a pointer to the starting element in BLAS's view.
139///
140/// BLAS wants a pointer to the element with lowest address,
141/// which agrees with our pointer for non-negative strides, but
142/// is at the opposite end for negative strides.
143#[cfg(feature = "blas")]
144unsafe fn blas_1d_params<A>(ptr: *const A, len: usize, stride: isize) -> (*const A, blas_index, blas_index)
145{
146    // [x x x x]
147    //        ^--ptr
148    //        stride = -1
149    //  ^--blas_ptr = ptr + (len - 1) * stride
150    if stride >= 0 || len == 0 {
151        (ptr, len as blas_index, stride as blas_index)
152    } else {
153        let ptr = ptr.offset((len - 1) as isize * stride);
154        (ptr, len as blas_index, stride as blas_index)
155    }
156}
157
158/// Matrix Multiplication
159///
160/// For two-dimensional arrays, the dot method computes the matrix
161/// multiplication.
162pub trait Dot<Rhs>
163{
164    /// The result of the operation.
165    ///
166    /// For two-dimensional arrays: a rectangular array.
167    type Output;
168    fn dot(&self, rhs: &Rhs) -> Self::Output;
169}
170
171impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix1>
172where
173    S: Data<Elem = A>,
174    S2: Data<Elem = A>,
175    A: LinalgScalar,
176{
177    type Output = A;
178
179    /// Compute the dot product of one-dimensional arrays.
180    ///
181    /// The dot product is a sum of the elementwise products (no conjugation
182    /// of complex operands, and thus not their inner product).
183    ///
184    /// **Panics** if the arrays are not of the same length.<br>
185    /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
186    /// layout allows.
187    #[track_caller]
188    fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> A
189    {
190        self.dot_impl(rhs)
191    }
192}
193
194impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix1>
195where
196    S: Data<Elem = A>,
197    S2: Data<Elem = A>,
198    A: LinalgScalar,
199{
200    type Output = Array<A, Ix1>;
201
202    /// Perform the matrix multiplication of the row vector `self` and
203    /// rectangular matrix `rhs`.
204    ///
205    /// The array shapes must agree in the way that
206    /// if `self` is *M*, then `rhs` is *M* × *N*.
207    ///
208    /// Return a result array with shape *N*.
209    ///
210    /// **Panics** if shapes are incompatible.
211    #[track_caller]
212    fn dot(&self, rhs: &ArrayBase<S2, Ix2>) -> Array<A, Ix1>
213    {
214        rhs.t().dot(self)
215    }
216}
217
218impl<A, S> ArrayBase<S, Ix2>
219where S: Data<Elem = A>
220{
221    /// Perform matrix multiplication of rectangular arrays `self` and `rhs`.
222    ///
223    /// `Rhs` may be either a one-dimensional or a two-dimensional array.
224    ///
225    /// If Rhs is two-dimensional, they array shapes must agree in the way that
226    /// if `self` is *M* × *N*, then `rhs` is *N* × *K*.
227    ///
228    /// Return a result array with shape *M* × *K*.
229    ///
230    /// **Panics** if shapes are incompatible or the number of elements in the
231    /// result would overflow `isize`.
232    ///
233    /// *Note:* If enabled, uses blas `gemv/gemm` for elements of `f32, f64`
234    /// when memory layout allows. The default matrixmultiply backend
235    /// is otherwise used for `f32, f64` for all memory layouts.
236    ///
237    /// ```
238    /// use ndarray::arr2;
239    ///
240    /// let a = arr2(&[[1., 2.],
241    ///                [0., 1.]]);
242    /// let b = arr2(&[[1., 2.],
243    ///                [2., 3.]]);
244    ///
245    /// assert!(
246    ///     a.dot(&b) == arr2(&[[5., 8.],
247    ///                         [2., 3.]])
248    /// );
249    /// ```
250    #[track_caller]
251    pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
252    where Self: Dot<Rhs>
253    {
254        Dot::dot(self, rhs)
255    }
256}
257
258impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix2>
259where
260    S: Data<Elem = A>,
261    S2: Data<Elem = A>,
262    A: LinalgScalar,
263{
264    type Output = Array2<A>;
265    fn dot(&self, b: &ArrayBase<S2, Ix2>) -> Array2<A>
266    {
267        let a = self.view();
268        let b = b.view();
269        let ((m, k), (k2, n)) = (a.dim(), b.dim());
270        if k != k2 || m.checked_mul(n).is_none() {
271            dot_shape_error(m, k, k2, n);
272        }
273
274        let lhs_s0 = a.strides()[0];
275        let rhs_s0 = b.strides()[0];
276        let column_major = lhs_s0 == 1 && rhs_s0 == 1;
277        // A is Copy so this is safe
278        let mut v = Vec::with_capacity(m * n);
279        let mut c;
280        unsafe {
281            v.set_len(m * n);
282            c = Array::from_shape_vec_unchecked((m, n).set_f(column_major), v);
283        }
284        mat_mul_impl(A::one(), &a, &b, A::zero(), &mut c.view_mut());
285        c
286    }
287}
288
289/// Assumes that `m` and `n` are ≤ `isize::MAX`.
290#[cold]
291#[inline(never)]
292fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> !
293{
294    match m.checked_mul(n) {
295        Some(len) if len <= isize::MAX as usize => {}
296        _ => panic!("ndarray: shape {} × {} overflows isize", m, n),
297    }
298    panic!(
299        "ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication",
300        m, k, k2, n
301    );
302}
303
304#[cold]
305#[inline(never)]
306fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> !
307{
308    panic!("ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication",
309           m, k, k2, n, c1, c2);
310}
311
312/// Perform the matrix multiplication of the rectangular array `self` and
313/// column vector `rhs`.
314///
315/// The array shapes must agree in the way that
316/// if `self` is *M* × *N*, then `rhs` is *N*.
317///
318/// Return a result array with shape *M*.
319///
320/// **Panics** if shapes are incompatible.
321impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix2>
322where
323    S: Data<Elem = A>,
324    S2: Data<Elem = A>,
325    A: LinalgScalar,
326{
327    type Output = Array<A, Ix1>;
328    #[track_caller]
329    fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> Array<A, Ix1>
330    {
331        let ((m, a), n) = (self.dim(), rhs.dim());
332        if a != n {
333            dot_shape_error(m, a, n, 1);
334        }
335
336        // Avoid initializing the memory in vec -- set it during iteration
337        unsafe {
338            let mut c = Array1::uninit(m);
339            general_mat_vec_mul_impl(A::one(), self, rhs, A::zero(), c.raw_view_mut().cast::<A>());
340            c.assume_init()
341        }
342    }
343}
344
345impl<A, S, D> ArrayBase<S, D>
346where
347    S: Data<Elem = A>,
348    D: Dimension,
349{
350    /// Perform the operation `self += alpha * rhs` efficiently, where
351    /// `alpha` is a scalar and `rhs` is another array. This operation is
352    /// also known as `axpy` in BLAS.
353    ///
354    /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
355    ///
356    /// **Panics** if broadcasting isn’t possible.
357    #[track_caller]
358    pub fn scaled_add<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
359    where
360        S: DataMut,
361        S2: Data<Elem = A>,
362        A: LinalgScalar,
363        E: Dimension,
364    {
365        self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x));
366    }
367}
368
369// mat_mul_impl uses ArrayView arguments to send all array kinds into
370// the same instantiated implementation.
371#[cfg(not(feature = "blas"))]
372use self::mat_mul_general as mat_mul_impl;
373
374#[cfg(feature = "blas")]
375fn mat_mul_impl<A>(alpha: A, a: &ArrayView2<'_, A>, b: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>)
376where A: LinalgScalar
377{
378    let ((m, k), (k2, n)) = (a.dim(), b.dim());
379    debug_assert_eq!(k, k2);
380    if (m > GEMM_BLAS_CUTOFF || n > GEMM_BLAS_CUTOFF || k > GEMM_BLAS_CUTOFF)
381        && (same_type::<A, f32>() || same_type::<A, f64>() || same_type::<A, c32>() || same_type::<A, c64>())
382    {
383        // Compute A B -> C
384        // We require for BLAS compatibility that:
385        // A, B, C are contiguous (stride=1) in their fastest dimension,
386        // but they can be either row major/"c" or col major/"f".
387        //
388        // The "normal case" is CblasRowMajor for cblas.
389        // Select CblasRowMajor / CblasColMajor to fit C's memory order.
390        //
391        // Apply transpose to A, B as needed if they differ from the row major case.
392        // If C is CblasColMajor then transpose both A, B (again!)
393
394        if let (Some(a_layout), Some(b_layout), Some(c_layout)) =
395            (get_blas_compatible_layout(a), get_blas_compatible_layout(b), get_blas_compatible_layout(c))
396        {
397            let cblas_layout = c_layout.to_cblas_layout();
398            let a_trans = a_layout.to_cblas_transpose_for(cblas_layout);
399            let lda = blas_stride(&a, a_layout);
400
401            let b_trans = b_layout.to_cblas_transpose_for(cblas_layout);
402            let ldb = blas_stride(&b, b_layout);
403
404            let ldc = blas_stride(&c, c_layout);
405
406            macro_rules! gemm_scalar_cast {
407                (f32, $var:ident) => {
408                    cast_as(&$var)
409                };
410                (f64, $var:ident) => {
411                    cast_as(&$var)
412                };
413                (c32, $var:ident) => {
414                    &$var as *const A as *const _
415                };
416                (c64, $var:ident) => {
417                    &$var as *const A as *const _
418                };
419            }
420
421            macro_rules! gemm {
422                ($ty:tt, $gemm:ident) => {
423                    if same_type::<A, $ty>() {
424                        // gemm is C ← αA^Op B^Op + βC
425                        // Where Op is notrans/trans/conjtrans
426                        unsafe {
427                            blas_sys::$gemm(
428                                cblas_layout,
429                                a_trans,
430                                b_trans,
431                                m as blas_index,                 // m, rows of Op(a)
432                                n as blas_index,                 // n, cols of Op(b)
433                                k as blas_index,                 // k, cols of Op(a)
434                                gemm_scalar_cast!($ty, alpha),   // alpha
435                                a.ptr.as_ptr() as *const _,      // a
436                                lda,                             // lda
437                                b.ptr.as_ptr() as *const _,      // b
438                                ldb,                             // ldb
439                                gemm_scalar_cast!($ty, beta),    // beta
440                                c.ptr.as_ptr() as *mut _,        // c
441                                ldc,                             // ldc
442                            );
443                        }
444                        return;
445                    }
446                };
447            }
448
449            gemm!(f32, cblas_sgemm);
450            gemm!(f64, cblas_dgemm);
451            gemm!(c32, cblas_cgemm);
452            gemm!(c64, cblas_zgemm);
453
454            unreachable!() // we checked above that A is one of f32, f64, c32, c64
455        }
456    }
457    mat_mul_general(alpha, a, b, beta, c)
458}
459
460/// C ← α A B + β C
461fn mat_mul_general<A>(
462    alpha: A, lhs: &ArrayView2<'_, A>, rhs: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>,
463) where A: LinalgScalar
464{
465    let ((m, k), (_, n)) = (lhs.dim(), rhs.dim());
466
467    // common parameters for gemm
468    let ap = lhs.as_ptr();
469    let bp = rhs.as_ptr();
470    let cp = c.as_mut_ptr();
471    let (rsc, csc) = (c.strides()[0], c.strides()[1]);
472    if same_type::<A, f32>() {
473        unsafe {
474            matrixmultiply::sgemm(
475                m,
476                k,
477                n,
478                cast_as(&alpha),
479                ap as *const _,
480                lhs.strides()[0],
481                lhs.strides()[1],
482                bp as *const _,
483                rhs.strides()[0],
484                rhs.strides()[1],
485                cast_as(&beta),
486                cp as *mut _,
487                rsc,
488                csc,
489            );
490        }
491    } else if same_type::<A, f64>() {
492        unsafe {
493            matrixmultiply::dgemm(
494                m,
495                k,
496                n,
497                cast_as(&alpha),
498                ap as *const _,
499                lhs.strides()[0],
500                lhs.strides()[1],
501                bp as *const _,
502                rhs.strides()[0],
503                rhs.strides()[1],
504                cast_as(&beta),
505                cp as *mut _,
506                rsc,
507                csc,
508            );
509        }
510    } else if same_type::<A, c32>() {
511        unsafe {
512            matrixmultiply::cgemm(
513                matrixmultiply::CGemmOption::Standard,
514                matrixmultiply::CGemmOption::Standard,
515                m,
516                k,
517                n,
518                complex_array(cast_as(&alpha)),
519                ap as *const _,
520                lhs.strides()[0],
521                lhs.strides()[1],
522                bp as *const _,
523                rhs.strides()[0],
524                rhs.strides()[1],
525                complex_array(cast_as(&beta)),
526                cp as *mut _,
527                rsc,
528                csc,
529            );
530        }
531    } else if same_type::<A, c64>() {
532        unsafe {
533            matrixmultiply::zgemm(
534                matrixmultiply::CGemmOption::Standard,
535                matrixmultiply::CGemmOption::Standard,
536                m,
537                k,
538                n,
539                complex_array(cast_as(&alpha)),
540                ap as *const _,
541                lhs.strides()[0],
542                lhs.strides()[1],
543                bp as *const _,
544                rhs.strides()[0],
545                rhs.strides()[1],
546                complex_array(cast_as(&beta)),
547                cp as *mut _,
548                rsc,
549                csc,
550            );
551        }
552    } else {
553        // It's a no-op if `c` has zero length.
554        if c.is_empty() {
555            return;
556        }
557
558        // initialize memory if beta is zero
559        if beta.is_zero() {
560            c.fill(beta);
561        }
562
563        let mut i = 0;
564        let mut j = 0;
565        loop {
566            unsafe {
567                let elt = c.uget_mut((i, j));
568                *elt =
569                    *elt * beta + alpha * (0..k).fold(A::zero(), move |s, x| s + *lhs.uget((i, x)) * *rhs.uget((x, j)));
570            }
571            j += 1;
572            if j == n {
573                j = 0;
574                i += 1;
575                if i == m {
576                    break;
577                }
578            }
579        }
580    }
581}
582
583/// General matrix-matrix multiplication.
584///
585/// Compute C ← α A B + β C
586///
587/// The array shapes must agree in the way that
588/// if `a` is *M* × *N*, then `b` is *N* × *K* and `c` is *M* × *K*.
589///
590/// ***Panics*** if array shapes are not compatible<br>
591/// *Note:* If enabled, uses blas `gemm` for elements of `f32, f64` when memory
592/// layout allows.  The default matrixmultiply backend is otherwise used for
593/// `f32, f64` for all memory layouts.
594#[track_caller]
595pub fn general_mat_mul<A, S1, S2, S3>(
596    alpha: A, a: &ArrayBase<S1, Ix2>, b: &ArrayBase<S2, Ix2>, beta: A, c: &mut ArrayBase<S3, Ix2>,
597) where
598    S1: Data<Elem = A>,
599    S2: Data<Elem = A>,
600    S3: DataMut<Elem = A>,
601    A: LinalgScalar,
602{
603    let ((m, k), (k2, n)) = (a.dim(), b.dim());
604    let (m2, n2) = c.dim();
605    if k != k2 || m != m2 || n != n2 {
606        general_dot_shape_error(m, k, k2, n, m2, n2);
607    } else {
608        mat_mul_impl(alpha, &a.view(), &b.view(), beta, &mut c.view_mut());
609    }
610}
611
612/// General matrix-vector multiplication.
613///
614/// Compute y ← α A x + β y
615///
616/// where A is a *M* × *N* matrix and x is an *N*-element column vector and
617/// y an *M*-element column vector (one dimensional arrays).
618///
619/// ***Panics*** if array shapes are not compatible<br>
620/// *Note:* If enabled, uses blas `gemv` for elements of `f32, f64` when memory
621/// layout allows.
622#[track_caller]
623#[allow(clippy::collapsible_if)]
624pub fn general_mat_vec_mul<A, S1, S2, S3>(
625    alpha: A, a: &ArrayBase<S1, Ix2>, x: &ArrayBase<S2, Ix1>, beta: A, y: &mut ArrayBase<S3, Ix1>,
626) where
627    S1: Data<Elem = A>,
628    S2: Data<Elem = A>,
629    S3: DataMut<Elem = A>,
630    A: LinalgScalar,
631{
632    unsafe { general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) }
633}
634
635/// General matrix-vector multiplication
636///
637/// Use a raw view for the destination vector, so that it can be uninitialized.
638///
639/// ## Safety
640///
641/// The caller must ensure that the raw view is valid for writing.
642/// the destination may be uninitialized iff beta is zero.
643#[allow(clippy::collapsible_else_if)]
644unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
645    alpha: A, a: &ArrayBase<S1, Ix2>, x: &ArrayBase<S2, Ix1>, beta: A, y: RawArrayViewMut<A, Ix1>,
646) where
647    S1: Data<Elem = A>,
648    S2: Data<Elem = A>,
649    A: LinalgScalar,
650{
651    let ((m, k), k2) = (a.dim(), x.dim());
652    let m2 = y.dim();
653    if k != k2 || m != m2 {
654        general_dot_shape_error(m, k, k2, 1, m2, 1);
655    } else {
656        #[cfg(feature = "blas")]
657        macro_rules! gemv {
658            ($ty:ty, $gemv:ident) => {
659                if same_type::<A, $ty>() {
660                    if let Some(layout) = get_blas_compatible_layout(&a) {
661                        if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) {
662                            // Determine stride between rows or columns. Note that the stride is
663                            // adjusted to at least `k` or `m` to handle the case of a matrix with a
664                            // trivial (length 1) dimension, since the stride for the trivial dimension
665                            // may be arbitrary.
666                            let a_trans = CblasNoTrans;
667
668                            let a_stride = blas_stride(&a, layout);
669                            let cblas_layout = layout.to_cblas_layout();
670
671                            // Low addr in memory pointers required for x, y
672                            let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides);
673                            let x_ptr = x.ptr.as_ptr().sub(x_offset);
674                            let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides);
675                            let y_ptr = y.ptr.as_ptr().sub(y_offset);
676
677                            let x_stride = x.strides()[0] as blas_index;
678                            let y_stride = y.strides()[0] as blas_index;
679
680                            blas_sys::$gemv(
681                                cblas_layout,
682                                a_trans,
683                                m as blas_index,            // m, rows of Op(a)
684                                k as blas_index,            // n, cols of Op(a)
685                                cast_as(&alpha),            // alpha
686                                a.ptr.as_ptr() as *const _, // a
687                                a_stride,                   // lda
688                                x_ptr as *const _,          // x
689                                x_stride,
690                                cast_as(&beta),             // beta
691                                y_ptr as *mut _,            // y
692                                y_stride,
693                            );
694                            return;
695                        }
696                    }
697                }
698            };
699        }
700        #[cfg(feature = "blas")]
701        gemv!(f32, cblas_sgemv);
702        #[cfg(feature = "blas")]
703        gemv!(f64, cblas_dgemv);
704
705        /* general */
706
707        if beta.is_zero() {
708            // when beta is zero, c may be uninitialized
709            Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
710                elt.write(row.dot(x) * alpha);
711            });
712        } else {
713            Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
714                *elt = *elt * beta + row.dot(x) * alpha;
715            });
716        }
717    }
718}
719
720/// Kronecker product of 2D matrices.
721///
722/// The kronecker product of a LxN matrix A and a MxR matrix B is a (L*M)x(N*R)
723/// matrix K formed by the block multiplication A_ij * B.
724pub fn kron<A, S1, S2>(a: &ArrayBase<S1, Ix2>, b: &ArrayBase<S2, Ix2>) -> Array<A, Ix2>
725where
726    S1: Data<Elem = A>,
727    S2: Data<Elem = A>,
728    A: LinalgScalar,
729{
730    let dimar = a.shape()[0];
731    let dimac = a.shape()[1];
732    let dimbr = b.shape()[0];
733    let dimbc = b.shape()[1];
734    let mut out: Array2<MaybeUninit<A>> = Array2::uninit((
735        dimar
736            .checked_mul(dimbr)
737            .expect("Dimensions of kronecker product output array overflows usize."),
738        dimac
739            .checked_mul(dimbc)
740            .expect("Dimensions of kronecker product output array overflows usize."),
741    ));
742    Zip::from(out.exact_chunks_mut((dimbr, dimbc)))
743        .and(a)
744        .for_each(|out, &a| {
745            Zip::from(out).and(b).for_each(|out, &b| {
746                *out = MaybeUninit::new(a * b);
747            })
748        });
749    unsafe { out.assume_init() }
750}
751
752#[inline(always)]
753/// Return `true` if `A` and `B` are the same type
754fn same_type<A: 'static, B: 'static>() -> bool
755{
756    TypeId::of::<A>() == TypeId::of::<B>()
757}
758
759// Read pointer to type `A` as type `B`.
760//
761// **Panics** if `A` and `B` are not the same type
762fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B
763{
764    assert!(same_type::<A, B>(), "expect type {} and {} to match",
765            std::any::type_name::<A>(), std::any::type_name::<B>());
766    unsafe { ::std::ptr::read(a as *const _ as *const B) }
767}
768
769/// Return the complex in the form of an array [re, im]
770#[inline]
771fn complex_array<A: 'static + Copy>(z: Complex<A>) -> [A; 2]
772{
773    [z.re, z.im]
774}
775
776#[cfg(feature = "blas")]
777fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
778where
779    S: RawData,
780    A: 'static,
781    S::Elem: 'static,
782{
783    if !same_type::<A, S::Elem>() {
784        return false;
785    }
786    if a.len() > blas_index::MAX as usize {
787        return false;
788    }
789    let stride = a.strides()[0];
790    if stride == 0 || stride > blas_index::MAX as isize || stride < blas_index::MIN as isize {
791        return false;
792    }
793    true
794}
795
796#[cfg(feature = "blas")]
797#[derive(Copy, Clone)]
798#[cfg_attr(test, derive(PartialEq, Eq, Debug))]
799enum BlasOrder
800{
801    C,
802    F,
803}
804
805#[cfg(feature = "blas")]
806impl BlasOrder
807{
808    fn transpose(self) -> Self
809    {
810        match self {
811            Self::C => Self::F,
812            Self::F => Self::C,
813        }
814    }
815
816    #[inline]
817    /// Axis of leading stride (opposite of contiguous axis)
818    fn get_blas_lead_axis(self) -> usize
819    {
820        match self {
821            Self::C => 0,
822            Self::F => 1,
823        }
824    }
825
826    fn to_cblas_layout(self) -> CBLAS_LAYOUT
827    {
828        match self {
829            Self::C => CBLAS_LAYOUT::CblasRowMajor,
830            Self::F => CBLAS_LAYOUT::CblasColMajor,
831        }
832    }
833
834    /// When using cblas_sgemm (etc) with C matrix using `for_layout`,
835    /// how should this `self` matrix be transposed
836    fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE
837    {
838        let effective_order = match for_layout {
839            CBLAS_LAYOUT::CblasRowMajor => self,
840            CBLAS_LAYOUT::CblasColMajor => self.transpose(),
841        };
842
843        match effective_order {
844            Self::C => CblasNoTrans,
845            Self::F => CblasTrans,
846        }
847    }
848}
849
850#[cfg(feature = "blas")]
851fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool
852{
853    let (m, n) = dim.into_pattern();
854    let s0 = stride[0] as isize;
855    let s1 = stride[1] as isize;
856    let (inner_stride, outer_stride, inner_dim, outer_dim) = match order {
857        BlasOrder::C => (s1, s0, m, n),
858        BlasOrder::F => (s0, s1, n, m),
859    };
860
861    if !(inner_stride == 1 || outer_dim == 1) {
862        return false;
863    }
864
865    if s0 < 1 || s1 < 1 {
866        return false;
867    }
868
869    if (s0 > blas_index::MAX as isize || s0 < blas_index::MIN as isize)
870        || (s1 > blas_index::MAX as isize || s1 < blas_index::MIN as isize)
871    {
872        return false;
873    }
874
875    // leading stride must >= the dimension (no broadcasting/aliasing)
876    if inner_dim > 1 && (outer_stride as usize) < outer_dim {
877        return false;
878    }
879
880    if m > blas_index::MAX as usize || n > blas_index::MAX as usize {
881        return false;
882    }
883
884    true
885}
886
887/// Get BLAS compatible layout if any (C or F, preferring the former)
888#[cfg(feature = "blas")]
889fn get_blas_compatible_layout<S>(a: &ArrayBase<S, Ix2>) -> Option<BlasOrder>
890where S: Data
891{
892    if is_blas_2d(&a.dim, &a.strides, BlasOrder::C) {
893        Some(BlasOrder::C)
894    } else if is_blas_2d(&a.dim, &a.strides, BlasOrder::F) {
895        Some(BlasOrder::F)
896    } else {
897        None
898    }
899}
900
901/// `a` should be blas compatible.
902/// axis: 0 or 1.
903///
904/// Return leading stride (lda, ldb, ldc) of array
905#[cfg(feature = "blas")]
906fn blas_stride<S>(a: &ArrayBase<S, Ix2>, order: BlasOrder) -> blas_index
907where S: Data
908{
909    let axis = order.get_blas_lead_axis();
910    let other_axis = 1 - axis;
911    let len_this = a.shape()[axis];
912    let len_other = a.shape()[other_axis];
913    let stride = a.strides()[axis];
914
915    // if current axis has length == 1, then stride does not matter for ndarray
916    // but for BLAS we need a stride that makes sense, i.e. it's >= the other axis
917
918    // cast: a should already be blas compatible
919    (if len_this <= 1 {
920        Ord::max(stride, len_other as isize)
921    } else {
922        stride
923    }) as blas_index
924}
925
926#[cfg(test)]
927#[cfg(feature = "blas")]
928fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
929where
930    S: Data,
931    A: 'static,
932    S::Elem: 'static,
933{
934    if !same_type::<A, S::Elem>() {
935        return false;
936    }
937    is_blas_2d(&a.dim, &a.strides, BlasOrder::C)
938}
939
940#[cfg(test)]
941#[cfg(feature = "blas")]
942fn blas_column_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
943where
944    S: Data,
945    A: 'static,
946    S::Elem: 'static,
947{
948    if !same_type::<A, S::Elem>() {
949        return false;
950    }
951    is_blas_2d(&a.dim, &a.strides, BlasOrder::F)
952}
953
954#[cfg(test)]
955#[cfg(feature = "blas")]
956mod blas_tests
957{
958    use super::*;
959
960    #[test]
961    fn blas_row_major_2d_normal_matrix()
962    {
963        let m: Array2<f32> = Array2::zeros((3, 5));
964        assert!(blas_row_major_2d::<f32, _>(&m));
965        assert!(!blas_column_major_2d::<f32, _>(&m));
966    }
967
968    #[test]
969    fn blas_row_major_2d_row_matrix()
970    {
971        let m: Array2<f32> = Array2::zeros((1, 5));
972        assert!(blas_row_major_2d::<f32, _>(&m));
973        assert!(blas_column_major_2d::<f32, _>(&m));
974    }
975
976    #[test]
977    fn blas_row_major_2d_column_matrix()
978    {
979        let m: Array2<f32> = Array2::zeros((5, 1));
980        assert!(blas_row_major_2d::<f32, _>(&m));
981        assert!(blas_column_major_2d::<f32, _>(&m));
982    }
983
984    #[test]
985    fn blas_row_major_2d_transposed_row_matrix()
986    {
987        let m: Array2<f32> = Array2::zeros((1, 5));
988        let m_t = m.t();
989        assert!(blas_row_major_2d::<f32, _>(&m_t));
990        assert!(blas_column_major_2d::<f32, _>(&m_t));
991    }
992
993    #[test]
994    fn blas_row_major_2d_transposed_column_matrix()
995    {
996        let m: Array2<f32> = Array2::zeros((5, 1));
997        let m_t = m.t();
998        assert!(blas_row_major_2d::<f32, _>(&m_t));
999        assert!(blas_column_major_2d::<f32, _>(&m_t));
1000    }
1001
1002    #[test]
1003    fn blas_column_major_2d_normal_matrix()
1004    {
1005        let m: Array2<f32> = Array2::zeros((3, 5).f());
1006        assert!(!blas_row_major_2d::<f32, _>(&m));
1007        assert!(blas_column_major_2d::<f32, _>(&m));
1008    }
1009
1010    #[test]
1011    fn blas_row_major_2d_skip_rows_ok()
1012    {
1013        let m: Array2<f32> = Array2::zeros((5, 5));
1014        let mv = m.slice(s![..;2, ..]);
1015        assert!(blas_row_major_2d::<f32, _>(&mv));
1016        assert!(!blas_column_major_2d::<f32, _>(&mv));
1017    }
1018
1019    #[test]
1020    fn blas_row_major_2d_skip_columns_fail()
1021    {
1022        let m: Array2<f32> = Array2::zeros((5, 5));
1023        let mv = m.slice(s![.., ..;2]);
1024        assert!(!blas_row_major_2d::<f32, _>(&mv));
1025        assert!(!blas_column_major_2d::<f32, _>(&mv));
1026    }
1027
1028    #[test]
1029    fn blas_col_major_2d_skip_columns_ok()
1030    {
1031        let m: Array2<f32> = Array2::zeros((5, 5).f());
1032        let mv = m.slice(s![.., ..;2]);
1033        assert!(blas_column_major_2d::<f32, _>(&mv));
1034        assert!(!blas_row_major_2d::<f32, _>(&mv));
1035    }
1036
1037    #[test]
1038    fn blas_col_major_2d_skip_rows_fail()
1039    {
1040        let m: Array2<f32> = Array2::zeros((5, 5).f());
1041        let mv = m.slice(s![..;2, ..]);
1042        assert!(!blas_column_major_2d::<f32, _>(&mv));
1043        assert!(!blas_row_major_2d::<f32, _>(&mv));
1044    }
1045
1046    #[test]
1047    fn blas_too_short_stride()
1048    {
1049        // leading stride must be longer than the other dimension
1050        // Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS.
1051
1052        const N: usize = 5;
1053        const MAXSTRIDE: usize = N + 2;
1054        let mut data = [0; MAXSTRIDE * N];
1055        let mut iter = 0..data.len();
1056        data.fill_with(|| iter.next().unwrap());
1057
1058        for stride in 1..=MAXSTRIDE {
1059            let m = ArrayView::from_shape((N, N).strides((stride, 1)), &data).unwrap();
1060            eprintln!("{:?}", m);
1061
1062            if stride < N {
1063                assert_eq!(get_blas_compatible_layout(&m), None);
1064            } else {
1065                assert_eq!(get_blas_compatible_layout(&m), Some(BlasOrder::C));
1066            }
1067        }
1068    }
1069}
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy