Skip to content

FEA Add array API support for GaussianMixture #30777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 103 commits into from
Jun 19, 2025

Conversation

lesteve
Copy link
Member

@lesteve lesteve commented Feb 6, 2025

Working on it with @StefanieSenger.

Link to TODO

@lesteve lesteve marked this pull request as draft February 6, 2025 14:26
Copy link

github-actions bot commented Feb 6, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: d46840b. Link to the linter CI: here

@StefanieSenger StefanieSenger self-requested a review February 14, 2025 09:28
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good. I left a few comments



# TODO What is the expected behavior when weights init
# and X are not in the same namespace/device?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not resolved yet. Can we remove the commented out code?

@OmarManzoor
Copy link
Contributor

@lesteve Just one test failing and that has to do with array api strict on device and float32. Maybe we need to increase the tolerance further for this specific scenario.

@lesteve
Copy link
Member Author

lesteve commented Jun 19, 2025

My honest impression is that these tests are fragile on float32 data but I don't really know if there is much we can do to improve the situation ...

Even for array-api-strict the results are different because of the difference between scipy.linalg.choleksy and numpy.linalg.cholesky and between scipy.linalg.triangular_solve and numpy.linalg.solve.

On a GPU VM I also saw some test failures (a few more than in the CI actually) and raised the atol and rtol a bit to get them to pass locally. I trigger another run of the CUDA CI, let's see what happens 🤞.

@OmarManzoor
Copy link
Contributor

I don't think we can do much with trying to improve array-api-strict tests for float32 especially with respect to accuracy. As long as array-api-strict works generally I think that should be sufficient.

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you for the work done in this PR @lesteve and @StefanieSenger

@OmarManzoor OmarManzoor merged commit cc526ee into scikit-learn:main Jun 19, 2025
40 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in Array API Jun 19, 2025
@lesteve lesteve deleted the gmm-array-api branch June 19, 2025 12:18
@lesteve
Copy link
Member Author

lesteve commented Jun 20, 2025

Thanks for the reviews @OmarManzoor and @ogrisel!

One of the remaining question in the old and long TODO list: should we implement __sklearn_tags__ to tell that GaussianMixture has array_api_support?

PCA ___sklearn_tags__ does this currently and always sets array_api = True although array API support is implemented for some values of the parameters, not sure whether this is expected or not:

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.transformer_tags.preserves_dtype = ["float64", "float32"]
tags.array_api_support = True
tags.input_tags.sparse = self.svd_solver in (
"auto",
"arpack",
"covariance_eigh",
)
return tags

I am guessing the array_api tags is only used for the common tests right now, right?

@ogrisel
Copy link
Member

ogrisel commented Jun 20, 2025

Good questions:

  • indeed, we could make PCA only return the tags.array_api_support = True when the solver supports array API inputs.
  • similarly for GaussianMixture (depending on the choice of the init).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

4 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy