-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
ENH: Add support for inplace matrix multiplication #21120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
f32cdf6
to
83be62b
Compare
Could you add a test for a large matrix, I am slightly worried about this only accidentally working. The problem is that if the input is exactly identical to the output, we probably pass this on into lapack (einsum?) directly. |
I just added a dedicated test case in 1e62199; even when upping the size of the matrix from |
Thanks, I like to have the test in either case, I will have a closer look at the code. I have to check if and when the iterator makes a copy anyway. If it does (which is plausible) writing |
Sorry, my bad, I missed/forgot that we of course already have a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test for what happens when a @ b
and a
are not the same shape? Especially cases where a broadcast could happen, if a
is (n, n)
and b
is (n, 1)
.
Great point, and potentially tricky :/. I remember fixing some things around this and making it more strict, but we did not disallow the output to broadcast the inputs at that time. This means that this is still incorrect:
There could be an argument to allow the output to broadcast in some cases, but when inplace ops are used for a gufunc we are on very thin ice. EDIT: Since it may not be clear, Eric's example is fine, because it happens in the core dimensions, but we currently do allow the output to broadcast in the non-core dimensions. In the above example, the input has more core dimensions than the output, so an input core dimension |
Currently it raises an exception, rather than automatically broadcasting the output. In this sense it's somewhat unique compared to the other inplace operation, but then again All in all I think I might prefer the current (more explicit) In [1]: a = np.arange(9).reshape(3, 3)
...: b = np.ones((3, 1), dtype=int)
In [2]: a @= b
---------------------------------------------------------------------------
ValueError: matmul: Output operand 0 has a mismatch in its core dimension 1, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 3 is different from 1) |
On a unrelated note: either sphinx or one its dependencies is causing issues in circleci. Traceback (most recent call last):
File "/home/circleci/repo/venv/lib/python3.8/site-packages/sphinx/cmd/build.py", line 280, in build_main
app.build(args.force_all, filenames)
File "/home/circleci/repo/venv/lib/python3.8/site-packages/sphinx/application.py", line 343, in build
self.builder.build_update()
File "/home/circleci/repo/venv/lib/python3.8/site-packages/sphinx/builders/__init__.py", line 293, in build_update
self.build(to_build,
File "/home/circleci/repo/venv/lib/python3.8/site-packages/sphinx/builders/__init__.py", line 307, in build
updated_docnames = set(self.read())
File "/home/circleci/repo/venv/lib/python3.8/site-packages/sphinx/builders/__init__.py", line 412, in read
self._read_parallel(docnames, nproc=self.app.parallel)
File "/home/circleci/repo/venv/lib/python3.8/site-packages/sphinx/builders/__init__.py", line 463, in _read_parallel
tasks.join()
File "/home/circleci/repo/venv/lib/python3.8/site-packages/sphinx/util/parallel.py", line 108, in join
if not self._join_one():
File "/home/circleci/repo/venv/lib/python3.8/site-packages/sphinx/util/parallel.py", line 129, in _join_one
raise SphinxParallelError(*result)
sphinx.errors.SphinxParallelError: AttributeError: 'PosixPath' object has no attribute 'readlink' |
Marking for triage-review for now. I am not quite sure what to do about my example where the number of core dimensions of input and output differs; can we ignore it? |
I would be in favor of ignoring it and just letting it raise (i.e. the current behavior). There a few other functions that reduce either the dimensionality or alter the output shape w.r.t. the input shape (e.g. In [1]: a = np.arange(9).reshape(3, 3)
...: b = np.ones((3, 1), dtype=int)
In [2]: b.mean(out=a)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
...
ValueError: output parameter for reduction operation add has the wrong number of dimensions: Found 2 but expected 0 |
@BvB93 sorry, I should have repeated my example: In [1]: arr = np.array([1.0, 2.0])
In [2]: arr @ arr
Out[2]: 5.0
In [3]: arr @= arr
In [4]: arr
Out[4]: array([5., 5.]) |
numpy/core/src/multiarray/number.c
Outdated
} | ||
|
||
INPLACE_GIVE_UP_IF_NEEDED(m1, m2_array, | ||
nb_inplace_matrix_multiply, array_inplace_matrix_multiply); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit worried whether this must not happen before any conversion (and what is correct in general). But I can try to dig into that later, also.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is always done first so lets do it first here also.
close/reopen to regenerate CI logs |
I pulled the nuclear option on this, because the tests noticed the array coercion. This means that the errors are pretty terrible, maybe most impressive:
but at least the |
By "nuclear" you mean call back out to python with |
Yeah, that is what I mean, the initial code converted the We can improve the generic message a little, but that will still be a bit confusing probably. I suspect getting a nice one would actually require to manually replace it here (maybe pragmatically assuming that the other side isn't a super weird object). |
Something like "matmul operand 2 has 1 dimension, but 2 dimensions are required (axes tuple has length 2)." (EDIT: add the ufunc name, we should give it.) |
OK, now based on changing the other error to use AxisError, replacing it here now (a bit of a best effort since you could do evil things with nested object arrays). That removes the worst of the errors (which referred to the |
… in the array-api
…utput Add special casing for `1d @ 1d` and `2d @ 1d` ops.
In principle, this is probably still not 100% correct always since we convert the the other object to an array upfront (to get the error). But the alternative solution using `axes=` seems tricky as well...
This uses the `axes` argument. Arguably the most correct version since we use the full ufunc machinery, so we can't mess up with random conversions (which the test suite actually did notice, at least for an array subclass). OTOH, for now the errors are ridiculously unhelpful, some of which could be fixed in the gufunc code (at least made _better_).
In theory, an object matmul could be caught here, but lets assume that doesn't happen...
It appears that assertion is not true for bad __array_ufunc__ or similar implementations. And we have at least one test that runs into it.
I think we should put this in and work out better error messages as they arise. For the case above, the message is now:
A casting error also seems clear enough:
|
I finished this from my side, if you are happy with the new solution. |
This PR adds support for inplace matrix multiplication via the
@=
operator.From what I gather,
@=
was made illegal back in the day whennp.matmul
was still a normal function, as the latter lacked support for something akin to anout
parameter anda @= b
would thus be silently expanded toa = a @ b
. This limitation was eventually resolved when matmul was converted into a ufunc, removing the issue previously plaguing@=
.Examples