Skip to content

MAINT: stats: ensure functions work on non-default device #22856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jun 1, 2025

Conversation

mdhaber
Copy link
Contributor

@mdhaber mdhaber commented Apr 17, 2025

Reference issue

Toward gh-22680

What does this implement/fix?

Makes adjustments needed to ensure that stats functions work when arrays are not on the default device. Adds tests for device and dtype preservation.

Additional information

To do:

  • add device to marray asarray
  • preserve device in special/_support_alternative_backends.py? A lot of work has been done on _support_alternative_backends since opening this; I don't think I need to do anything here.
  • ensure common device in xp_promote? Sidestepped. Our private _asarray needs to support device.
  • add tests for remaining functions

@mdhaber mdhaber requested a review from rgommers as a code owner April 17, 2025 22:54
@mdhaber mdhaber marked this pull request as draft April 17, 2025 22:54
@github-actions github-actions bot added scipy.stats scipy._lib Meson Items related to the introduction of Meson as the new build system for SciPy labels Apr 17, 2025
@mdhaber mdhaber removed the request for review from rgommers April 17, 2025 22:54
@rgommers
Copy link
Member

Thanks Matt. This looks pretty good to me already.

@mdhaber mdhaber changed the title WIP: stats: ensure functions work on non-default device MAINT: stats: ensure functions work on non-default device May 21, 2025
@mdhaber mdhaber removed the Meson Items related to the introduction of Meson as the new build system for SciPy label May 21, 2025
@mdhaber mdhaber marked this pull request as ready for review May 21, 2025 21:19
@lucascolley lucascolley added maintenance Items related to regular maintenance tasks array types Items related to array API support and input array validation (see gh-18286) labels May 21, 2025
if is_torch(xp):
devices = xp.__array_namespace_info__().devices()
# open an issue about this - cannot branch based on `any`/`all`?
return (device for device in devices if device.type != 'meta')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what to do about this:

from scipy._lib.array_api_compat import torch as xp
if xp.any(xp.asarray([True, False], device='meta')): 
    print('success!')
# RuntimeError: Tensor.item() cannot be called on meta tensors

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to data-apis/array-api#945 it looks like. 'meta' backend is brand new and WIP I think - do we need to worry about it here if CI is green?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing not many existing tests exercise it because only tests that use the devices fixture see it, right? In any case, it causes several of these new tests to fail.

I don't think I can make all the functions work with it in this PR, so currently I used this like to skip it everywhere. But I'm open to thoughts on a more targeted way to skip or xfail the offending tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing not many existing tests exercise it because only tests that use the devices fixture see it, right?

Yes, that seems right. And I think the choice you made here for skipping at the scipy-global level is the right approach for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supporting the torch meta device requires

  1. adding it to array_api_compat.is_lazy_array
  2. switching scipy to use array-api-extra xp_assert_equal et al. to benefit from ENH: support PyTorch device='meta' data-apis/array-api-extra#300.

Comment on lines +992 to +993
dtype = xp_result_type(lmb, data, force_floating=True, xp=xp)
data = xp.asarray(data, dtype=dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if xp_promote could convert Python scalar lmb to an array on the right device, but xp_promote uses our _asarray, which doesn't support device. I'd like to leave it as-is in this PR, but it's something to keep in mind in gh-22049.

@mdhaber mdhaber requested review from lucascolley and rgommers May 22, 2025 02:26
@mdhaber mdhaber requested review from larsoner and andyfaff as code owners May 22, 2025 05:37
@mdhaber
Copy link
Contributor Author

mdhaber commented May 22, 2025

Removed changes to distance.py and added the one missing use of device to bartlett. Since meta is new and essentially untested, I can open an issue for it. If we want to support it, I think we need a help function to look for it (like is_meta_array, since it will require special cases like lazy arrays do), and I think the test skip/xp_capabilities system needs to be updated to deal with it.

@rgommers rgommers added this to the 1.17.0 milestone May 30, 2025
@rgommers
Copy link
Member

I started reviewing again and then noticed merge conflicts - heavy traffic on this code. I resolved them and then added a couple of commits to fix/skip a few failures. The array-api-strict/torch ones were easy to figure out, the dask ones were obscure so I added skips for them.

@rgommers
Copy link
Member

Since meta is new and essentially untested, I can open an issue for it. If we want to support it, I think we need a help function to look for it (like is_meta_array, since it will require special cases like lazy arrays do), and I think the test skip/xp_capabilities system needs to be updated to deal with it.

An issue would be great. My current impression is that this is low-priority, and we should leave it completely disabled until we have dealt with JAX/Dask more. Because the PyTorch 'meta' device has similar requirements to them.

Copy link
Member

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mdhaber. I think this looks good in its current form, let's squash-merge when CI comes back green I'd say.

Comment on lines 94 to 95
if (is_array_api_strict(xp) or is_torch(xp)) and method == 'harrell-davis':
pytest.skip("'harrell-davis' not currently not supported on GPU.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

Suggested change
if (is_array_api_strict(xp) or is_torch(xp)) and method == 'harrell-davis':
pytest.skip("'harrell-davis' not currently not supported on GPU.")
if (is_array_api_strict(xp) or is_torch(xp)) and method == 'harrell-davis':
pytest.skip("'harrell-davis' not currently supported on GPU.")

However, would this not be more accurately written as:

@pytest.mark.parametrize('method', [
    'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
    'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear', 'median_unbiased',
    'normal_unbiased',
    pytest.param('harrell-davis', marks=[pytest.mark.skip_xp_backends(cpu_only=True)],
])

That would come with the default message "no array-agnostic implementation or delegation available for this backend and device".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. It's easier to write it with if ... rather than pytest.param (no need to remember unusual pytest syntax), but the pytest.param method is a bit more concise and gives better messages.

I'm happy for this to go in either way though, before it accumulates conflicts again.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to remember unusual pytest syntax

I guess I've reviewed enough of Guido's PRs that it seems normal now 😅

I'll push the change and merge.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, it looks like cpu_only=True does not skip array-api-strict's device1 as that comes via the devices fixture, not the xp fixture.

@crusaderky is there a better solution already available here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What you're asking for is a flag for all backends that have devices other than CPU. Which doesn't exist, and it's a tad too niche IMHO.
Worth pointing out that, on a non-pixi deployment with CUDA installed,

@pytest.mark.skip_xp_backends(cpu_only=True)
def test1(xp, devices):
    ...

will in fact run on torch GPU and JAX GPU.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I drafted a PR that uses pytest.param(..., marks=skip_xp_backends(...)), but while it's more idiomatic I found it to be a lot less readable.

@lucascolley lucascolley merged commit 55a0b1d into scipy:main Jun 1, 2025
@lucascolley
Copy link
Member

merged to avoid conflicts, hopefully #22856 (comment) can be addressed in a follow-up

@pytest.mark.parametrize('dtype', dtypes)
def test_one_in_one_out(fun, kwargs, dtype, xp, devices):
if is_dask(xp) and fun == stats.variation:
pytest.skip("dtype inference failing in xpx.apply_where")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mdhaber this is quite worrysome:

E TypeError("Multiple namespaces for array inputs: {<module 'scipy._lib.array_api_compat.numpy' [...], <module 'scipy._lib.array_api_compat.dask.array' [...]

Copy link
Contributor Author

@mdhaber mdhaber Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to look into it, but it wasn't worth holding up the whole PR. Dask can be very creative about how to fail with code that works for other backends. There are other things I find more worrisome.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dask can be very creative about how to fail with code that works for other backends.

Yeah I looked at this, and it's very much not obvious what is going on there. All the input variables seemed to be Dask arrays of the same dtype and device, and internally somewhere, one is converted to a numpy array.

Copy link
Contributor Author

@mdhaber mdhaber Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@crusaderky I looked and I can't see what's wrong with the inputs to apply_where, but the error is happening inside that.

In the call to map_blocks:
https://github.com/data-apis/array-api-extra/blob/dd7650aa737f4ded37e5d406067f117e05035ec3/src/array_api_extra/_lib/_funcs.py#L142

  • cond is a Dask array
  • f1 is a function that operates on Dask arrays and returns Dask arrays (defined in variation)
  • f2 is None
  • args_ is a tuple of two Dask arrays
  • f1(*args_) is a Dask array
  • meta_xp is SciPy's array API compat NumPy

Then if we go inside the call to _apply_where, all the array arguments are NumPy arrays, but f1 is still a function that is supposed to operate on Dask arrays. So temp1 is a Dask array, and in return at(out, cond).set(temp1), it's trying to set elements of a NumPy out array with elements of a Dask temp1 array.

Copy link
Contributor Author

@mdhaber mdhaber Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the callable f1 need to accept any type of array and return any type of array? I guess so:

On Dask, f1 and f2 are applied to the individual chunks and should use functions from the namespace of the chunks.

I can probably fix that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the callable f1 need to accept any type of array and return any type of array?

Yes. See data-apis/array-api-extra#196

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jpmikhail pushed a commit to jpmikhail/scipy that referenced this pull request Jun 3, 2025
* WIP: stats: ensure functions work on non-default device

[skip ci]

* MAINT: stats: post-merge fixups

* TST: stats: add dtype/device tests

* STY: stats: fix lint issues

* TST: stats: address test failures

* TST: stats: add tests of quantile and directional_stats

* MAINT: stats: skip marray tests

* MAINT: stats: address test failures

* STY: stats: fix lint

* MAINT: unpin marray to use latest release; run marray tests

* MAINT: stats: adjustments per review

* MAINT: stats: some more device/dtype fixes

* STY: fix a linter complaint

* [skip ci]

* [skip ci]

---------

Co-authored-by: Ralf Gommers <ralf.gommers@gmail.com>
Co-authored-by: Lucas Colley <lucas.colley8@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) maintenance Items related to regular maintenance tasks scipy.stats
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy