-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
ENH: spatial.transform: add array API standard support #22777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
At first glance, this is a compromise I'm not particularly keen on. I'm definitely in the camp of error handling being a measure of correctness for which performance shouldn't take priority. That said, the main point in saying that is just to make sure reviewers consider the tradeoffs carefully. It may be sensible to spell out our policy on this somewhere in a tracking issue or something like that if there is already a consensus I'm not aware of. I made a similar comment about "batching" using NaNs instead of hard errors recently, but this is a separate scenario I think. I suppose the very existence of very popular lazy execution libraries is a sort of counter point that there is substantial demand for performance over error handling. Some of the lazy libraries have eager modes for better error handling/debugging though, right? Anyway, this is something I'm worried about. |
@crusaderky may have words to add here |
Please follow the pattern from #22342:
I'm going to write a proper guide in the documentation soon. |
It's not a matter of performance. Lazy backends cannot materialize the array contents with |
A big +1 for spelling out the policy and documenting it in a centralized location. #22781 |
Just to make sure I understand correctly: You mean you want to exclusively use the cython implementation for performing actual computations and add a conversion wrapper for frameworks other than numpy? |
I was about to remark the same. The NaN vs eager raise discussion is not about performance. It's just not possible to raise with some lazy frameworks. The question is how consistent error handling needs to be between frameworks.
|
|
right, and this PR addresses "an array API agnostic variant is not feasible (or delayed to a future date)" for |
I'm not sure I follow--why would we want to add lazy backend support if not for performance reasons? The fact that we're contemplating adding the lazy backend support implies that there's a performance benefit to doing so--if performance is not involved here, what is the true motivation? What I'm saying is that if we want a single common code path for lazy and eager backends that would mean that we sacrifice error handling for the eager case, which I'm not fond of. If we specialize code paths for lazy and eager and provide them with different error handling, that would seem to have different drawbacks that somewhat move away from the original array API implementation ideals (though I think specialized code paths have proliferated somewhat regardless). I also don't really understand the design of the In short, if you're not compromising, shimming or adjusting the error handling of the eager libraries at all, then I'm less worried. If you are making any adjustments to the error handling behavior of the eager libraries, then I maintain that this is ultimately a performance (support for lazy backends) vs. error handling matter. |
# Make sure we can transpose quat | ||
quat = xpx.atleast_nd(quat, ndim=2, xp=xp) | ||
K = (weights * quat.T) @ quat | ||
_, v = xp.linalg.eigh(K) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
strictly, all uses of xp.linalg
should be wrapped in if hasattr(xp, "linalg")
. Alternatively, if no fallback implementation is easy, we can just declare that the module requires xp
to implement the linalg
extension. In that case, it is probably better to add a higher-level check for an early exit.
Namely, because this function may not get a substantial performance benefit from running on Dask or JAX, but everything else in the user's workflow before and after it could.
Specifically for error handling, we don't want a single common code path because error handling is valuable for eager backends and impossible for lazy ones.
It is important to point out that we are talking about two code paths, total: one for eager and one for lazy. Which are two very broad categories of backends.
We're working to remove them, at least from the production code. Unit tests are harder to simplify as there are frequently quirks in behaviour when this or that backend diverges from the Standard.
They don't, by design. Another "potentially lazy" backend is ndonnx (not yet supported by scipy), which I understand is intended mainly for ML training where it switches to lazy mode. |
Understood, I think my main concern evaporates then, at least for this specific case. I was worried both were being unified to the detriment of eager error handling to keep maintenance burden under control. |
I have now added full support for I'll try to create some benchmarks that compare the new implementation for different frameworks and devices to the old one next. Once we have some numbers and the performance looks okay, we can maybe discuss what's left to do before this PR can be merged. |
One frequent issue that is coming up while I'm going through the benchmarks: Some functions may work just fine on the CPU, but will raise device errors when the original |
we're still thinking about how best to test devices in gh-22680 In array-api-extra, we use an enum, https://github.com/data-apis/array-api-extra/blob/main/src/array_api_extra/_lib/_backends.py. But in SciPy we also have the environment variable |
If the important part for now is just that the same tests are ran on the GPU, that should be covered by https://github.com/scipy/scipy/blob/main/.github/workflows/gpu-ci.yml |
Here are the benchmarks for There is a performance penalty for determining the array framework etc. that is mainly noticeable at sample sizes from 1 to 10. Nevertheless, this should be a very common use case, so we should probably take that serious and try to close the gap. I'm also not a cython expert, so if someone has more expertise in that area feel free to jump in with any ideas on how to improve typing etc. for more performance / less overhead. Some samples:
I will try to produce the same benchmark for |
You mean like this 😃 data-apis/array-api-compat#308 |
This is why I think the backend design makes a lot of sense. It should be very easy to separate either backend out at a later time into its own package if you wish to. Adding backends back in is then just be a matter of installing |
The PR above is now in scipy main; if you merge from main and rerun the benchmarks you should see its impact. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find this a rather difficult PR to review, since big chunks of code are split out to new files and I'm unfamiliar with the array API & its conventions.
However I am familiar with the existing tests here, and the testing of all the behavior of this module is very thoroughly covered. I reviewed the test files and barring the comments I left, I can confirm that the test suite remains functionally unchanged. This serves to prove that the code changes here have not modified the expected behavior.
assert_allclose( | ||
exp_coords = xp.asarray( | ||
[ | ||
[-2.01041204, -0.52983629, 0.65773501, 0.10386614, 0.05855009, 0.54959179], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have strong preference on the formatting of arrays, but it should be consistent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the isolated testing PR #22939. I strongly agree on this one. It would be great to have at least a consistent formatter for scipy.spatial.transform
.
], | ||
] | ||
) | ||
# The tolerance is set to 1e-8, because xp_assert_close compares the absolute difference instead |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this test otherwise failing? The previous 1e-12 should still be a bounding absolute error to compare against here, and not including rtol would only lower the expected error. If the results are changing enough to trip the limit, it's something we should flag and dig into. (Changed in some other places as well).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous 1e-12 should still be a bounding absolute error to compare against here, and not including rtol would only lower the expected error
I'm not sure I follow the logic here. Yes, adjusting the implicit rtol
lowers the expected error. Hence, the tests fail, as the implementation no longer meets the stricter expectation.
scipy/scipy/_lib/_array_api.py
Lines 366 to 373 in 9c02107
floating = xp.isdtype(actual.dtype, ('real floating', 'complex floating')) | |
if rtol is None and floating: | |
# multiplier of 4 is used as for `np.float64` this puts the default `rtol` | |
# roughly half way between sqrt(eps) and the default for | |
# `numpy.testing.assert_allclose`, 1e-7 | |
rtol = xp.finfo(actual.dtype).eps**0.5 * 4 | |
elif rtol is None: | |
rtol = 1e-7 |
1e-12
was not acting as a bounding absolute error, due to the implicit rtol
.
So I don't think the change implies that the results are becoming less precise (of course, we should still check).
The important question is what mixture of atol
and rtol
would be most appropriate here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the code comment here seems inaccurate, though: the linked lines show that rtol
is set to a nonzero value when the argument is omitted, so it isn't just comparing the absolute difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yeah I was getting things backwards, should be fine as-is. I think the correct comparison for hardcoded digits as is done here is an atol equal to the resolution of the hardcoded number. I.e. 8 decimal places would compare to atol of 1e-8 as is done in the new code, with rtol as 0 (as appears is not being done per the linked snippet).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If swapping assert_allclose
for xp_assert_close
is nontrivial, it should be moved to a separate PR which can become a dependency of this one. xp_assert_close
runs on non-array-api numpy too.
Basically everything that makes sense on its own as a propaedeutic step to this should be moved out in order to reduce the size of this PR as much as possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, agreed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #22939
@pytest.mark.thread_unsafe | ||
@pytest.mark.parametrize("seq_tuple", permutations("xyz")) | ||
@pytest.mark.parametrize("intrinsic", (False, True)) | ||
def test_as_euler_degenerate_compare_algorithms(seq_tuple, intrinsic): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seconded
BTW @amacati, it looks like you aren't the only one interested in using this! magpylib/magpylib#792 |
Sorry for taking so long, had a lot going on recently. I have now split up the testing into a separate PR #22939. The PR is optimized to reduce the amount of line changes, but still comes in at ~ +2000/-1200 lines. I don't think I can improve significantly on that, it's just a lot of code that has changed. The plan for moving forward (in this order) is:
All reviews can focus on these isolated PRs to make things less overwhelming. Once these PRs are merged, we can close this PR (probably without merging, as its parts have been merged individually). @ maintainers please let me know if that works for you. Splitting this PR up into stages without breaking the Also, I know that |
sounds like a plan! We can add to the list: converting the tests to use the |
Ah yes, but this is already completed in #22939 and ready for review, so I didn't list it |
Did you mean to clear all the commits here? |
I thought it might make sense to keep this thread open for the general discussion until everything is merged from smaller PRs and then close this PR without merging. |
Now that I looked at https://docs.jax.dev/en/latest/jax.scipy.html, apparently jax has a native implementation of |
No, not yet. It would definitely be interesting to compare against those functions that are implemented in the jax version (it only supports a subset) |
General overview
This PR adds array API support for the
scipy.spatial.transform
module as a step towards #18867. It allows users to create objects from the module with arrays from frameworks liketorch
,jax
etc., and leverages these frameworks for performing the underlying operations.Since work started when
Rotation
was the primary part of the module,RigidTransform
has not yet been ported. However, it would be great to discuss some of the general design decisions before moving forward with the implementation ofRigidTransform
.Architecture
Motivation
Previously,
Rotation
was a cython class centered around the rotation data represented as a quaternion. With the addition of support for the array API we can no longer use this cython code path as the universal implementation, as some arrays from other frameworks cannot be passed into cython (e.g. tensors on a GPU, jax tracers etc.).Since the cython implementation can significantly speed up calculations with numpy arrays, and numpy remains the primary use case for the foreseeable future, we do not want to add too much overhead for the array API support.
Leveraging backends
This PR proposes an architecture that separates the implementation into:
Rotation
class that handles the interface and data storageThese backends are functional implementations of
Rotation
's methods. Each time aRotation
is created from arrays, we check if a backend specifically for this array type has been added. If so, we use the specialized backend for all operations. Otherwise, we fall back to the new array API backend that is compatible with any array API compatible array type.This design allows us to nicely integrate the previous, optimized cython implementation by rewriting it into a backend. We register this cython backend for numpy arrays by default, which means that numpy arrays still enjoy the speedup from compiled cython code. In addition, this also gives us a straightforward path to include optimized implementations for individual frameworks if desired, and could even be exposed to other modules that may want to register their own, specialized implementations. Here is an example of how this works in the new implementation:
Compiling functions
Some frameworks like
jax
heavily rely on jit compiled code with statically known shapes and value-independent control flow. The generic array API backend is written (and tested) to comply with these restrictions. However, this has two implications:Current limitations and issues
spatial.transform.RigidTransform
supportJust like the previous
Rotation
implementation,RigidTransform
is currently a cython class that relies onRotation
being one as well. The solution is to also convert RigidTransform into a Python class as interface, and use the same backend mechanic as withRotation
. Since the current PR already makes significant code changes with many design decisions, it makes sense to first discuss and review these before moving on to convertingRigidTransform
as well. The final draft may then also include fullRigidTransform
support.For now, this version of
spatial.transform
INTENTIONALLY DISABLESRigidTransform
to let the build succeed. Hence, all tests forRigidTransform
are expected to fail.array API version
The current implementation uses the latest array API standard 2024.12 with
array-api-strict>=2.3
andarray-api-compat>=1.11
. This is necessary for advanced indexing.Randomness
The array API does not define a standard for generating random numbers. This is probably out of scope anyways, as different frameworks have different models for randomness (compare e.g.
numpy
generators vsjax
PRNG keys). Therefore, random constructors always return the numpy version of rotations.Rotation splines
Rotation splines currently use
scipy.interpolate
, which is not array API compatible. Therefore, the current implementation errors on creating splines from aRotation
that is not backed by a numpy array.Open issues
xp.roll
needs to be clarified by the array API andarray-api-compat
needs to be updated for torch, orarray-api-strict
requires additional checks forxp.roll
. See data-apis/array-api#914Open design decisions
Rewriting
Rotation
with pure array API compatible code that may also be compiled by lazy frameworks requires some deviations from the original cython implementation. Sections that deviate and require some form of discussion or decision are marked with# DECISION: ... (or sometimes with TODO:)
to make it easier for reviewers to find critical points for discussion. They mostly also add a reason why the current decision was made, and may include alternatives.
Tests
Currently, all tests for rotations have been rewritten in terms of the array API and pass locally with
numpy
,array-api-strict
,torch
andjax
.Performance
I plan on benchmarking the current implementation for various workloads similar to the benchmarks in #22500. Ideally this includes comparing it against the current version to measure the overhead of going through the backend. This may have to wait until I have some more time at my disposal though. In case this should be integrated into the airspeed suite, I'd be grateful for guidance and/or help with that.
TODOs
Reference issue
Closes #22500