Skip to content

ENH: Add support for inplace matrix multiplication #21120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions doc/release/upcoming_changes/21120.new_feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
Add support for inplace matrix multiplication
----------------------------------------------
It is now possible to perform inplace matrix multiplication
via the ``@=`` operator.

.. code-block:: python

>>> import numpy as np

>>> a = np.arange(6).reshape(3, 2)
>>> print(a)
[[0 1]
[2 3]
[4 5]]

>>> b = np.ones((2, 2), dtype=int)
>>> a @= b
>>> print(a)
[[1 1]
[5 5]
[9 9]]
14 changes: 13 additions & 1 deletion numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1928,7 +1928,6 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
def __neg__(self: NDArray[object_]) -> Any: ...

# Binary ops
# NOTE: `ndarray` does not implement `__imatmul__`
@overload
def __matmul__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc]
@overload
Expand Down Expand Up @@ -2515,6 +2514,19 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
@overload
def __ior__(self: NDArray[object_], other: Any) -> NDArray[object_]: ...

@overload
def __imatmul__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ...
@overload
def __imatmul__(self: NDArray[unsignedinteger[_NBit1]], other: _ArrayLikeUInt_co) -> NDArray[unsignedinteger[_NBit1]]: ...
@overload
def __imatmul__(self: NDArray[signedinteger[_NBit1]], other: _ArrayLikeInt_co) -> NDArray[signedinteger[_NBit1]]: ...
@overload
def __imatmul__(self: NDArray[floating[_NBit1]], other: _ArrayLikeFloat_co) -> NDArray[floating[_NBit1]]: ...
@overload
def __imatmul__(self: NDArray[complexfloating[_NBit1, _NBit1]], other: _ArrayLikeComplex_co) -> NDArray[complexfloating[_NBit1, _NBit1]]: ...
@overload
def __imatmul__(self: NDArray[object_], other: Any) -> NDArray[object_]: ...

def __dlpack__(self: NDArray[number[Any]], *, stream: None = ...) -> _PyCapsule: ...
def __dlpack_device__(self) -> tuple[int, L[0]]: ...

Expand Down
14 changes: 2 additions & 12 deletions numpy/array_api/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,23 +850,13 @@ def __imatmul__(self: Array, other: Array, /) -> Array:
"""
Performs the operation __imatmul__.
"""
# Note: NumPy does not implement __imatmul__.

# matmul is not defined for scalars, but without this, we may get
# the wrong error message from asarray.
other = self._check_allowed_dtypes(other, "numeric", "__imatmul__")
if other is NotImplemented:
return other

# __imatmul__ can only be allowed when it would not change the shape
# of self.
other_shape = other.shape
if self.shape == () or other_shape == ():
raise ValueError("@= requires at least one dimension")
if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]:
raise ValueError("@= cannot change the shape of the input array")
self._array[:] = self._array.__matmul__(other._array)
return self
res = self._array.__imatmul__(other._array)
return self.__class__._new(res)

def __rmatmul__(self: Array, other: Array, /) -> Array:
"""
Expand Down
72 changes: 65 additions & 7 deletions numpy/core/src/multiarray/number.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ static PyObject *
array_inplace_remainder(PyArrayObject *m1, PyObject *m2);
static PyObject *
array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo));
static PyObject *
array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2);

/*
* Dictionary can contain any of the numeric operations, by name.
Expand Down Expand Up @@ -339,7 +341,6 @@ array_divmod(PyObject *m1, PyObject *m2)
return PyArray_GenericBinaryFunction(m1, m2, n_ops.divmod);
}

/* Need this to be version dependent on account of the slot check */
static PyObject *
array_matrix_multiply(PyObject *m1, PyObject *m2)
{
Expand All @@ -348,13 +349,70 @@ array_matrix_multiply(PyObject *m1, PyObject *m2)
}

static PyObject *
array_inplace_matrix_multiply(
PyArrayObject *NPY_UNUSED(m1), PyObject *NPY_UNUSED(m2))
array_inplace_matrix_multiply(PyArrayObject *self, PyObject *other)
{
PyErr_SetString(PyExc_TypeError,
"In-place matrix multiplication is not (yet) supported. "
"Use 'a = a @ b' instead of 'a @= b'.");
return NULL;
static PyObject *AxisError_cls = NULL;
npy_cache_import("numpy.exceptions", "AxisError", &AxisError_cls);
if (AxisError_cls == NULL) {
return NULL;
}

INPLACE_GIVE_UP_IF_NEEDED(self, other,
nb_inplace_matrix_multiply, array_inplace_matrix_multiply);

/*
* Unlike `matmul(a, b, out=a)` we ensure that the result is not broadcast
* if the result without `out` would have less dimensions than `a`.
* Since the signature of matmul is '(n?,k),(k,m?)->(n?,m?)' this is the
* case exactly when the second operand has both core dimensions.
*
* The error here will be confusing, but for now, we enforce this by
* passing the correct `axes=`.
*/
static PyObject *axes_1d_obj_kwargs = NULL;
static PyObject *axes_2d_obj_kwargs = NULL;
if (NPY_UNLIKELY(axes_1d_obj_kwargs == NULL)) {
axes_1d_obj_kwargs = Py_BuildValue(
"{s, [(i), (i, i), (i)]}", "axes", -1, -2, -1, -1);
if (axes_1d_obj_kwargs == NULL) {
return NULL;
}
}
if (NPY_UNLIKELY(axes_2d_obj_kwargs == NULL)) {
axes_2d_obj_kwargs = Py_BuildValue(
"{s, [(i, i), (i, i), (i, i)]}", "axes", -2, -1, -2, -1, -2, -1);
if (axes_2d_obj_kwargs == NULL) {
return NULL;
}
}

PyObject *args = PyTuple_Pack(3, self, other, self);
if (args == NULL) {
return NULL;
}
PyObject *kwargs;
if (PyArray_NDIM(self) == 1) {
kwargs = axes_1d_obj_kwargs;
}
else {
kwargs = axes_2d_obj_kwargs;
}
PyObject *res = PyObject_Call(n_ops.matmul, args, kwargs);
Py_DECREF(args);

if (res == NULL) {
/*
* AxisError should indicate that the axes argument didn't work out
* which should mean the second operand not being 2 dimensional.
*/
if (PyErr_ExceptionMatches(AxisError_cls)) {
PyErr_SetString(PyExc_ValueError,
"inplace matrix multiplication requires the first operand to "
"have at least one and the second at least two dimensions.");
}
}

return res;
}

/*
Expand Down
77 changes: 66 additions & 11 deletions numpy/core/tests/test_multiarray.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import collections.abc
import tempfile
import sys
Expand Down Expand Up @@ -3693,7 +3695,7 @@ def test_ufunc_binop_interaction(self):
'and': (np.bitwise_and, True, int),
'xor': (np.bitwise_xor, True, int),
'or': (np.bitwise_or, True, int),
'matmul': (np.matmul, False, float),
'matmul': (np.matmul, True, float),
# 'ge': (np.less_equal, False),
# 'gt': (np.less, False),
# 'le': (np.greater_equal, False),
Expand Down Expand Up @@ -7155,16 +7157,69 @@ def test_matmul_raises(self):
assert_raises(TypeError, self.matmul, np.void(b'abc'), np.void(b'abc'))
assert_raises(TypeError, self.matmul, np.arange(10), np.void(b'abc'))

def test_matmul_inplace():
# It would be nice to support in-place matmul eventually, but for now
# we don't have a working implementation, so better just to error out
# and nudge people to writing "a = a @ b".
a = np.eye(3)
b = np.eye(3)
assert_raises(TypeError, a.__imatmul__, b)
import operator
assert_raises(TypeError, operator.imatmul, a, b)
assert_raises(TypeError, exec, "a @= b", globals(), locals())

class TestMatmulInplace:
DTYPES = {}
for i in MatmulCommon.types:
for j in MatmulCommon.types:
if np.can_cast(j, i):
DTYPES[f"{i}-{j}"] = (np.dtype(i), np.dtype(j))

@pytest.mark.parametrize("dtype1,dtype2", DTYPES.values(), ids=DTYPES)
def test_basic(self, dtype1: np.dtype, dtype2: np.dtype) -> None:
a = np.arange(10).reshape(5, 2).astype(dtype1)
a_id = id(a)
b = np.ones((2, 2), dtype=dtype2)

ref = a @ b
a @= b

assert id(a) == a_id
assert a.dtype == dtype1
assert a.shape == (5, 2)
if dtype1.kind in "fc":
np.testing.assert_allclose(a, ref)
else:
np.testing.assert_array_equal(a, ref)

SHAPES = {
"2d_large": ((10**5, 10), (10, 10)),
"3d_large": ((10**4, 10, 10), (1, 10, 10)),
"1d": ((3,), (3,)),
"2d_1d": ((3, 3), (3,)),
"1d_2d": ((3,), (3, 3)),
"2d_broadcast": ((3, 3), (3, 1)),
"2d_broadcast_reverse": ((1, 3), (3, 3)),
"3d_broadcast1": ((3, 3, 3), (1, 3, 1)),
"3d_broadcast2": ((3, 3, 3), (1, 3, 3)),
"3d_broadcast3": ((3, 3, 3), (3, 3, 1)),
"3d_broadcast_reverse1": ((1, 3, 3), (3, 3, 3)),
"3d_broadcast_reverse2": ((3, 1, 3), (3, 3, 3)),
"3d_broadcast_reverse3": ((1, 1, 3), (3, 3, 3)),
}

@pytest.mark.parametrize("a_shape,b_shape", SHAPES.values(), ids=SHAPES)
def test_shapes(self, a_shape: tuple[int, ...], b_shape: tuple[int, ...]):
a_size = np.product(a_shape)
a = np.arange(a_size).reshape(a_shape).astype(np.float64)
a_id = id(a)

b_size = np.product(b_shape)
b = np.arange(b_size).reshape(b_shape)

ref = a @ b
if ref.shape != a_shape:
with pytest.raises(ValueError):
a @= b
return
else:
a @= b

assert id(a) == a_id
assert a.dtype.type == np.float64
assert a.shape == a_shape
np.testing.assert_allclose(a, ref)


def test_matmul_axes():
a = np.arange(3*4*5).reshape(3, 4, 5)
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy