Skip to content

ENH: special: support_alternative_backends on Dask and jax.jit #22639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 4, 2025

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented Mar 6, 2025

This is part of a set of three PRs, which can be merged in any order with minor conflicts:

Branch that contains all three: https://github.com/crusaderky/scipy/tree/special_staging

In this PR

  • Add no-materialization Dask support for scipy.special.log_ndtr, ndtr, ndtri, erf, erfc, i0, i0e, i1, i1e, gammaln, gammainc, gammaincc, logit, expit, entr, chdtr, chdtrc, betainc, betaincc, stdtr, stdtrit
  • Suppress warnings in Dask for scipy.special.xlogy
  • Test that all the above functions are compatible with jax.jit too
  • Dask and marray arrays wrapping around non-numpy objects - e.g. cupy - should now work with these functions. This feature remains untested and shouldn't be advertised for now as it requires fundamental groundwork in order to be offered to the public (XREF [WIP] ENH: dask+cupy, dask+sparse etc. namespaces data-apis/array-api-compat#270) CC @mdhaber

@github-actions github-actions bot added scipy.special enhancement A new feature or improvement labels Mar 6, 2025
Comment on lines 44 to 68

if is_marray(_xp):
data_args = [np.asarray(arg.data) for arg in array_args]
out = _f(*data_args, *other_args, **kwargs)
data_args = [arg.data for arg in array_args]
f = globals()[f_name] # Allow nested wrapping
out = f(*data_args, *other_args, **kwargs)
mask = functools.reduce(operator.or_, (arg.mask for arg in array_args))
return _xp.asarray(out, mask=mask)

elif is_dask(_xp):
f = globals()[f_name] # Allow nested wrapping
# IMPORTANT: this works only because all ufuncs in this module
# are elementwise. It would be a grave mistake to apply this to gufuncs,
# as they would change their output depending on chunking!
return _xp.map_blocks(
# Hide other_args as well as kwargs such as dtype from map_blocks
lambda *array_args: f(*array_args, *other_args, **kwargs),
*array_args
)

else:
assert array_args # is_numpy(xp) is True
device = xp_device(array_args[0])
array_args = [np.asarray(arg) for arg in array_args]
out = _f(*array_args, *other_args, **kwargs)
return _xp.asarray(out)
return _xp.asarray(out, device=device)
Copy link
Contributor Author

@crusaderky crusaderky Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I expect wrapper arrays to proliferate in the medium term future (@lucascolley's https://github.com/quantity-dev/quantity-array will soon join the party), this pattern is not sustainable in the long run.

I'm unsure about how to solve it. While it would be reasonable to encapsulate inside xpx.lazy_apply all backend-specific unwrapping and wrapping, there is a fundamental problem with metadata.

The example above where the output marray mask is always the OR-reduction of the input masks is a simple special case and not representative of the general problem.
For example, quantity-array would need to know what kind of operation each binop ufunc performs: a+b (need to check for compatibility and to scale arrays while unwrapping)? a*b? a/b? (different operation on metadata and possibly data scaling).

e.g.

  • lambda x: x + y, x=qa.asarray(1, "g"), y=qa.asarray(1, "kg") must scale y * 1000; output units=g
  • lambda x: x + y, x=qa.asarray(1, "g"), y=qa.asarray(1, "m") must crash
  • lambda x: x * y, x=qa.asarray(1, "g"), y=qa.asarray(1, "m") output units = g*m
  • lambda x: x / y, x=qa.asarray(1, "g"), y=qa.asarray(1, "m") output units = g/m
  • lambda x: x / y, x=qa.asarray(1, "km"), y=qa.asarray(1, "m/s") must scale x*1000; output units = m/s

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I've thought about this myself a bit. I think quantity-array should provide a function which takes a callable and some standardised representation of how the output units are a function of the input units, and returns a Quantity with the output of that callable as the value and the desired units. I think this sort of standardised representation is within the scope of the broader quantity-dev work.

Copy link
Member

@lucascolley lucascolley Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it would also be nice to have marray.apply_func(f, (a, b), mask="or") or similar.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it would also be nice to have marray.apply_func(f, (a, b), mask="or") or similar.

I don't know the details of how mask propagation is supposed to work - could there be a ternary operation where the output mask is a more complex function, e.g. a and b or c?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answering my own question by reading https://github.com/mdhaber/marray/blob/main/marray/__init__.py: every operation has its own arbitrary, point-by-point mask propagation rules.

Copy link
Contributor

@mdhaber mdhaber Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost all special functions - and all the ones we are treating with support_alternative_backends right now - are elementwise functions, so we can restrict our attention to the scalar case. If we have a function f(a, b, c, d, ...) that really depends on all the parameters, and we are missing one of them, then we cannot compute the value of the function. This is where the mask behavior comes from.

It was not designed to cover all cases that might arise in scipy.special (e.g. logsumexp), and it wasn't meant to be extendable to other libraries; it just does what we've needed so far with marray support in SciPy.

Suppose we had a function like gamma(z, a=0, b=np.inf) that computes $\gamma = \int_a^b t^{z-1} \exp(-t) dt$. Then we could have a choice to make about what a masked value means. Does that mean to use the default value, or does that mean that there is really an unknown value? I think for marray.trim we've assumed the latter, so we end up with the same mask propagation rule. And for simplicity, I think it's best to define it the same way for all elementwise functions, no exceptions - any masked input -> masked output.

Yes, maybe marray could use a helper for this. I haven't thought about it much, though.

Copy link
Contributor

@mdhaber mdhaber Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

every operation has its own arbitrary, point-by-point mask propagation rules.

I'm not sure if I'm understanding correctly, but that's not how I would express it. All the elementwise functions are handled with this block in a loop.

https://github.com/mdhaber/marray/blob/500d1027a8d633fecf298fc29527daca7edb08eb/marray/__init__.py#L305-L337

The rule is the same for all of them, and it's the same rule as here: if any of the arguments are masked, the output is masked.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking at xp.where:

        mask = xp.where(condition.data,
                        condition.mask | x1.mask,
                        condition.mask | x2.mask)

Do I understand correctly that a generic, agnostic condition.mask | x1.mask | x2.mask would not be wrong, and that the above is just more precise? In other words, that having a false-positive on the mask is acceptable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a standards document for masked array behavior, i'd really like to know about it! But I'm not aware of one, so I don't think it's a matter of right and wrong, but what's useful.

I think where fails the premise from above:

If we have a function $f(a, b, c, d, ...)$ that really depends on all the parameters...

in that the output doesn't depend on all the parameters, even in principle. For instance, np.where(True, 1, x) doesn't depend on the value of x at all - even a NaN won't propagate. So right or wrong, I don't think sharing the mask there would be the most useful behavior, and I think users would eventually complain.

Certainly there are special function for which, at particular values of one of the arguments, the value of one of the other arguments almost doesn't matter. But I think we'd still want to propagate a NaN in such cases, and I'd guess that we'd still propagate a mask, although depending on the situation, users might have a case to argue that we shouldn't.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also wouldn't need to cover every possible pattern. Having arbitrarily many parameters admittedly introduces substantial complexity by itself, but my guess is that covering simple ops like n-bit OR will cover most common cases. Then for more complicated functions, we can continue doing what we are currently doing by defining the mask behaviour at the call site.

@lucascolley lucascolley added the array types Items related to array API support and input array validation (see gh-18286) label Mar 6, 2025
@crusaderky crusaderky force-pushed the dask_special branch 5 times, most recently from eeedf08 to be3573b Compare March 18, 2025 23:17
@crusaderky crusaderky force-pushed the dask_special branch 3 times, most recently from 670dcfd to 3239bc3 Compare March 19, 2025 13:04
@crusaderky crusaderky requested a review from rgommers as a code owner March 19, 2025 13:04
@crusaderky crusaderky force-pushed the dask_special branch 2 times, most recently from 9402618 to 3700e79 Compare March 19, 2025 14:40
@crusaderky crusaderky marked this pull request as draft March 19, 2025 16:42
@crusaderky crusaderky force-pushed the dask_special branch 2 times, most recently from 9971bbd to e2ad933 Compare March 20, 2025 09:46
Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@lucascolley lucascolley added this to the 1.16.0 milestone Mar 31, 2025
@lucascolley lucascolley removed the request for review from rgommers March 31, 2025 22:10
@mdhaber
Copy link
Contributor

mdhaber commented Apr 1, 2025

I was pinged here and in gh-22718, but think Ralf is interested in these things (gh-22246), so I'll let him review/merge. I am probably fine with whatever is done with support_alternative_backends. My initial goal was to get alternative backend support working well enough in special to allow some progress in stats. For that, the code has served its purpose. But things have changed/progressed a lot since then, so it's good to take a fresh look at the problem. Some things that might be worth keeping in mind (and perhaps you already took these things into account - great!):

  • When data is on the CPU, the backend does not have a special function implementation, and there is no generic implementation in terms of Array API calls, the existing support_alternative_backends system automatically does a no-copy conversion to NumPy, evaluates the special function with the compiled SciPy function, then converts the result back to the original backend. Two things with that:
    • When data is on the GPU, we had talked about doing something similar with CuPy.
    • It would be nice to do those no-copy conversions with from_dlpack, or whatever is the "right" way. That was one of the motivations of gh-22049 and preceding PRs.
  • Looks like @crusaderky already did it, but here's something I wanted to have at some point: "Dask and marray arrays wrapping around non-numpy objects - e.g. cupy - should now work with these functions." Before, I think special functions on marrays were always evaluated with NumPy (if possible, otherwise error). Now, I think this means that if marray is wrapping a CuPy array and CuPy has the special function, it should do the special function evaluation with CuPy.
  • The support_alternative_backends systems in signal and ndimage ended up looking pretty different. We had talked about a common system for all modules, so if the special one is being rewritten, it might be a good time to revisit that. (CC @ev-br in case you're still interested.) I think @lucascolley even mentioned that NetworkX and others were working on some sort of common system, and we should use that? (Do you know what I'm talking about, Lucas? Is that still relevant, or do we still need to roll our own system?) special is relatively simple because essentially almost all arguments are arrays and everything almost everything works elementwise, but I thought it would be worth mentioning.

@crusaderky
Copy link
Contributor Author

special is relatively simple because essentially all arguments are arrays and everything works elementwise

Does this stand true for all special functions though? For example logsumexp (bad example, as it is implemented agnostically in array API) performs reductions.

@mdhaber
Copy link
Contributor

mdhaber commented Apr 1, 2025

Does this stand true for all special functions though?

No. By "essentially all", I mean almost all, and that was intended to cover the statement about working elementwise. We discussed this in the context of marray support in special, but I will clarify the wording above to avoid confusion.

In any case, the point was that we took advantage of this in the initial implementation. The scope was not all possible backends on all possible devices with all possible function signatures then. That is the problem faced in some sub-packages now, though, so it might be worth thinking about the general problem and applying that general solution here.

@ev-br
Copy link
Member

ev-br commented Apr 1, 2025

The support_alternative_backends systems in signal and ndimage ended up looking pretty different. We had talked about a common system for all modules, so if the special one is being rewritten, it might be a good time to revisit that. (CC @ev-br in case you're still interested.)

I still am, yes. In the language of signal and ndimage, the system in special has just a few delegators: def single_arg_signature(x, *args, **kwds): return array_namespace(x) and def two_arg_signature(x, y, *args, **kwds): return array_namespace(x, y).

It might make sense though to consider #20678 (comment), esp the output rewrapping. Does it make sense in scipy.special?

@crusaderky
Copy link
Contributor Author

The support_alternative_backends systems in signal and ndimage ended up looking pretty different. We had talked about a common system for all modules, so if the special one is being rewritten, it might be a good time to revisit that. (CC @ev-br in case you're still interested.)

I still am, yes. In the language of signal and ndimage, the system in special has just a few delegators: def single_arg_signature(x, *args, **kwds): return array_namespace(x) and def two_arg_signature(x, y, *args, **kwds): return array_namespace(x, y).

It might make sense though to consider #20678 (comment), esp the output rewrapping. Does it make sense in scipy.special?

As a rule of thumb it makes sense to use the same dispatch mechanism everywhere.
The important caveat is that (most of) scipy.special is elementwise. This means that a few blanket statements are possible:

  • You know in advance the output shape will be the broadcast_shapes of the inputs (important for Dask and JAX)
  • You know in advance that each point of the output is a function of the matching point of the input and nothing more (important for MArray and Dask)

I hadn't thought about dispatching between GPU backends - that's a pretty neat thing to do. I just would be careful about dispatch priority (namely, when inside jax.jit I expect an Array API agnostic implementation to be much faster than CuPy or PyTorch).

@ev-br
Copy link
Member

ev-br commented Apr 1, 2025

As a rule of thumb it makes sense to use the same dispatch mechanism everywhere.

+1 to this, with some fine print.

First, I rather strongly believe that scipy's delegation story must be as straightforward as possible. If there are multiple options with potential perf implications, we should use one strategy by default and give users a way to override this choice explicitly. No more, not until we gain more collective experience with dispatch/delegation across the ecosystem.

Second, with scipy.special specifically. ISTM we should factor xsf development into consideration. Last I heard, the plan was for essentially all elementwise kernels to migrate to xsf, which will be a dual CPU / GPU C++ codebase. Then if all relevant libraries grew an xsf-powered layer, all scipy needs doing is to delegate to that layer?

@lucascolley
Copy link
Member

lucascolley commented Apr 1, 2025

Last I heard, the plan was for essentially all elementwise kernels to migrate to xsf, which will be a dual CPU / GPU C++ codebase. Then if all relevant libraries grew an xsf-powered layer, all scipy needs doing is to delegate to that layer?

I imagine we will need to delegate to jax.scipy.special to have any hope of things working with jax.jit where a pure Python kernel would kill performance.

EDIT: I realise now that this might be the same as what you mean by an xsf-powered layer

@rgommers
Copy link
Member

rgommers commented Apr 1, 2025

xsf isn't relevant here at all I think, it's a C++ library that will be internal to whatever array library uses it, one cannot delegate to that. It doesn't matter whether a somelibrary.special.func is backed by xsf or not, only whether special.func exists in somelibrary.

@ev-br
Copy link
Member

ev-br commented Apr 1, 2025

xsf isn't relevant here at all I think, it's a C++ library that will be internal to whatever array library uses it, one cannot delegate to that.

Of course. cupyx.scipy.special, and that it's powered by xsf is an implementation detail.
All this assuming the plans of xsf authors did not change, given that the last CuPy PR I see now was in October 2024, cupy/cupy#8620.

@crusaderky
Copy link
Contributor Author

I think that the grand unified dispatch design that is being discussed above should be a topic for a follow-up, not this PR.
There are several big benefits that merging this PR as-is delivers today, listed in the opening comment.

@lucascolley lucascolley merged commit 7545258 into scipy:main Apr 4, 2025
40 of 41 checks passed
@crusaderky crusaderky deleted the dask_special branch April 4, 2025 14:35
@crusaderky crusaderky changed the title ENH: special: support_alternative_backends on Dask ENH: special: support_alternative_backends on Dask and jax.jit Apr 10, 2025
@lucascolley
Copy link
Member

I think @lucascolley even mentioned that NetworkX and others were working on some sort of common system, and we should use that? (Do you know what I'm talking about, Lucas? Is that still relevant, or do we still need to roll our own system?)

I don't think I ever read this at the time of the comment, sorry @mdhaber. That's probably referring to scientific-python/spatch#1, but it looks like that hasn't had any activity in a while. @betatim or @stefanv may be able to tell you more about scikit-image/scikit-image#7520 and the prospects for generalising further.

@stefanv
Copy link
Member

stefanv commented May 2, 2025

We implemented whole-function dispatching in NetworkX and scikit-image, and I suspect at some point we'll try to find commonalities and repackage it as spatch.

scikit-image/scikit-image#7520

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement scipy.special
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy