-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
ENH: special: support_alternative_backends
on Dask and jax.jit
#22639
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
||
if is_marray(_xp): | ||
data_args = [np.asarray(arg.data) for arg in array_args] | ||
out = _f(*data_args, *other_args, **kwargs) | ||
data_args = [arg.data for arg in array_args] | ||
f = globals()[f_name] # Allow nested wrapping | ||
out = f(*data_args, *other_args, **kwargs) | ||
mask = functools.reduce(operator.or_, (arg.mask for arg in array_args)) | ||
return _xp.asarray(out, mask=mask) | ||
|
||
elif is_dask(_xp): | ||
f = globals()[f_name] # Allow nested wrapping | ||
# IMPORTANT: this works only because all ufuncs in this module | ||
# are elementwise. It would be a grave mistake to apply this to gufuncs, | ||
# as they would change their output depending on chunking! | ||
return _xp.map_blocks( | ||
# Hide other_args as well as kwargs such as dtype from map_blocks | ||
lambda *array_args: f(*array_args, *other_args, **kwargs), | ||
*array_args | ||
) | ||
|
||
else: | ||
assert array_args # is_numpy(xp) is True | ||
device = xp_device(array_args[0]) | ||
array_args = [np.asarray(arg) for arg in array_args] | ||
out = _f(*array_args, *other_args, **kwargs) | ||
return _xp.asarray(out) | ||
return _xp.asarray(out, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I expect wrapper arrays to proliferate in the medium term future (@lucascolley's https://github.com/quantity-dev/quantity-array will soon join the party), this pattern is not sustainable in the long run.
I'm unsure about how to solve it. While it would be reasonable to encapsulate inside xpx.lazy_apply
all backend-specific unwrapping and wrapping, there is a fundamental problem with metadata.
The example above where the output marray mask is always the OR-reduction of the input masks is a simple special case and not representative of the general problem.
For example, quantity-array would need to know what kind of operation each binop ufunc performs: a+b (need to check for compatibility and to scale arrays while unwrapping)? a*b? a/b? (different operation on metadata and possibly data scaling).
e.g.
lambda x: x + y, x=qa.asarray(1, "g"), y=qa.asarray(1, "kg")
must scaley * 1000
; output units=glambda x: x + y, x=qa.asarray(1, "g"), y=qa.asarray(1, "m")
must crashlambda x: x * y, x=qa.asarray(1, "g"), y=qa.asarray(1, "m")
output units = g*mlambda x: x / y, x=qa.asarray(1, "g"), y=qa.asarray(1, "m")
output units = g/mlambda x: x / y, x=qa.asarray(1, "km"), y=qa.asarray(1, "m/s")
must scalex*1000
; output units = m/s
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I've thought about this myself a bit. I think quantity-array should provide a function which takes a callable and some standardised representation of how the output units are a function of the input units, and returns a Quantity with the output of that callable as the value and the desired units. I think this sort of standardised representation is within the scope of the broader quantity-dev work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it would also be nice to have marray.apply_func(f, (a, b), mask="or")
or similar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it would also be nice to have
marray.apply_func(f, (a, b), mask="or")
or similar.
I don't know the details of how mask propagation is supposed to work - could there be a ternary operation where the output mask is a more complex function, e.g. a and b or c
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Answering my own question by reading https://github.com/mdhaber/marray/blob/main/marray/__init__.py: every operation has its own arbitrary, point-by-point mask propagation rules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost all special functions - and all the ones we are treating with support_alternative_backends
right now - are elementwise functions, so we can restrict our attention to the scalar case. If we have a function f(a, b, c, d, ...)
that really depends on all the parameters, and we are missing one of them, then we cannot compute the value of the function. This is where the mask behavior comes from.
It was not designed to cover all cases that might arise in scipy.special
(e.g. logsumexp
), and it wasn't meant to be extendable to other libraries; it just does what we've needed so far with marray
support in SciPy.
Suppose we had a function like gamma(z, a=0, b=np.inf)
that computes marray.trim
we've assumed the latter, so we end up with the same mask propagation rule. And for simplicity, I think it's best to define it the same way for all elementwise functions, no exceptions - any masked input -> masked output.
Yes, maybe marray
could use a helper for this. I haven't thought about it much, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
every operation has its own arbitrary, point-by-point mask propagation rules.
I'm not sure if I'm understanding correctly, but that's not how I would express it. All the elementwise functions are handled with this block in a loop.
The rule is the same for all of them, and it's the same rule as here: if any of the arguments are masked, the output is masked.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm looking at xp.where
:
mask = xp.where(condition.data,
condition.mask | x1.mask,
condition.mask | x2.mask)
Do I understand correctly that a generic, agnostic condition.mask | x1.mask | x2.mask
would not be wrong, and that the above is just more precise? In other words, that having a false-positive on the mask is acceptable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is a standards document for masked array behavior, i'd really like to know about it! But I'm not aware of one, so I don't think it's a matter of right and wrong, but what's useful.
I think where
fails the premise from above:
If we have a function
$f(a, b, c, d, ...)$ that really depends on all the parameters...
in that the output doesn't depend on all the parameters, even in principle. For instance, np.where(True, 1, x)
doesn't depend on the value of x
at all - even a NaN won't propagate. So right or wrong, I don't think sharing the mask there would be the most useful behavior, and I think users would eventually complain.
Certainly there are special function for which, at particular values of one of the arguments, the value of one of the other arguments almost doesn't matter. But I think we'd still want to propagate a NaN in such cases, and I'd guess that we'd still propagate a mask, although depending on the situation, users might have a case to argue that we shouldn't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also wouldn't need to cover every possible pattern. Having arbitrarily many parameters admittedly introduces substantial complexity by itself, but my guess is that covering simple ops like n-bit OR will cover most common cases. Then for more complicated functions, we can continue doing what we are currently doing by defining the mask behaviour at the call site.
eeedf08
to
be3573b
Compare
670dcfd
to
3239bc3
Compare
9402618
to
3700e79
Compare
2e305cc
to
964f344
Compare
9971bbd
to
e2ad933
Compare
0a4b1da
to
6ca0c0a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
I was pinged here and in gh-22718, but think Ralf is interested in these things (gh-22246), so I'll let him review/merge. I am probably fine with whatever is done with
|
Does this stand true for all |
No. By "essentially all", I mean almost all, and that was intended to cover the statement about working elementwise. We discussed this in the context of In any case, the point was that we took advantage of this in the initial implementation. The scope was not all possible backends on all possible devices with all possible function signatures then. That is the problem faced in some sub-packages now, though, so it might be worth thinking about the general problem and applying that general solution here. |
I still am, yes. In the language of signal and ndimage, the system in special has just a few delegators: It might make sense though to consider #20678 (comment), esp the output rewrapping. Does it make sense in scipy.special? |
As a rule of thumb it makes sense to use the same dispatch mechanism everywhere.
I hadn't thought about dispatching between GPU backends - that's a pretty neat thing to do. I just would be careful about dispatch priority (namely, when inside jax.jit I expect an Array API agnostic implementation to be much faster than CuPy or PyTorch). |
+1 to this, with some fine print. First, I rather strongly believe that scipy's delegation story must be as straightforward as possible. If there are multiple options with potential perf implications, we should use one strategy by default and give users a way to override this choice explicitly. No more, not until we gain more collective experience with dispatch/delegation across the ecosystem. Second, with scipy.special specifically. ISTM we should factor xsf development into consideration. Last I heard, the plan was for essentially all elementwise kernels to migrate to |
I imagine we will need to delegate to EDIT: I realise now that this might be the same as what you mean by an xsf-powered layer |
|
Of course. |
I think that the grand unified dispatch design that is being discussed above should be a topic for a follow-up, not this PR. |
support_alternative_backends
on Dasksupport_alternative_backends
on Dask and jax.jit
I don't think I ever read this at the time of the comment, sorry @mdhaber. That's probably referring to scientific-python/spatch#1, but it looks like that hasn't had any activity in a while. @betatim or @stefanv may be able to tell you more about scikit-image/scikit-image#7520 and the prospects for generalising further. |
We implemented whole-function dispatching in NetworkX and scikit-image, and I suspect at some point we'll try to find commonalities and repackage it as spatch. |
This is part of a set of three PRs, which can be merged in any order with minor conflicts:
rel_entr
#22641support_alternative_backends
on Dask and jax.jit #22639Branch that contains all three: https://github.com/crusaderky/scipy/tree/special_staging
In this PR
scipy.special.log_ndtr
,ndtr
,ndtri
,erf
,erfc
,i0
,i0e
,i1
,i1e
,gammaln
,gammainc
,gammaincc
,logit
,expit
,entr
,chdtr
,chdtrc
,betainc
,betaincc
,stdtr
,stdtrit
scipy.special.xlogy
jax.jit
too