diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index 0a67742a2b3081..abc82213e044e6 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -639,28 +639,38 @@ def get_annotations( if eval_str and format != Format.VALUE: raise ValueError("eval_str=True is only supported with format=Format.VALUE") - # For VALUE format, we look at __annotations__ directly. - if format != Format.VALUE: - annotate = get_annotate_function(obj) - if annotate is not None: - ann = call_annotate_function(annotate, format, owner=obj) - if not isinstance(ann, dict): - raise ValueError(f"{obj!r}.__annotate__ returned a non-dict") - return dict(ann) - - if isinstance(obj, type): - try: - ann = _BASE_GET_ANNOTATIONS(obj) - except AttributeError: - # For static types, the descriptor raises AttributeError. - return {} - else: - ann = getattr(obj, "__annotations__", None) - if ann is None: - return {} - - if not isinstance(ann, dict): - raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None") + match format: + case Format.VALUE: + # For VALUE, we only look at __annotations__ + ann = _get_dunder_annotations(obj) + case Format.FORWARDREF: + # For FORWARDREF, we use __annotations__ if it exists + try: + ann = _get_dunder_annotations(obj) + except NameError: + pass + else: + return dict(ann) + + # But if __annotations__ threw a NameError, we try calling __annotate__ + ann = _get_and_call_annotate(obj, format) + if ann is not None: + return ann + + # If that didn't work either, we have a very weird object: evaluating + # __annotations__ threw NameError and there is no __annotate__. In that case, + # we fall back to trying __annotations__ again. + return dict(_get_dunder_annotations(obj)) + case Format.SOURCE: + # For SOURCE, we try to call __annotate__ + ann = _get_and_call_annotate(obj, format) + if ann is not None: + return ann + # But if we didn't get it, we use __annotations__ instead. + ann = _get_dunder_annotations(obj) + return ann + case _: + raise ValueError(f"Unsupported format {format!r}") if not ann: return {} @@ -725,3 +735,30 @@ def get_annotations( for key, value in ann.items() } return return_value + + +def _get_and_call_annotate(obj, format): + annotate = get_annotate_function(obj) + if annotate is not None: + ann = call_annotate_function(annotate, format, owner=obj) + if not isinstance(ann, dict): + raise ValueError(f"{obj!r}.__annotate__ returned a non-dict") + return dict(ann) + return None + + +def _get_dunder_annotations(obj): + if isinstance(obj, type): + try: + ann = _BASE_GET_ANNOTATIONS(obj) + except AttributeError: + # For static types, the descriptor raises AttributeError. + return {} + else: + ann = getattr(obj, "__annotations__", None) + if ann is None: + return {} + + if not isinstance(ann, dict): + raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None") + return dict(ann) diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index dd8ceb55a411fb..62f57c0a83d4e2 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -705,17 +705,97 @@ def f(x: int): self.assertEqual(annotationlib.get_annotations(f), {"x": int}) self.assertEqual( - annotationlib.get_annotations(f, format=annotationlib.Format.FORWARDREF), + annotationlib.get_annotations(f, format=Format.FORWARDREF), {"x": int}, ) f.__annotations__["x"] = str # The modification is reflected in VALUE (the default) self.assertEqual(annotationlib.get_annotations(f), {"x": str}) - # ... but not in FORWARDREF, which uses __annotate__ + # ... and also in FORWARDREF, which tries __annotations__ if available self.assertEqual( - annotationlib.get_annotations(f, format=annotationlib.Format.FORWARDREF), - {"x": int}, + annotationlib.get_annotations(f, format=Format.FORWARDREF), + {"x": str}, + ) + # ... but not in SOURCE which always uses __annotate__ + self.assertEqual( + annotationlib.get_annotations(f, format=Format.SOURCE), + {"x": "int"}, + ) + + def test_non_dict_annotations(self): + class WeirdAnnotations: + @property + def __annotations__(self): + return "not a dict" + + wa = WeirdAnnotations() + for format in Format: + with ( + self.subTest(format=format), + self.assertRaisesRegex( + ValueError, r".*__annotations__ is neither a dict nor None" + ), + ): + annotationlib.get_annotations(wa, format=format) + + def test_annotations_on_custom_object(self): + class HasAnnotations: + @property + def __annotations__(self): + return {"x": int} + + ha = HasAnnotations() + self.assertEqual( + annotationlib.get_annotations(ha, format=Format.VALUE), {"x": int} + ) + self.assertEqual( + annotationlib.get_annotations(ha, format=Format.FORWARDREF), {"x": int} + ) + + # TODO(gh-124412): This should return {'x': 'int'} instead. + self.assertEqual( + annotationlib.get_annotations(ha, format=Format.SOURCE), {"x": int} + ) + + def test_raising_annotations_on_custom_object(self): + class HasRaisingAnnotations: + @property + def __annotations__(self): + return {"x": undefined} + + hra = HasRaisingAnnotations() + + with self.assertRaises(NameError): + annotationlib.get_annotations(hra, format=Format.VALUE) + + with self.assertRaises(NameError): + annotationlib.get_annotations(hra, format=Format.FORWARDREF) + + undefined = float + self.assertEqual( + annotationlib.get_annotations(hra, format=Format.VALUE), {"x": float} + ) + + def test_forwardref_prefers_annotations(self): + class HasBoth: + @property + def __annotations__(self): + return {"x": int} + + @property + def __annotate__(self): + return lambda format: {"x": str} + + hb = HasBoth() + self.assertEqual( + annotationlib.get_annotations(hb, format=Format.VALUE), {"x": int} + ) + self.assertEqual( + annotationlib.get_annotations(hb, format=Format.FORWARDREF), {"x": int} + ) + self.assertEqual( + annotationlib.get_annotations(hb, format=Format.SOURCE), {"x": str} ) def test_pep695_generic_class_with_future_annotations(self):
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: