@@ -381,41 +381,40 @@ def test_plot(fig_test, fig_ref):
381381 fig_test.subplots().plot([1, 3, 5])
382382 fig_ref.subplots().plot([0, 1, 2], [1, 3, 5])
383383 """
384-
384+ POSITIONAL_OR_KEYWORD = inspect . Parameter . POSITIONAL_OR_KEYWORD
385385 def decorator (func ):
386386 import pytest
387387
388388 _ , result_dir = _image_directories (func )
389389
390- if len (inspect .signature (func ).parameters ) == 2 :
391- # Free-standing function.
392- @pytest .mark .parametrize ("ext" , extensions )
393- def wrapper (ext ):
394- fig_test = plt .figure ("test" )
395- fig_ref = plt .figure ("reference" )
396- func (fig_test , fig_ref )
397- test_image_path = result_dir / (func .__name__ + "." + ext )
398- ref_image_path = (
399- result_dir / (func .__name__ + "-expected." + ext ))
400- fig_test .savefig (test_image_path )
401- fig_ref .savefig (ref_image_path )
402- _raise_on_image_difference (
403- ref_image_path , test_image_path , tol = tol )
404-
405- elif len (inspect .signature (func ).parameters ) == 3 :
406- # Method.
407- @pytest .mark .parametrize ("ext" , extensions )
408- def wrapper (self , ext ):
409- fig_test = plt .figure ("test" )
410- fig_ref = plt .figure ("reference" )
411- func (self , fig_test , fig_ref )
412- test_image_path = result_dir / (func .__name__ + "." + ext )
413- ref_image_path = (
414- result_dir / (func .__name__ + "-expected." + ext ))
415- fig_test .savefig (test_image_path )
416- fig_ref .savefig (ref_image_path )
417- _raise_on_image_difference (
418- ref_image_path , test_image_path , tol = tol )
390+ @pytest .mark .parametrize ("ext" , extensions )
391+ def wrapper (* args , ext , ** kwargs ):
392+ fig_test = plt .figure ("test" )
393+ fig_ref = plt .figure ("reference" )
394+ func (* args , fig_test = fig_test , fig_ref = fig_ref , ** kwargs )
395+ test_image_path = result_dir / (func .__name__ + "." + ext )
396+ ref_image_path = result_dir / (
397+ func .__name__ + "-expected." + ext
398+ )
399+ fig_test .savefig (test_image_path )
400+ fig_ref .savefig (ref_image_path )
401+ _raise_on_image_difference (
402+ ref_image_path , test_image_path , tol = tol
403+ )
404+
405+ sig = inspect .signature (func )
406+ new_sig = sig .replace (
407+ parameters = ([param
408+ for param in sig .parameters .values ()
409+ if param .name not in {"fig_test" , "fig_ref" }]
410+ + [inspect .Parameter ("ext" , POSITIONAL_OR_KEYWORD )])
411+ )
412+ wrapper .__signature__ = new_sig
413+
414+ # reach a bit into pytest internals to hoist the marks from
415+ # our wrapped function
416+ new_marks = getattr (func , "pytestmark" , []) + wrapper .pytestmark
417+ wrapper .pytestmark = new_marks
419418
420419 return wrapper
421420
0 commit comments