-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
MAINT: stats: ensure functions work on non-default device #22856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks Matt. This looks pretty good to me already. |
if is_torch(xp): | ||
devices = xp.__array_namespace_info__().devices() | ||
# open an issue about this - cannot branch based on `any`/`all`? | ||
return (device for device in devices if device.type != 'meta') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what to do about this:
from scipy._lib.array_api_compat import torch as xp
if xp.any(xp.asarray([True, False], device='meta')):
print('success!')
# RuntimeError: Tensor.item() cannot be called on meta tensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to data-apis/array-api#945 it looks like. 'meta'
backend is brand new and WIP I think - do we need to worry about it here if CI is green?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing not many existing tests exercise it because only tests that use the devices
fixture see it, right? In any case, it causes several of these new tests to fail.
I don't think I can make all the functions work with it in this PR, so currently I used this like to skip it everywhere. But I'm open to thoughts on a more targeted way to skip or xfail the offending tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing not many existing tests exercise it because only tests that use the
devices
fixture see it, right?
Yes, that seems right. And I think the choice you made here for skipping at the scipy-global level is the right approach for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Supporting the torch meta
device requires
- adding it to
array_api_compat.is_lazy_array
- switching scipy to use array-api-extra
xp_assert_equal
et al. to benefit from ENH: support PyTorchdevice='meta'
data-apis/array-api-extra#300.
dtype = xp_result_type(lmb, data, force_floating=True, xp=xp) | ||
data = xp.asarray(data, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice if xp_promote
could convert Python scalar lmb
to an array on the right device, but xp_promote
uses our _asarray
, which doesn't support device
. I'd like to leave it as-is in this PR, but it's something to keep in mind in gh-22049.
Removed changes to |
I started reviewing again and then noticed merge conflicts - heavy traffic on this code. I resolved them and then added a couple of commits to fix/skip a few failures. The |
An issue would be great. My current impression is that this is low-priority, and we should leave it completely disabled until we have dealt with JAX/Dask more. Because the PyTorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mdhaber. I think this looks good in its current form, let's squash-merge when CI comes back green I'd say.
if (is_array_api_strict(xp) or is_torch(xp)) and method == 'harrell-davis': | ||
pytest.skip("'harrell-davis' not currently not supported on GPU.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
if (is_array_api_strict(xp) or is_torch(xp)) and method == 'harrell-davis': | |
pytest.skip("'harrell-davis' not currently not supported on GPU.") | |
if (is_array_api_strict(xp) or is_torch(xp)) and method == 'harrell-davis': | |
pytest.skip("'harrell-davis' not currently supported on GPU.") |
However, would this not be more accurately written as:
@pytest.mark.parametrize('method', [
'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear', 'median_unbiased',
'normal_unbiased',
pytest.param('harrell-davis', marks=[pytest.mark.skip_xp_backends(cpu_only=True)],
])
That would come with the default message "no array-agnostic implementation or delegation available for this backend and device".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. It's easier to write it with if ...
rather than pytest.param
(no need to remember unusual pytest
syntax), but the pytest.param
method is a bit more concise and gives better messages.
I'm happy for this to go in either way though, before it accumulates conflicts again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to remember unusual pytest syntax
I guess I've reviewed enough of Guido's PRs that it seems normal now 😅
I'll push the change and merge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, it looks like cpu_only=True
does not skip array-api-strict's device1
as that comes via the devices
fixture, not the xp
fixture.
@crusaderky is there a better solution already available here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What you're asking for is a flag for all backends that have devices other than CPU. Which doesn't exist, and it's a tad too niche IMHO.
Worth pointing out that, on a non-pixi deployment with CUDA installed,
@pytest.mark.skip_xp_backends(cpu_only=True)
def test1(xp, devices):
...
will in fact run on torch GPU and JAX GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I drafted a PR that uses pytest.param(..., marks=skip_xp_backends(...))
, but while it's more idiomatic I found it to be a lot less readable.
merged to avoid conflicts, hopefully #22856 (comment) can be addressed in a follow-up |
@pytest.mark.parametrize('dtype', dtypes) | ||
def test_one_in_one_out(fun, kwargs, dtype, xp, devices): | ||
if is_dask(xp) and fun == stats.variation: | ||
pytest.skip("dtype inference failing in xpx.apply_where") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mdhaber this is quite worrysome:
E TypeError("Multiple namespaces for array inputs: {<module 'scipy._lib.array_api_compat.numpy' [...], <module 'scipy._lib.array_api_compat.dask.array' [...]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to look into it, but it wasn't worth holding up the whole PR. Dask can be very creative about how to fail with code that works for other backends. There are other things I find more worrisome.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dask can be very creative about how to fail with code that works for other backends.
Yeah I looked at this, and it's very much not obvious what is going on there. All the input variables seemed to be Dask arrays of the same dtype and device, and internally somewhere, one is converted to a numpy array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@crusaderky I looked and I can't see what's wrong with the inputs to apply_where
, but the error is happening inside that.
In the call to map_blocks
:
https://github.com/data-apis/array-api-extra/blob/dd7650aa737f4ded37e5d406067f117e05035ec3/src/array_api_extra/_lib/_funcs.py#L142
cond
is a Dask arrayf1
is a function that operates on Dask arrays and returns Dask arrays (defined invariation
)f2
isNone
args_
is a tuple of two Dask arraysf1(*args_)
is a Dask arraymeta_xp
is SciPy's array API compat NumPy
Then if we go inside the call to _apply_where
, all the array arguments are NumPy arrays, but f1
is still a function that is supposed to operate on Dask arrays. So temp1
is a Dask array, and in return at(out, cond).set(temp1)
, it's trying to set elements of a NumPy out
array with elements of a Dask temp1
array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the callable f1
need to accept any type of array and return any type of array? I guess so:
On Dask,
f1
andf2
are applied to the individual chunks and should use functions from the namespace of the chunks.
I can probably fix that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the callable
f1
need to accept any type of array and return any type of array?
Yes. See data-apis/array-api-extra#196
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* WIP: stats: ensure functions work on non-default device [skip ci] * MAINT: stats: post-merge fixups * TST: stats: add dtype/device tests * STY: stats: fix lint issues * TST: stats: address test failures * TST: stats: add tests of quantile and directional_stats * MAINT: stats: skip marray tests * MAINT: stats: address test failures * STY: stats: fix lint * MAINT: unpin marray to use latest release; run marray tests * MAINT: stats: adjustments per review * MAINT: stats: some more device/dtype fixes * STY: fix a linter complaint * [skip ci] * [skip ci] --------- Co-authored-by: Ralf Gommers <ralf.gommers@gmail.com> Co-authored-by: Lucas Colley <lucas.colley8@gmail.com>
Reference issue
Toward gh-22680
What does this implement/fix?
Makes adjustments needed to ensure that
stats
functions work when arrays are not on the default device. Adds tests for device and dtype preservation.Additional information
To do:
device
tomarray
asarray
preserve device inA lot of work has been done onspecial/_support_alternative_backends.py
?_support_alternative_backends
since opening this; I don't think I need to do anything here.ensure common device inSidestepped. Our privatexp_promote
?_asarray
needs to supportdevice
.