Skip to content

ENH: Add support for inplace matrix multiplication #21120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 26, 2023
Merged

Conversation

BvB93
Copy link
Member

@BvB93 BvB93 commented Feb 25, 2022

This PR adds support for inplace matrix multiplication via the @= operator.

From what I gather, @= was made illegal back in the day when np.matmul was still a normal function, as the latter lacked support for something akin to an out parameter and a @= b would thus be silently expanded to a = a @ b. This limitation was eventually resolved when matmul was converted into a ufunc, removing the issue previously plaguing @=.

Examples

>>> import numpy as np
>>> a = np.arange(6).reshape(3, 2)
>>> print(a)
[[0 1]
 [2 3]
 [4 5]]

>>> b = np.ones((2, 2), dtype=int)
>>> a @= b
>>> print(a)
[[1 1]
 [5 5]
 [9 9]]

@BvB93 BvB93 force-pushed the matmul branch 2 times, most recently from f32cdf6 to 83be62b Compare February 25, 2022 17:28
@seberg
Copy link
Member

seberg commented Feb 25, 2022

Could you add a test for a large matrix, I am slightly worried about this only accidentally working. The problem is that if the input is exactly identical to the output, we probably pass this on into lapack (einsum?) directly.
Even after a quick check, I am not sure that the matmul inner-loop is actually able to deal with overlap just fine.

@BvB93
Copy link
Member Author

BvB93 commented Feb 25, 2022

Could you add a test for a large matrix, I am slightly worried about this only accidentally working.

I just added a dedicated test case in 1e62199; even when upping the size of the matrix from 10**6 to 10**8 I'm still getting the expected result. I assume these results are somewhat promising?

@seberg
Copy link
Member

seberg commented Feb 25, 2022

Thanks, I like to have the test in either case, I will have a closer look at the code.

I have to check if and when the iterator makes a copy anyway. If it does (which is plausible) writing a @= b vs. a = a @ b has no advantage (in fact it is probably worse), but it will be safe. If it does not copy, things get a bit tricky, because matmul has a bunch of different code paths that may be taken, and we have to check/test all of them.

@seberg
Copy link
Member

seberg commented Feb 25, 2022

Sorry, my bad, I missed/forgot that we of course already have a matmul special case that always forces the copy. So this is safe, but slow/not very useful.

Copy link
Member

@eric-wieser eric-wieser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for what happens when a @ b and a are not the same shape? Especially cases where a broadcast could happen, if a is (n, n) and b is (n, 1).

@seberg
Copy link
Member

seberg commented Feb 25, 2022

Especially cases where a broadcast could happen, if a is (n, n) and b is (n, 1).

Great point, and potentially tricky :/. I remember fixing some things around this and making it more strict, but we did not disallow the output to broadcast the inputs at that time. This means that this is still incorrect:

In [4]: arr = np.array([1., 2.])

In [5]: np.matmul(arr, arr, out=arr)
Out[5]: array([5., 5.])

There could be an argument to allow the output to broadcast in some cases, but when inplace ops are used for a gufunc we are on very thin ice.

EDIT: Since it may not be clear, Eric's example is fine, because it happens in the core dimensions, but we currently do allow the output to broadcast in the non-core dimensions. In the above example, the input has more core dimensions than the output, so an input core dimension because becomes an outer output dimension.

@BvB93
Copy link
Member Author

BvB93 commented Feb 25, 2022

Can you add a test for what happens when a @ b and a are not the same shape? Especially cases where a broadcast could happen, if a is (n, n) and b is (n, 1).

Currently it raises an exception, rather than automatically broadcasting the output. In this sense it's somewhat unique compared to the other inplace operation, but then again __matmul__ is also the only dunder that can produce an array that is smaller in size compared to the original.

All in all I think I might prefer the current (more explicit) ValueError over an implicit broadcast (this sounds like something that could easily bite someone in the ass).

In [1]: a = np.arange(9).reshape(3, 3)
   ...: b = np.ones((3, 1), dtype=int)

In [2]: a @= b
---------------------------------------------------------------------------
ValueError: matmul: Output operand 0 has a mismatch in its core dimension 1, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 3 is different from 1)

@BvB93
Copy link
Member Author

BvB93 commented Feb 25, 2022

On a unrelated note: either sphinx or one its dependencies is causing issues in circleci.
Based on the exception a package requiring python >= 3.9 is being installed on python 3.8.

Traceback (most recent call last):
  File "/home/circleci/repo/venv/lib/python3.8/site-packages/sphinx/cmd/build.py", line 280, in build_main
    app.build(args.force_all, filenames)
  File "/home/circleci/repo/venv/lib/python3.8/site-packages/sphinx/application.py", line 343, in build
    self.builder.build_update()
  File "/home/circleci/repo/venv/lib/python3.8/site-packages/sphinx/builders/__init__.py", line 293, in build_update
    self.build(to_build,
  File "/home/circleci/repo/venv/lib/python3.8/site-packages/sphinx/builders/__init__.py", line 307, in build
    updated_docnames = set(self.read())
  File "/home/circleci/repo/venv/lib/python3.8/site-packages/sphinx/builders/__init__.py", line 412, in read
    self._read_parallel(docnames, nproc=self.app.parallel)
  File "/home/circleci/repo/venv/lib/python3.8/site-packages/sphinx/builders/__init__.py", line 463, in _read_parallel
    tasks.join()
  File "/home/circleci/repo/venv/lib/python3.8/site-packages/sphinx/util/parallel.py", line 108, in join
    if not self._join_one():
  File "/home/circleci/repo/venv/lib/python3.8/site-packages/sphinx/util/parallel.py", line 129, in _join_one
    raise SphinxParallelError(*result)
sphinx.errors.SphinxParallelError: AttributeError: 'PosixPath' object has no attribute 'readlink'

@seberg seberg added the triage review Issue/PR to be discussed at the next triage meeting label Feb 26, 2022
@seberg
Copy link
Member

seberg commented Feb 26, 2022

Marking for triage-review for now. I am not quite sure what to do about my example where the number of core dimensions of input and output differs; can we ignore it?
A possible work-around to reject it (without modifying the general case) would be to pass axes=[(-1,), None, (-1,)] to ensure that input and output have the same number of core dimensions (or (-2, -1) of course)).
However, it seems like it is not possible to pass None for the second input and I am not sure how hard it would be to add that functionality (it feels like it shouldn't be hard, but not sure).

@BvB93
Copy link
Member Author

BvB93 commented Mar 2, 2022

I am not quite sure what to do about my example where the number of core dimensions of input and output differs; can we ignore it?

I would be in favor of ignoring it and just letting it raise (i.e. the current behavior). There a few other functions that reduce either the dimensionality or alter the output shape w.r.t. the input shape (e.g. np.mean) and those raise a ValueError as well, even though the result could in principle be broadcasted.

In [1]: a = np.arange(9).reshape(3, 3)
   ...: b = np.ones((3, 1), dtype=int)

In [2]: b.mean(out=a)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
  ...
ValueError: output parameter for reduction operation add has the wrong number of dimensions: Found 2 but expected 0

@seberg
Copy link
Member

seberg commented Mar 2, 2022

@BvB93 sorry, I should have repeated my example:

In [1]: arr = np.array([1.0, 2.0])

In [2]: arr @ arr
Out[2]: 5.0

In [3]: arr @= arr

In [4]: arr
Out[4]: array([5., 5.])

}

INPLACE_GIVE_UP_IF_NEEDED(m1, m2_array,
nb_inplace_matrix_multiply, array_inplace_matrix_multiply);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit worried whether this must not happen before any conversion (and what is correct in general). But I can try to dig into that later, also.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is always done first so lets do it first here also.

@seberg seberg removed the triage review Issue/PR to be discussed at the next triage meeting label Mar 9, 2022
@charris charris closed this Apr 6, 2022
@charris charris reopened this Apr 6, 2022
@mattip
Copy link
Member

mattip commented Nov 17, 2022

close/reopen to regenerate CI logs

@mattip mattip closed this Nov 17, 2022
@mattip mattip reopened this Nov 17, 2022
@seberg
Copy link
Member

seberg commented Nov 22, 2022

I pulled the nuclear option on this, because the tests noticed the array coercion. This means that the errors are pretty terrible, maybe most impressive:

In [2]: arr = np.ones((3, 3))
In [3]: arr @= np.ones(3)
ValueError: axes item 1 should be a tuple with a single element, or an integer

but at least the ValueError part is technically correct :cough:.

@mattip
Copy link
Member

mattip commented Nov 24, 2022

By "nuclear" you mean call back out to python with np.matmul(a, b, out=a, axes=[(-2, -1), (-2, -1), (-2, -1)]). How do you think we could get a better error message?

@seberg
Copy link
Member

seberg commented Nov 24, 2022

Yeah, that is what I mean, the initial code converted the b to an array in a @= b, which makes it easy to check things up-front.

We can improve the generic message a little, but that will still be a bit confusing probably. I suspect getting a nice one would actually require to manually replace it here (maybe pragmatically assuming that the other side isn't a super weird object).

@seberg
Copy link
Member

seberg commented Nov 24, 2022

Something like "matmul operand 2 has 1 dimension, but 2 dimensions are required (axes tuple has length 2)." (EDIT: add the ufunc name, we should give it.)

@seberg
Copy link
Member

seberg commented Nov 28, 2022

OK, now based on changing the other error to use AxisError, replacing it here now (a bit of a best effort since you could do evil things with nested object arrays). That removes the worst of the errors (which referred to the axes argument the user never even supplied).

Bas van Beek and others added 11 commits December 2, 2022 00:29
…utput

Add special casing for `1d @ 1d` and `2d @ 1d` ops.
In principle, this is probably still not 100% correct always since
we convert the the other object to an array upfront (to get the
error).  But the alternative solution using `axes=` seems tricky
as well...
This uses the `axes` argument.  Arguably the most correct version since
we use the full ufunc machinery, so we can't mess up with random conversions
(which the test suite actually did notice, at least for an array subclass).

OTOH, for now the errors are ridiculously unhelpful, some of which could
be fixed in the gufunc code (at least made _better_).
In theory, an object matmul could be caught here, but lets assume
that doesn't happen...
It appears that assertion is not true for bad __array_ufunc__
or similar implementations.  And we have at least one test that runs
into it.
@seberg seberg added the triage review Issue/PR to be discussed at the next triage meeting label Jan 20, 2023
@mattip
Copy link
Member

mattip commented Mar 26, 2023

I think we should put this in and work out better error messages as they arise. For the case above, the message is now:

>>> arr = np.ones((3, 3))
>>> arr @= np.ones(3)
Traceback (most recent call last):
  File "<console>", line 1, in <module>
ValueError: inplace matrix multiplication requires the first operand to have \
    at least one and the second at least two dimensions.

A casting error also seems clear enough:

>>> arr = np.ones((3, 3), dtype=int)
>>> arr @= np.ones((3,3))
Traceback (most recent call last):
  File "<console>", line 1, in <module>
numpy.core._exceptions._UFuncOutputCastingError: Cannot cast ufunc 'matmul' output \
    from dtype('float64') to dtype('int64') with casting rule 'same_kind'

@seberg
Copy link
Member

seberg commented Mar 26, 2023

I finished this from my side, if you are happy with the new solution.

@mattip mattip merged commit a37978a into numpy:main Mar 26, 2023
@mattip
Copy link
Member

mattip commented Mar 26, 2023

Thanks @BvB93 and @seberg.

@seberg seberg mentioned this pull request Jun 20, 2023
8 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
01 - Enhancement triage review Issue/PR to be discussed at the next triage meeting
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy