Skip to content

ENH add from_cv_results in PrecisionRecallDisplay (single Display) #30508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from

Conversation

lucyleeow
Copy link
Member

Reference Issues/PRs

Follows on from #30399

What does this implement/fix? Explain your changes.

Proof of concept of adding multi displays to PrecisionRecallDisplay

  • A lot of the code is similar to that in ENH add from_cv_results in RocCurveDisplay (single RocCurveDisplay) #30399, so we can definitely factorize out, though small intricacies may make it complex
  • The plot method is complex due to handling both single and multi curve and doing a lot more checking, as user is able to use it outside of the from_estimator and from_predictions methods.

Detailed discussions of problems in review comments.

Any other comments?

cc @glemaitre @jeremiedbb

Copy link

github-actions bot commented Dec 19, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 33009c8. Link to the linter CI: here

Copy link
Member Author

@lucyleeow lucyleeow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to mention, I think I would like to decide on the order parameters for these display classes and their methods. They seem to have a lot of overlap and it would be great if they could be consistent.

I know that this would not matter when using the methods but it would be nice for the documentation API page if they were consistent?


estimator_name : str, default=None
Name of estimator. If None, then the estimator name is not shown.
curve_name : str or list of str, default=None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought curve_name is a more generalizable term (vs estimator_name), especially with cv multi curves where we want to name each curve by the fold number.

Changing this name will mean that we must change _validate_plot_params and thus all other classes that use _BinaryClassifierCurveDisplayMixin

I note that the parameter is named differently here (PrecisionRecallDisplay init) vs in the from_prediction and from_estimator methods (where it's called name). I'm not sure if this was accidental or to distinguish it from the method parameter 'name's?

Comment on lines 230 to 250
# If multi-curve, ensure all args are of the right length
req_multi = [
input for input in (self.precision, self.recall) if isinstance(input, list)
]
if req_multi and ((len(req_multi) != 2) or len({len(arg) for arg in req_multi}) > 1):
raise ValueError(
"When plotting multiple precision-recall curves, `self.precision` "
"and `self.recall` should both be lists of the same length."
)
elif self.average_precision is not None:
default_line_kwargs["label"] = f"AP = {self.average_precision:0.2f}"
elif name is not None:
default_line_kwargs["label"] = name
n_multi = len(self.precision) if req_multi else None
if req_multi:
for name, param in zip(
["self.average_precision", "`name` or `self.curve_name`"],
(self.average_precision, name_)
):
if not((isinstance(param, list) and len(param) != n_multi) or param is not None):
raise ValueError(
f"For multi precision-recall curves, {name} must either be "
"a list of the same length as `self.precision` and "
"`self.recall`, or None."
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I struggled to come up with a nice way to do this. The checks we need are:

  • precision and recall, both need to be lists of the same length or both need to be single ndarray
  • for multi curve, average_precision and name can either be a list of the same length or None.

This latter point is important, as previously I simply checked that all 4 parameters are of the same length if they were lists. I didn't check that 2 optional parameters needed to be None if they were not a list, for the multi-curve situation.

Suggestions welcome for making this nicer.

The good part though is that this is easily factorized out and can be generalised for all similar displays.

Comment on lines 272 to 274
name_ = [name_] * n_multi if name_ is None else name_
average_precision_ = (
[None] * n_multi if self.average_precision is None else self.average_precision
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this, but could not immediately think of a better way to do it

)
# Note `pos_label` cannot be `None` (default=1), unlike other metrics
# such as roc_auc
average_precision = average_precision_score(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note pos_label cannot be None here (default=1), unlike other metrics as roc_auc

Comment on lines 836 to 838
precision_all.append(precision)
recall_all.append(recall)
ap_all.append(average_precision)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't like this but not sure on the zip suggested in #30399 (comment) as you've got to unpack at the end 🤔

Copy link
Member Author

@lucyleeow lucyleeow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some notes on review suggestions. Namely to make all the multi class params (precisions, recalls etc) list of ndarrays.

Also realised we did not need separate plot_single_curve function, as most of the complexity was in _get_line_kwargs

Comment on lines 201 to 203
if fold_line_kws is None:
fold_line_kws = [
{"alpha": 0.5, "color": "tab:blue", "linestyle": "--"}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make all multi curve lines the same appearance, the legend will not be relevant.
Maybe we should not specify color?

image

Copy link
Member Author

@lucyleeow lucyleeow Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decided that we should not specify single colour because indeed the the legend would be useless.

Comment on lines 255 to 258
names : str, default=None
Names of each precision-recall curve for labeling. If `None`, use
name provided at `PrecisionRecallDisplay` initialization. If not
provided at initialization, no labeling is shown.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems reasonable that if we change the name parameter in the class init, we should change it here to, especially as we don't advocate people to use plot directly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed this with @glemaitre and decided that it is okay to change to names. We should however make it clear what this is setting - the label of the curve in the legend.

The problem use case we thought about was if someone created a plot and display object, then wanted to add one curve to it using plot, names would not make sense in this case. However, it would be difficult for us to manage the legend in such a case, so decided that it would be up to the user to manage the legend in such a case.

Comment on lines +347 to +348
if len(self.line_) == 1:
self.line_ = self.line_[0]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should line_ always be a list or should we do this to be backwards compatible?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We decided that we should deprecated line_ and add lines_.
We'll add a getter such that if you try to access line_ you get a warning and the first item of lines_, which will be removed in 2 releases.

@lucyleeow
Copy link
Member Author

Just wanted to document here that we discussed a potential enhancement for comparing between estimators, where you have cv results from several estimators (so several fold curves for each estimator). Potentially this could be added as a separate function, where you pass the display object, and estimators desired. Not planned, just a potential additional in future.

@jeremiedbb
Copy link
Member

Hey, I think that you can revive this PR now that the roc curve is merged. Let's try to reuse code from the other PR if possible :)

@lucyleeow
Copy link
Member Author

Thanks @jeremiedbb !

I think @glemaitre mentioned there was some discussion about what to do with the 'chance' level (average precision). In the current PR I have calculated a single average precision (AP) for all the data. I think others suggested that we should calculate average precision for each fold, which I can see is more accurate but I am concerned about the visualization appearance.

Here I have used 5 cv splits and plotting chance for each, and colouring each pair of precision-recall curve/chance line the same colour:

image

Some concerns about the visualization:

  • it would not be unusual for AP to be the same, as has occurred above. The orange and blue lines have the same AP and we can see that results in a single brown-ish line.
  • in ROC curve, we decided to colour all CV folds the same colour by default, as we thought this would be most appropriate:
    • as number of folds increases, it would be hard distinguish each line individually
    • we assume we are interested in comparing the results between estimators, thus it makes sense to colour all the cv results from one estimator the same colour

I will have more of a think of a better solution for this.

@glemaitre
Copy link
Member

glemaitre commented Jun 13, 2025

So for the "Chance level" I would consider all lines to have the same color (and a lower alpha) and in the legend to have a single entry showing the mean + std. I would think it is enough. Also it is easy to link a chance level line with its PR curve because they meet when the recall is 1.0

Copy link
Member Author

@lucyleeow lucyleeow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is what the plot looks like with defaults, and plot chance set to True:

Code
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_validate
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import PrecisionRecallDisplay

import matplotlib.pyplot as plt

Generate sample data
X, y = make_classification(n_samples=1000, n_classes=2, n_informative=5, random_state=42)
clf = RandomForestClassifier(random_state=42)

cv_results = cross_validate(
    clf, X, y, cv=5,
    return_estimator=True,
    return_indices=True,
)

# Plot Precision-Recall curve using from_cv_results
disp = PrecisionRecallDisplay.from_cv_results(
    cv_results, X, y, name="RandomForest", plot_chance_level=True
)
plt.show()

image

The alpha for chance line is 0.3. Prevalence seems to be pretty much the same for all cvs (which may not be unusual?) so they mostly over-lap.

@@ -135,6 +135,8 @@ def _validate_curve_kwargs(
legend_metric,
legend_metric_name,
curve_kwargs,
default_curve_kwargs=None,
removed_version="1.9",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly added a default here because I wanted this parameter to be last. Happy to change

Comment on lines -247 to +265
_validate_style_kwargs({"label": label}, curve_kwargs[fold_idx])
_validate_style_kwargs(
{"label": label, **default_curve_kwargs_}, curve_kwargs[fold_idx]
)
Copy link
Member Author

@lucyleeow lucyleeow Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh github lost my comment.

Previously, if you passed any curve_kwargs, it would over-ride all default kwargs.

Now, if an individual curve_kwargs that changes the same parameter as a default kwarg is passed, only that parameter will be over-ridden. All other default kwargs will still be used. E.g., if the user set color to red in curve_kwargs, only the defualt color parameter will be over-ridden. The other parameters (e.g., "alpha": 0.5, "linestyle": "--") will still be used.

I initially wanted to implement this in RocCurveDisplay, but just went with 'over-ride all defaults' because it was easier.

I think it is more likely that if a user e.g., sets the curve color to be red, they still want the other default kwargs (i.e., they only want to change the color).
In particular, I changed this because it is necessary for precision recall as I think we always want the default "drawstyle": "steps-post" (to prevent interpolation), unless the user specifically changes it.

(if we decide we are happy with this change, I should probably add a whats new entry for RocCurveDisplay)

@lucyleeow lucyleeow marked this pull request as ready for review June 18, 2025 12:29
Comment on lines +848 to +853
precision_folds, recall_folds, ap_folds, prevalence_pos_label_folds = (
[],
[],
[],
[],
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lint did this and I don't think it looks great but I have no suggestions on how to fix 🤷

@lucyleeow
Copy link
Member Author

lucyleeow commented Jun 20, 2025

pos_label checking in display methods

I only realised this when looking at the test:

def test_precision_recall_display_string_labels(pyplot):

When y is composed of string labels:

  • from_predictions raises an error if pos_label is not explicitly passed (via _check_pos_label_consistency). This makes sense, as we cannot guess what pos_label should be.
  • from_estimator does not raise an error because we default to estimator.classes_[1] (_get_response_values_binary does this).

I think it is reasonable for from_cv_results to also default to estimator.classes_[-1] (this is indeed what we have in the docstring, but it is NOT what are doing in main). This case is a bit more complicated than from_estimator because we have the problem where it is possible that not every class is present in each split (see #29558) - thus we could end up with different pos_labels. Still thinking through this, but I think I would be happy to check that if pos_label is not explicitly passed, it has been inferred to be the same for every split. WDYT @glemaitre ?

Edit: Actually, I think all estimators would raise an error if there are less than 2 classes, so we can just leave it to the estimator.

Comment on lines +362 to +364
# y_multi[y_multi == 1] = 2
# with pytest.raises(ValueError, match=r"y takes value in \{0, 2\}"):
# display_class.from_cv_results(cv_results, X, y_multi)
Copy link
Member Author

@lucyleeow lucyleeow Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've realised this is a very weird, edge case that I was testing. The estimator in cv_results has been fitted on a different y than what we have passed (y_multi). i.e., estimator.classes_ will have different classes than np.unique(y)

Then when we pass values to the metric, pos_label is not present in y_true. Interestingly, average_precision_score checks this in:

present_labels = np.unique(y_true).tolist()
if y_type == "binary":
if len(present_labels) == 2 and pos_label not in present_labels:
raise ValueError(
f"pos_label={pos_label} is not a valid label. It should be "
f"one of {present_labels}"

but none of precision_recall_curve, roc_curve and auc, check this.

I am not sure if this is something we should be checking, and if so should it be left to the metric functions (to also avoid duplication of checking)...?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realise that both roc_curve and precision_recall_curve will give a warning if pos_label is not in y_true, which gets ignored in tests. I can simply update this test to check that the correct warning is raised.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy