1use crate::imp_prelude::*;
10
11#[cfg(feature = "blas")]
12use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
13use crate::numeric_util;
14
15use crate::{LinalgScalar, Zip};
16
17#[cfg(not(feature = "std"))]
18use alloc::vec::Vec;
19use std::any::TypeId;
20use std::mem::MaybeUninit;
21
22use num_complex::Complex;
23use num_complex::{Complex32 as c32, Complex64 as c64};
24
25#[cfg(feature = "blas")]
26use libc::c_int;
27
28#[cfg(feature = "blas")]
29use cblas_sys as blas_sys;
30#[cfg(feature = "blas")]
31use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT, CBLAS_TRANSPOSE};
32
33#[cfg(feature = "blas")]
35const DOT_BLAS_CUTOFF: usize = 32;
36#[cfg(feature = "blas")]
38const GEMM_BLAS_CUTOFF: usize = 7;
39#[cfg(feature = "blas")]
40#[allow(non_camel_case_types)]
41type blas_index = c_int; impl<A, S> ArrayBase<S, Ix1>
44where S: Data<Elem = A>
45{
46 #[track_caller]
64 pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
65 where Self: Dot<Rhs>
66 {
67 Dot::dot(self, rhs)
68 }
69
70 fn dot_generic<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
71 where
72 S2: Data<Elem = A>,
73 A: LinalgScalar,
74 {
75 debug_assert_eq!(self.len(), rhs.len());
76 assert!(self.len() == rhs.len());
77 if let Some(self_s) = self.as_slice() {
78 if let Some(rhs_s) = rhs.as_slice() {
79 return numeric_util::unrolled_dot(self_s, rhs_s);
80 }
81 }
82 let mut sum = A::zero();
83 for i in 0..self.len() {
84 unsafe {
85 sum = sum + *self.uget(i) * *rhs.uget(i);
86 }
87 }
88 sum
89 }
90
91 #[cfg(not(feature = "blas"))]
92 fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
93 where
94 S2: Data<Elem = A>,
95 A: LinalgScalar,
96 {
97 self.dot_generic(rhs)
98 }
99
100 #[cfg(feature = "blas")]
101 fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
102 where
103 S2: Data<Elem = A>,
104 A: LinalgScalar,
105 {
106 if self.len() >= DOT_BLAS_CUTOFF {
108 debug_assert_eq!(self.len(), rhs.len());
109 assert!(self.len() == rhs.len());
110 macro_rules! dot {
111 ($ty:ty, $func:ident) => {{
112 if blas_compat_1d::<$ty, _>(self) && blas_compat_1d::<$ty, _>(rhs) {
113 unsafe {
114 let (lhs_ptr, n, incx) =
115 blas_1d_params(self.ptr.as_ptr(), self.len(), self.strides()[0]);
116 let (rhs_ptr, _, incy) =
117 blas_1d_params(rhs.ptr.as_ptr(), rhs.len(), rhs.strides()[0]);
118 let ret = blas_sys::$func(
119 n,
120 lhs_ptr as *const $ty,
121 incx,
122 rhs_ptr as *const $ty,
123 incy,
124 );
125 return cast_as::<$ty, A>(&ret);
126 }
127 }
128 }};
129 }
130
131 dot! {f32, cblas_sdot};
132 dot! {f64, cblas_ddot};
133 }
134 self.dot_generic(rhs)
135 }
136}
137
138#[cfg(feature = "blas")]
144unsafe fn blas_1d_params<A>(ptr: *const A, len: usize, stride: isize) -> (*const A, blas_index, blas_index)
145{
146 if stride >= 0 || len == 0 {
151 (ptr, len as blas_index, stride as blas_index)
152 } else {
153 let ptr = ptr.offset((len - 1) as isize * stride);
154 (ptr, len as blas_index, stride as blas_index)
155 }
156}
157
158pub trait Dot<Rhs>
163{
164 type Output;
168 fn dot(&self, rhs: &Rhs) -> Self::Output;
169}
170
171impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix1>
172where
173 S: Data<Elem = A>,
174 S2: Data<Elem = A>,
175 A: LinalgScalar,
176{
177 type Output = A;
178
179 #[track_caller]
188 fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> A
189 {
190 self.dot_impl(rhs)
191 }
192}
193
194impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix1>
195where
196 S: Data<Elem = A>,
197 S2: Data<Elem = A>,
198 A: LinalgScalar,
199{
200 type Output = Array<A, Ix1>;
201
202 #[track_caller]
212 fn dot(&self, rhs: &ArrayBase<S2, Ix2>) -> Array<A, Ix1>
213 {
214 rhs.t().dot(self)
215 }
216}
217
218impl<A, S> ArrayBase<S, Ix2>
219where S: Data<Elem = A>
220{
221 #[track_caller]
251 pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
252 where Self: Dot<Rhs>
253 {
254 Dot::dot(self, rhs)
255 }
256}
257
258impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix2>
259where
260 S: Data<Elem = A>,
261 S2: Data<Elem = A>,
262 A: LinalgScalar,
263{
264 type Output = Array2<A>;
265 fn dot(&self, b: &ArrayBase<S2, Ix2>) -> Array2<A>
266 {
267 let a = self.view();
268 let b = b.view();
269 let ((m, k), (k2, n)) = (a.dim(), b.dim());
270 if k != k2 || m.checked_mul(n).is_none() {
271 dot_shape_error(m, k, k2, n);
272 }
273
274 let lhs_s0 = a.strides()[0];
275 let rhs_s0 = b.strides()[0];
276 let column_major = lhs_s0 == 1 && rhs_s0 == 1;
277 let mut v = Vec::with_capacity(m * n);
279 let mut c;
280 unsafe {
281 v.set_len(m * n);
282 c = Array::from_shape_vec_unchecked((m, n).set_f(column_major), v);
283 }
284 mat_mul_impl(A::one(), &a, &b, A::zero(), &mut c.view_mut());
285 c
286 }
287}
288
289#[cold]
291#[inline(never)]
292fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> !
293{
294 match m.checked_mul(n) {
295 Some(len) if len <= isize::MAX as usize => {}
296 _ => panic!("ndarray: shape {} × {} overflows isize", m, n),
297 }
298 panic!(
299 "ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication",
300 m, k, k2, n
301 );
302}
303
304#[cold]
305#[inline(never)]
306fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> !
307{
308 panic!("ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication",
309 m, k, k2, n, c1, c2);
310}
311
312impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix2>
322where
323 S: Data<Elem = A>,
324 S2: Data<Elem = A>,
325 A: LinalgScalar,
326{
327 type Output = Array<A, Ix1>;
328 #[track_caller]
329 fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> Array<A, Ix1>
330 {
331 let ((m, a), n) = (self.dim(), rhs.dim());
332 if a != n {
333 dot_shape_error(m, a, n, 1);
334 }
335
336 unsafe {
338 let mut c = Array1::uninit(m);
339 general_mat_vec_mul_impl(A::one(), self, rhs, A::zero(), c.raw_view_mut().cast::<A>());
340 c.assume_init()
341 }
342 }
343}
344
345impl<A, S, D> ArrayBase<S, D>
346where
347 S: Data<Elem = A>,
348 D: Dimension,
349{
350 #[track_caller]
358 pub fn scaled_add<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
359 where
360 S: DataMut,
361 S2: Data<Elem = A>,
362 A: LinalgScalar,
363 E: Dimension,
364 {
365 self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x));
366 }
367}
368
369#[cfg(not(feature = "blas"))]
372use self::mat_mul_general as mat_mul_impl;
373
374#[cfg(feature = "blas")]
375fn mat_mul_impl<A>(alpha: A, a: &ArrayView2<'_, A>, b: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>)
376where A: LinalgScalar
377{
378 let ((m, k), (k2, n)) = (a.dim(), b.dim());
379 debug_assert_eq!(k, k2);
380 if (m > GEMM_BLAS_CUTOFF || n > GEMM_BLAS_CUTOFF || k > GEMM_BLAS_CUTOFF)
381 && (same_type::<A, f32>() || same_type::<A, f64>() || same_type::<A, c32>() || same_type::<A, c64>())
382 {
383 if let (Some(a_layout), Some(b_layout), Some(c_layout)) =
395 (get_blas_compatible_layout(a), get_blas_compatible_layout(b), get_blas_compatible_layout(c))
396 {
397 let cblas_layout = c_layout.to_cblas_layout();
398 let a_trans = a_layout.to_cblas_transpose_for(cblas_layout);
399 let lda = blas_stride(&a, a_layout);
400
401 let b_trans = b_layout.to_cblas_transpose_for(cblas_layout);
402 let ldb = blas_stride(&b, b_layout);
403
404 let ldc = blas_stride(&c, c_layout);
405
406 macro_rules! gemm_scalar_cast {
407 (f32, $var:ident) => {
408 cast_as(&$var)
409 };
410 (f64, $var:ident) => {
411 cast_as(&$var)
412 };
413 (c32, $var:ident) => {
414 &$var as *const A as *const _
415 };
416 (c64, $var:ident) => {
417 &$var as *const A as *const _
418 };
419 }
420
421 macro_rules! gemm {
422 ($ty:tt, $gemm:ident) => {
423 if same_type::<A, $ty>() {
424 unsafe {
427 blas_sys::$gemm(
428 cblas_layout,
429 a_trans,
430 b_trans,
431 m as blas_index, n as blas_index, k as blas_index, gemm_scalar_cast!($ty, alpha), a.ptr.as_ptr() as *const _, lda, b.ptr.as_ptr() as *const _, ldb, gemm_scalar_cast!($ty, beta), c.ptr.as_ptr() as *mut _, ldc, );
443 }
444 return;
445 }
446 };
447 }
448
449 gemm!(f32, cblas_sgemm);
450 gemm!(f64, cblas_dgemm);
451 gemm!(c32, cblas_cgemm);
452 gemm!(c64, cblas_zgemm);
453
454 unreachable!() }
456 }
457 mat_mul_general(alpha, a, b, beta, c)
458}
459
460fn mat_mul_general<A>(
462 alpha: A, lhs: &ArrayView2<'_, A>, rhs: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>,
463) where A: LinalgScalar
464{
465 let ((m, k), (_, n)) = (lhs.dim(), rhs.dim());
466
467 let ap = lhs.as_ptr();
469 let bp = rhs.as_ptr();
470 let cp = c.as_mut_ptr();
471 let (rsc, csc) = (c.strides()[0], c.strides()[1]);
472 if same_type::<A, f32>() {
473 unsafe {
474 matrixmultiply::sgemm(
475 m,
476 k,
477 n,
478 cast_as(&alpha),
479 ap as *const _,
480 lhs.strides()[0],
481 lhs.strides()[1],
482 bp as *const _,
483 rhs.strides()[0],
484 rhs.strides()[1],
485 cast_as(&beta),
486 cp as *mut _,
487 rsc,
488 csc,
489 );
490 }
491 } else if same_type::<A, f64>() {
492 unsafe {
493 matrixmultiply::dgemm(
494 m,
495 k,
496 n,
497 cast_as(&alpha),
498 ap as *const _,
499 lhs.strides()[0],
500 lhs.strides()[1],
501 bp as *const _,
502 rhs.strides()[0],
503 rhs.strides()[1],
504 cast_as(&beta),
505 cp as *mut _,
506 rsc,
507 csc,
508 );
509 }
510 } else if same_type::<A, c32>() {
511 unsafe {
512 matrixmultiply::cgemm(
513 matrixmultiply::CGemmOption::Standard,
514 matrixmultiply::CGemmOption::Standard,
515 m,
516 k,
517 n,
518 complex_array(cast_as(&alpha)),
519 ap as *const _,
520 lhs.strides()[0],
521 lhs.strides()[1],
522 bp as *const _,
523 rhs.strides()[0],
524 rhs.strides()[1],
525 complex_array(cast_as(&beta)),
526 cp as *mut _,
527 rsc,
528 csc,
529 );
530 }
531 } else if same_type::<A, c64>() {
532 unsafe {
533 matrixmultiply::zgemm(
534 matrixmultiply::CGemmOption::Standard,
535 matrixmultiply::CGemmOption::Standard,
536 m,
537 k,
538 n,
539 complex_array(cast_as(&alpha)),
540 ap as *const _,
541 lhs.strides()[0],
542 lhs.strides()[1],
543 bp as *const _,
544 rhs.strides()[0],
545 rhs.strides()[1],
546 complex_array(cast_as(&beta)),
547 cp as *mut _,
548 rsc,
549 csc,
550 );
551 }
552 } else {
553 if c.is_empty() {
555 return;
556 }
557
558 if beta.is_zero() {
560 c.fill(beta);
561 }
562
563 let mut i = 0;
564 let mut j = 0;
565 loop {
566 unsafe {
567 let elt = c.uget_mut((i, j));
568 *elt =
569 *elt * beta + alpha * (0..k).fold(A::zero(), move |s, x| s + *lhs.uget((i, x)) * *rhs.uget((x, j)));
570 }
571 j += 1;
572 if j == n {
573 j = 0;
574 i += 1;
575 if i == m {
576 break;
577 }
578 }
579 }
580 }
581}
582
583#[track_caller]
595pub fn general_mat_mul<A, S1, S2, S3>(
596 alpha: A, a: &ArrayBase<S1, Ix2>, b: &ArrayBase<S2, Ix2>, beta: A, c: &mut ArrayBase<S3, Ix2>,
597) where
598 S1: Data<Elem = A>,
599 S2: Data<Elem = A>,
600 S3: DataMut<Elem = A>,
601 A: LinalgScalar,
602{
603 let ((m, k), (k2, n)) = (a.dim(), b.dim());
604 let (m2, n2) = c.dim();
605 if k != k2 || m != m2 || n != n2 {
606 general_dot_shape_error(m, k, k2, n, m2, n2);
607 } else {
608 mat_mul_impl(alpha, &a.view(), &b.view(), beta, &mut c.view_mut());
609 }
610}
611
612#[track_caller]
623#[allow(clippy::collapsible_if)]
624pub fn general_mat_vec_mul<A, S1, S2, S3>(
625 alpha: A, a: &ArrayBase<S1, Ix2>, x: &ArrayBase<S2, Ix1>, beta: A, y: &mut ArrayBase<S3, Ix1>,
626) where
627 S1: Data<Elem = A>,
628 S2: Data<Elem = A>,
629 S3: DataMut<Elem = A>,
630 A: LinalgScalar,
631{
632 unsafe { general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) }
633}
634
635#[allow(clippy::collapsible_else_if)]
644unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
645 alpha: A, a: &ArrayBase<S1, Ix2>, x: &ArrayBase<S2, Ix1>, beta: A, y: RawArrayViewMut<A, Ix1>,
646) where
647 S1: Data<Elem = A>,
648 S2: Data<Elem = A>,
649 A: LinalgScalar,
650{
651 let ((m, k), k2) = (a.dim(), x.dim());
652 let m2 = y.dim();
653 if k != k2 || m != m2 {
654 general_dot_shape_error(m, k, k2, 1, m2, 1);
655 } else {
656 #[cfg(feature = "blas")]
657 macro_rules! gemv {
658 ($ty:ty, $gemv:ident) => {
659 if same_type::<A, $ty>() {
660 if let Some(layout) = get_blas_compatible_layout(&a) {
661 if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) {
662 let a_trans = CblasNoTrans;
667
668 let a_stride = blas_stride(&a, layout);
669 let cblas_layout = layout.to_cblas_layout();
670
671 let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides);
673 let x_ptr = x.ptr.as_ptr().sub(x_offset);
674 let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides);
675 let y_ptr = y.ptr.as_ptr().sub(y_offset);
676
677 let x_stride = x.strides()[0] as blas_index;
678 let y_stride = y.strides()[0] as blas_index;
679
680 blas_sys::$gemv(
681 cblas_layout,
682 a_trans,
683 m as blas_index, k as blas_index, cast_as(&alpha), a.ptr.as_ptr() as *const _, a_stride, x_ptr as *const _, x_stride,
690 cast_as(&beta), y_ptr as *mut _, y_stride,
693 );
694 return;
695 }
696 }
697 }
698 };
699 }
700 #[cfg(feature = "blas")]
701 gemv!(f32, cblas_sgemv);
702 #[cfg(feature = "blas")]
703 gemv!(f64, cblas_dgemv);
704
705 if beta.is_zero() {
708 Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
710 elt.write(row.dot(x) * alpha);
711 });
712 } else {
713 Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
714 *elt = *elt * beta + row.dot(x) * alpha;
715 });
716 }
717 }
718}
719
720pub fn kron<A, S1, S2>(a: &ArrayBase<S1, Ix2>, b: &ArrayBase<S2, Ix2>) -> Array<A, Ix2>
725where
726 S1: Data<Elem = A>,
727 S2: Data<Elem = A>,
728 A: LinalgScalar,
729{
730 let dimar = a.shape()[0];
731 let dimac = a.shape()[1];
732 let dimbr = b.shape()[0];
733 let dimbc = b.shape()[1];
734 let mut out: Array2<MaybeUninit<A>> = Array2::uninit((
735 dimar
736 .checked_mul(dimbr)
737 .expect("Dimensions of kronecker product output array overflows usize."),
738 dimac
739 .checked_mul(dimbc)
740 .expect("Dimensions of kronecker product output array overflows usize."),
741 ));
742 Zip::from(out.exact_chunks_mut((dimbr, dimbc)))
743 .and(a)
744 .for_each(|out, &a| {
745 Zip::from(out).and(b).for_each(|out, &b| {
746 *out = MaybeUninit::new(a * b);
747 })
748 });
749 unsafe { out.assume_init() }
750}
751
752#[inline(always)]
753fn same_type<A: 'static, B: 'static>() -> bool
755{
756 TypeId::of::<A>() == TypeId::of::<B>()
757}
758
759fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B
763{
764 assert!(same_type::<A, B>(), "expect type {} and {} to match",
765 std::any::type_name::<A>(), std::any::type_name::<B>());
766 unsafe { ::std::ptr::read(a as *const _ as *const B) }
767}
768
769#[inline]
771fn complex_array<A: 'static + Copy>(z: Complex<A>) -> [A; 2]
772{
773 [z.re, z.im]
774}
775
776#[cfg(feature = "blas")]
777fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
778where
779 S: RawData,
780 A: 'static,
781 S::Elem: 'static,
782{
783 if !same_type::<A, S::Elem>() {
784 return false;
785 }
786 if a.len() > blas_index::MAX as usize {
787 return false;
788 }
789 let stride = a.strides()[0];
790 if stride == 0 || stride > blas_index::MAX as isize || stride < blas_index::MIN as isize {
791 return false;
792 }
793 true
794}
795
796#[cfg(feature = "blas")]
797#[derive(Copy, Clone)]
798#[cfg_attr(test, derive(PartialEq, Eq, Debug))]
799enum BlasOrder
800{
801 C,
802 F,
803}
804
805#[cfg(feature = "blas")]
806impl BlasOrder
807{
808 fn transpose(self) -> Self
809 {
810 match self {
811 Self::C => Self::F,
812 Self::F => Self::C,
813 }
814 }
815
816 #[inline]
817 fn get_blas_lead_axis(self) -> usize
819 {
820 match self {
821 Self::C => 0,
822 Self::F => 1,
823 }
824 }
825
826 fn to_cblas_layout(self) -> CBLAS_LAYOUT
827 {
828 match self {
829 Self::C => CBLAS_LAYOUT::CblasRowMajor,
830 Self::F => CBLAS_LAYOUT::CblasColMajor,
831 }
832 }
833
834 fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE
837 {
838 let effective_order = match for_layout {
839 CBLAS_LAYOUT::CblasRowMajor => self,
840 CBLAS_LAYOUT::CblasColMajor => self.transpose(),
841 };
842
843 match effective_order {
844 Self::C => CblasNoTrans,
845 Self::F => CblasTrans,
846 }
847 }
848}
849
850#[cfg(feature = "blas")]
851fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool
852{
853 let (m, n) = dim.into_pattern();
854 let s0 = stride[0] as isize;
855 let s1 = stride[1] as isize;
856 let (inner_stride, outer_stride, inner_dim, outer_dim) = match order {
857 BlasOrder::C => (s1, s0, m, n),
858 BlasOrder::F => (s0, s1, n, m),
859 };
860
861 if !(inner_stride == 1 || outer_dim == 1) {
862 return false;
863 }
864
865 if s0 < 1 || s1 < 1 {
866 return false;
867 }
868
869 if (s0 > blas_index::MAX as isize || s0 < blas_index::MIN as isize)
870 || (s1 > blas_index::MAX as isize || s1 < blas_index::MIN as isize)
871 {
872 return false;
873 }
874
875 if inner_dim > 1 && (outer_stride as usize) < outer_dim {
877 return false;
878 }
879
880 if m > blas_index::MAX as usize || n > blas_index::MAX as usize {
881 return false;
882 }
883
884 true
885}
886
887#[cfg(feature = "blas")]
889fn get_blas_compatible_layout<S>(a: &ArrayBase<S, Ix2>) -> Option<BlasOrder>
890where S: Data
891{
892 if is_blas_2d(&a.dim, &a.strides, BlasOrder::C) {
893 Some(BlasOrder::C)
894 } else if is_blas_2d(&a.dim, &a.strides, BlasOrder::F) {
895 Some(BlasOrder::F)
896 } else {
897 None
898 }
899}
900
901#[cfg(feature = "blas")]
906fn blas_stride<S>(a: &ArrayBase<S, Ix2>, order: BlasOrder) -> blas_index
907where S: Data
908{
909 let axis = order.get_blas_lead_axis();
910 let other_axis = 1 - axis;
911 let len_this = a.shape()[axis];
912 let len_other = a.shape()[other_axis];
913 let stride = a.strides()[axis];
914
915 (if len_this <= 1 {
920 Ord::max(stride, len_other as isize)
921 } else {
922 stride
923 }) as blas_index
924}
925
926#[cfg(test)]
927#[cfg(feature = "blas")]
928fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
929where
930 S: Data,
931 A: 'static,
932 S::Elem: 'static,
933{
934 if !same_type::<A, S::Elem>() {
935 return false;
936 }
937 is_blas_2d(&a.dim, &a.strides, BlasOrder::C)
938}
939
940#[cfg(test)]
941#[cfg(feature = "blas")]
942fn blas_column_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
943where
944 S: Data,
945 A: 'static,
946 S::Elem: 'static,
947{
948 if !same_type::<A, S::Elem>() {
949 return false;
950 }
951 is_blas_2d(&a.dim, &a.strides, BlasOrder::F)
952}
953
954#[cfg(test)]
955#[cfg(feature = "blas")]
956mod blas_tests
957{
958 use super::*;
959
960 #[test]
961 fn blas_row_major_2d_normal_matrix()
962 {
963 let m: Array2<f32> = Array2::zeros((3, 5));
964 assert!(blas_row_major_2d::<f32, _>(&m));
965 assert!(!blas_column_major_2d::<f32, _>(&m));
966 }
967
968 #[test]
969 fn blas_row_major_2d_row_matrix()
970 {
971 let m: Array2<f32> = Array2::zeros((1, 5));
972 assert!(blas_row_major_2d::<f32, _>(&m));
973 assert!(blas_column_major_2d::<f32, _>(&m));
974 }
975
976 #[test]
977 fn blas_row_major_2d_column_matrix()
978 {
979 let m: Array2<f32> = Array2::zeros((5, 1));
980 assert!(blas_row_major_2d::<f32, _>(&m));
981 assert!(blas_column_major_2d::<f32, _>(&m));
982 }
983
984 #[test]
985 fn blas_row_major_2d_transposed_row_matrix()
986 {
987 let m: Array2<f32> = Array2::zeros((1, 5));
988 let m_t = m.t();
989 assert!(blas_row_major_2d::<f32, _>(&m_t));
990 assert!(blas_column_major_2d::<f32, _>(&m_t));
991 }
992
993 #[test]
994 fn blas_row_major_2d_transposed_column_matrix()
995 {
996 let m: Array2<f32> = Array2::zeros((5, 1));
997 let m_t = m.t();
998 assert!(blas_row_major_2d::<f32, _>(&m_t));
999 assert!(blas_column_major_2d::<f32, _>(&m_t));
1000 }
1001
1002 #[test]
1003 fn blas_column_major_2d_normal_matrix()
1004 {
1005 let m: Array2<f32> = Array2::zeros((3, 5).f());
1006 assert!(!blas_row_major_2d::<f32, _>(&m));
1007 assert!(blas_column_major_2d::<f32, _>(&m));
1008 }
1009
1010 #[test]
1011 fn blas_row_major_2d_skip_rows_ok()
1012 {
1013 let m: Array2<f32> = Array2::zeros((5, 5));
1014 let mv = m.slice(s![..;2, ..]);
1015 assert!(blas_row_major_2d::<f32, _>(&mv));
1016 assert!(!blas_column_major_2d::<f32, _>(&mv));
1017 }
1018
1019 #[test]
1020 fn blas_row_major_2d_skip_columns_fail()
1021 {
1022 let m: Array2<f32> = Array2::zeros((5, 5));
1023 let mv = m.slice(s![.., ..;2]);
1024 assert!(!blas_row_major_2d::<f32, _>(&mv));
1025 assert!(!blas_column_major_2d::<f32, _>(&mv));
1026 }
1027
1028 #[test]
1029 fn blas_col_major_2d_skip_columns_ok()
1030 {
1031 let m: Array2<f32> = Array2::zeros((5, 5).f());
1032 let mv = m.slice(s![.., ..;2]);
1033 assert!(blas_column_major_2d::<f32, _>(&mv));
1034 assert!(!blas_row_major_2d::<f32, _>(&mv));
1035 }
1036
1037 #[test]
1038 fn blas_col_major_2d_skip_rows_fail()
1039 {
1040 let m: Array2<f32> = Array2::zeros((5, 5).f());
1041 let mv = m.slice(s![..;2, ..]);
1042 assert!(!blas_column_major_2d::<f32, _>(&mv));
1043 assert!(!blas_row_major_2d::<f32, _>(&mv));
1044 }
1045
1046 #[test]
1047 fn blas_too_short_stride()
1048 {
1049 const N: usize = 5;
1053 const MAXSTRIDE: usize = N + 2;
1054 let mut data = [0; MAXSTRIDE * N];
1055 let mut iter = 0..data.len();
1056 data.fill_with(|| iter.next().unwrap());
1057
1058 for stride in 1..=MAXSTRIDE {
1059 let m = ArrayView::from_shape((N, N).strides((stride, 1)), &data).unwrap();
1060 eprintln!("{:?}", m);
1061
1062 if stride < N {
1063 assert_eq!(get_blas_compatible_layout(&m), None);
1064 } else {
1065 assert_eq!(get_blas_compatible_layout(&m), Some(BlasOrder::C));
1066 }
1067 }
1068 }
1069}