diff --git a/lib/matplotlib/testing/decorators.py b/lib/matplotlib/testing/decorators.py index 7965d45e4f16..7e9621cb3b1d 100644 --- a/lib/matplotlib/testing/decorators.py +++ b/lib/matplotlib/testing/decorators.py @@ -381,41 +381,40 @@ def test_plot(fig_test, fig_ref): fig_test.subplots().plot([1, 3, 5]) fig_ref.subplots().plot([0, 1, 2], [1, 3, 5]) """ - + POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD def decorator(func): import pytest _, result_dir = _image_directories(func) - if len(inspect.signature(func).parameters) == 2: - # Free-standing function. - @pytest.mark.parametrize("ext", extensions) - def wrapper(ext): - fig_test = plt.figure("test") - fig_ref = plt.figure("reference") - func(fig_test, fig_ref) - test_image_path = result_dir / (func.__name__ + "." + ext) - ref_image_path = ( - result_dir / (func.__name__ + "-expected." + ext)) - fig_test.savefig(test_image_path) - fig_ref.savefig(ref_image_path) - _raise_on_image_difference( - ref_image_path, test_image_path, tol=tol) - - elif len(inspect.signature(func).parameters) == 3: - # Method. - @pytest.mark.parametrize("ext", extensions) - def wrapper(self, ext): - fig_test = plt.figure("test") - fig_ref = plt.figure("reference") - func(self, fig_test, fig_ref) - test_image_path = result_dir / (func.__name__ + "." + ext) - ref_image_path = ( - result_dir / (func.__name__ + "-expected." + ext)) - fig_test.savefig(test_image_path) - fig_ref.savefig(ref_image_path) - _raise_on_image_difference( - ref_image_path, test_image_path, tol=tol) + @pytest.mark.parametrize("ext", extensions) + def wrapper(*args, ext, **kwargs): + fig_test = plt.figure("test") + fig_ref = plt.figure("reference") + func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs) + test_image_path = result_dir / (func.__name__ + "." + ext) + ref_image_path = result_dir / ( + func.__name__ + "-expected." + ext + ) + fig_test.savefig(test_image_path) + fig_ref.savefig(ref_image_path) + _raise_on_image_difference( + ref_image_path, test_image_path, tol=tol + ) + + sig = inspect.signature(func) + new_sig = sig.replace( + parameters=([param + for param in sig.parameters.values() + if param.name not in {"fig_test", "fig_ref"}] + + [inspect.Parameter("ext", POSITIONAL_OR_KEYWORD)]) + ) + wrapper.__signature__ = new_sig + + # reach a bit into pytest internals to hoist the marks from + # our wrapped function + new_marks = getattr(func, "pytestmark", []) + wrapper.pytestmark + wrapper.pytestmark = new_marks return wrapper diff --git a/lib/matplotlib/tests/test_testing.py b/lib/matplotlib/tests/test_testing.py index 55415abb6a04..d1273ec1ba46 100644 --- a/lib/matplotlib/tests/test_testing.py +++ b/lib/matplotlib/tests/test_testing.py @@ -1,7 +1,17 @@ import warnings import pytest +from matplotlib.testing.decorators import check_figures_equal -@pytest.mark.xfail(strict=True, - reason="testing that warnings fail tests") + +@pytest.mark.xfail( + strict=True, reason="testing that warnings fail tests" +) def test_warn_to_fail(): warnings.warn("This should fail the test") + + +@pytest.mark.parametrize("a", [1]) +@check_figures_equal(extensions=["png"]) +@pytest.mark.parametrize("b", [1]) +def test_parametrize_with_check_figure_equal(a, fig_ref, b, fig_test): + assert a == b
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: