diff --git a/Doc/reference/compound_stmts.rst b/Doc/reference/compound_stmts.rst index e95fa3a6424e23..36bd911c05f09a 100644 --- a/Doc/reference/compound_stmts.rst +++ b/Doc/reference/compound_stmts.rst @@ -1098,6 +1098,11 @@ The same keyword should not be repeated in class patterns. The following is the logical flow for matching a class pattern against a subject value: +#. If ``name_or_attr`` is a union type, apply the subsequent steps in order to + each of its members, returning the first successful match or raising the first + encountered exception. + This mirrors the behavior of :func:`isinstance` with union types. + #. If ``name_or_attr`` is not an instance of the builtin :class:`type` , raise :exc:`TypeError`. diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py index 5d0857b059ea23..cb6662d3eaaff1 100644 --- a/Lib/test/test_patma.py +++ b/Lib/test/test_patma.py @@ -15,6 +15,13 @@ class Point: y: int +@dataclasses.dataclass +class Point3D: + x: int + y: int + z: int + + class TestCompiler(unittest.TestCase): def test_refleaks(self): @@ -2888,6 +2895,84 @@ class B(A): ... h = 1 self.assertEqual(h, 1) + def test_patma_union_type(self): + IntOrStr = int | str + w = None + match 0: + case IntOrStr(): + w = 0 + self.assertEqual(w, 0) + + def test_patma_union_no_match(self): + StrOrBytes = str | bytes + w = None + match 0: + case StrOrBytes(): + w = 0 + self.assertIsNone(w) + + def test_union_type_positional_subpattern(self): + IntOrStr = int | str + w = None + match 0: + case IntOrStr(y): + w = y + self.assertEqual(w, 0) + + def test_union_type_keyword_subpattern(self): + EitherPoint = Point | Point3D + p = Point(x=1, y=2) + w = None + match p: + case EitherPoint(x=1, y=2): + w = 0 + self.assertEqual(w, 0) + + def test_patma_union_arg(self): + p = Point(x=1, y=2) + IntOrStr = int | str + w = None + match p: + case Point(IntOrStr(), IntOrStr()): + w = 0 + self.assertEqual(w, 0) + + def test_patma_union_kwarg(self): + p = Point(x=1, y=2) + IntOrStr = int | str + w = None + match p: + case Point(x=IntOrStr(), y=IntOrStr()): + w = 0 + self.assertEqual(w, 0) + + def test_patma_union_arg_no_match(self): + p = Point(x=1, y=2) + StrOrBytes = str | bytes + w = None + match p: + case Point(StrOrBytes(), StrOrBytes()): + w = 0 + self.assertIsNone(w) + + def test_patma_union_kwarg_no_match(self): + p = Point(x=1, y=2) + StrOrBytes = str | bytes + w = None + match p: + case Point(x=StrOrBytes(), y=StrOrBytes()): + w = 0 + self.assertIsNone(w) + + def test_union_type_match_second_member(self): + EitherPoint = Point | Point3D + p = Point3D(x=1, y=2, z=3) + w = None + match p: + case EitherPoint(x=1, y=2, z=3): + w = 0 + self.assertEqual(w, 0) + class TestSyntaxErrors(unittest.TestCase): @@ -3230,8 +3315,28 @@ def test_mapping_pattern_duplicate_key_edge_case3(self): pass """) + class TestTypeErrors(unittest.TestCase): + def test_generic_type(self): + t = list[str] + w = None + with self.assertRaises(TypeError): + match ["s"]: + case t(): + w = 0 + self.assertIsNone(w) + + def test_legacy_generic_type(self): + from typing import List + t = List[str] + w = None + with self.assertRaises(TypeError): + match ["s"]: + case t(): + w = 0 + self.assertIsNone(w) + def test_accepts_positional_subpatterns_0(self): class Class: __match_args__ = () @@ -3341,6 +3446,121 @@ def test_class_pattern_not_type(self): w = 0 self.assertIsNone(w) + def test_class_or_union_not_specialform(self): + from typing import Literal + name = type(Literal).__name__ + msg = rf"called match pattern must be a class or a union of classes \(got {name}\)" + w = None + with self.assertRaisesRegex(TypeError, msg): + match 1: + case Literal(): + w = 0 + self.assertIsNone(w) + + def test_typing_union(self): + from typing import Union + IntOrStr = Union[int, str] # identical to int | str since gh-105499 + w = False + match 1: + case IntOrStr(): + w = True + self.assertIs(w, True) + + def test_expanded_union_mirrors_isinstance_success(self): + ListOfInt = list[int] + t = int | ListOfInt + try: # get the isinstance result + reference = isinstance(1, t) + except TypeError as exc: + reference = exc + + try: # get the match-case result + match 1: + case int() | ListOfInt(): + result = True + case _: + result = False + except TypeError as exc: + result = exc + + # we should ge the same result + self.assertIs(result, True) + self.assertIs(reference, True) + + def test_expanded_union_mirrors_isinstance_failure(self): + ListOfInt = list[int] + t = ListOfInt | int + + try: # get the isinstance result + reference = isinstance(1, t) + except TypeError as exc: + reference = exc + + try: # get the match-case result + match 1: + case ListOfInt() | int(): + result = True + case _: + result = False + except TypeError as exc: + result = exc + + # we should ge the same result + self.assertIsInstance(result, TypeError) + self.assertIsInstance(reference, TypeError) + + def test_union_type_mirrors_isinstance_success(self): + t = int | list[int] + + try: # get the isinstance result + reference = isinstance(1, t) + except TypeError as exc: + reference = exc + + try: # get the match-case result + match 1: + case t(): + result = True + case _: + result = False + except TypeError as exc: + result = exc + + # we should ge the same result + self.assertIs(result, True) + self.assertIs(reference, True) + + def test_union_type_mirrors_isinstance_failure(self): + t = list[int] | int + + try: # get the isinstance result + reference = isinstance(1, t) + except TypeError as exc: + reference = exc + + try: # get the match-case result + match 1: + case t(): + result = True + case _: + result = False + except TypeError as exc: + result = exc + + # we should ge the same result + self.assertIsInstance(result, TypeError) + self.assertIsInstance(reference, TypeError) + + def test_generic_union_type(self): + from collections.abc import Sequence, Set + t = Sequence[str] | Set[str] + w = None + with self.assertRaises(TypeError): + match ["s"]: + case t(): + w = 0 + self.assertIsNone(w) + def test_regular_protocol(self): from typing import Protocol class P(Protocol): ... diff --git a/Python/ceval.c b/Python/ceval.c index 291e753dec0ce5..fd102ea4516aab 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -39,6 +39,7 @@ #include "pycore_template.h" // _PyTemplate_Build() #include "pycore_traceback.h" // _PyTraceBack_FromFrame #include "pycore_tuple.h" // _PyTuple_ITEMS() +#include "pycore_unionobject.h" // _PyUnion_Check() #include "pycore_uop_ids.h" // Uops #include "dictobject.h" @@ -725,9 +726,27 @@ PyObject* _PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type, Py_ssize_t nargs, PyObject *kwargs) { + // Recurse on unions. + if (_PyUnion_Check(type)) { + // get union members + PyObject *members = _Py_union_args(type); + const Py_ssize_t n = PyTuple_GET_SIZE(members); + + // iterate over union members and return first match + for (Py_ssize_t i = 0; i < n; i++) { + PyObject *member = PyTuple_GET_ITEM(members, i); + PyObject *attrs = _PyEval_MatchClass(tstate, subject, member, nargs, kwargs); + // match found + if (attrs != NULL) { + return attrs; + } + } + // no match found + return NULL; + } if (!PyType_Check(type)) { - const char *e = "called match pattern must be a class"; - _PyErr_Format(tstate, PyExc_TypeError, e); + const char *e = "called match pattern must be a class or a union of classes (got %s)"; + _PyErr_Format(tstate, PyExc_TypeError, e, Py_TYPE(type)->tp_name); return NULL; } assert(PyTuple_CheckExact(kwargs));
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: