@@ -7317,11 +7317,33 @@ def test_dot_equivalent(self, args):
7317
7317
r3 = np .matmul (args [0 ].copy (), args [1 ].copy ())
7318
7318
assert_equal (r1 , r3 )
7319
7319
7320
- # matrix matrix, issue 29164
7321
- if [len (args [0 ].shape ), len (args [1 ].shape )] == [2 , 2 ]:
7322
- out_f = np .zeros ((r2 .shape [0 ] * 2 , r2 .shape [1 ] * 2 ), order = 'F' )
7323
- r4 = np .matmul (* args , out = out_f [::2 , ::2 ])
7324
- assert_equal (r2 , r4 )
7320
+ # issue 29164 with extra checks
7321
+ @pytest .mark .parametrize ('dtype' , (
7322
+ np .float32 , np .float64 , np .complex64 , np .complex128
7323
+ ))
7324
+ def test_dot_equivalent_matrix_matrix_blastypes (self , dtype ):
7325
+ modes = list (itertools .product (['C' , 'F' ], [True , False ]))
7326
+
7327
+ def apply_mode (m , mode ):
7328
+ order , is_contiguous = mode
7329
+ if is_contiguous :
7330
+ return m .copy () if order == 'C' else m .T .copy ().T
7331
+
7332
+ retval = np .zeros (
7333
+ (m .shape [0 ] * 2 , m .shape [1 ] * 2 ), dtype = m .dtype , order = order
7334
+ )[::2 , ::2 ]
7335
+ retval [...] = m
7336
+ return retval
7337
+
7338
+ is_complex = np .issubdtype (dtype , np .complexfloating )
7339
+ m1 = self .m1 .astype (dtype ) + (1j if is_complex else 0 )
7340
+ m2 = self .m2 .astype (dtype ) + (1j if is_complex else 0 )
7341
+ dot_res = np .dot (m1 , m2 )
7342
+ mo = np .zeros_like (dot_res )
7343
+
7344
+ for mode in itertools .product (* [modes ]* 3 ):
7345
+ m1_ , m2_ , mo_ = [apply_mode (* x ) for x in zip ([m1 , m2 , mo ], mode )]
7346
+ assert_equal (np .matmul (m1_ , m2_ , out = mo_ ), dot_res )
7325
7347
7326
7348
def test_matmul_object (self ):
7327
7349
import fractions
0 commit comments