From cb2b850d4589445d3d7319da824e7b5f4ab13c75 Mon Sep 17 00:00:00 2001 From: Thomas M Kehrenberg Date: Thu, 2 May 2024 22:10:45 +0200 Subject: [PATCH 1/7] Allow the use of unions as match patterns --- Lib/test/test_patma.py | 34 ++++++++++++++++++++++++++++++++++ Python/ceval.c | 15 +++++++++++++-- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py index 5d0857b059ea23..bf849c673d4ab5 100644 --- a/Lib/test/test_patma.py +++ b/Lib/test/test_patma.py @@ -4,6 +4,7 @@ import dis import enum import inspect +from re import I import sys import unittest from test import support @@ -2888,6 +2889,14 @@ class B(A): ... h = 1 self.assertEqual(h, 1) + def test_patma_union_type(self): + IntOrStr = int | str + x = 0 + match x: + case IntOrStr(): + x = 1 + self.assertEqual(x, 1) + class TestSyntaxErrors(unittest.TestCase): @@ -3370,6 +3379,31 @@ class A: w = 0 self.assertIsNone(w) + def test_union_type_postional_subpattern(self): + IntOrStr = int | str + x = 1 + w = None + with self.assertRaises(TypeError): + match x: + case IntOrStr(x): + w = 0 + self.assertEqual(x, 1) + self.assertIsNone(w) + + def test_union_type_keyword_subpattern(self): + @dataclasses.dataclass + class Point2: + x: int + y: int + EitherPoint = Point | Point2 + x = Point(x=1, y=2) + w = None + with self.assertRaises(TypeError): + match x: + case EitherPoint(x=1, y=2): + w = 0 + self.assertIsNone(w) + class TestValueErrors(unittest.TestCase): diff --git a/Python/ceval.c b/Python/ceval.c index 291e753dec0ce5..24df578f158d30 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -39,6 +39,7 @@ #include "pycore_template.h" // _PyTemplate_Build() #include "pycore_traceback.h" // _PyTraceBack_FromFrame #include "pycore_tuple.h" // _PyTuple_ITEMS() +#include "pycore_unionobject.h" // _PyUnion_Check() #include "pycore_uop_ids.h" // Uops #include "dictobject.h" @@ -725,8 +726,8 @@ PyObject* _PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type, Py_ssize_t nargs, PyObject *kwargs) { - if (!PyType_Check(type)) { - const char *e = "called match pattern must be a class"; + if (!PyType_Check(type) && !_PyUnion_Check(type)) { + const char *e = "called match pattern must be a class or a union"; _PyErr_Format(tstate, PyExc_TypeError, e); return NULL; } @@ -735,6 +736,16 @@ _PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type, if (PyObject_IsInstance(subject, type) <= 0) { return NULL; } + // Subpatterns are not supported for union types: + if (_PyUnion_Check(type)) { + // Return error if any positional or keyword arguments are given: + if (nargs || PyTuple_GET_SIZE(kwargs)) { + const char *e = "union types do not support sub-patterns"; + _PyErr_Format(tstate, PyExc_TypeError, e); + return NULL; + } + return PyTuple_New(0); + } // So far so good: PyObject *seen = PySet_New(NULL); if (seen == NULL) { From 114f250a8adf73b7f8b382ef32ebf582c8536a22 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Mon, 6 May 2024 14:26:57 +0200 Subject: [PATCH 2/7] add match-case support for unions --- Lib/test/test_patma.py | 249 ++++++++++++++++++++++++++++++++++++----- Python/ceval.c | 34 +++--- 2 files changed, 240 insertions(+), 43 deletions(-) diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py index bf849c673d4ab5..c16f09a583fbba 100644 --- a/Lib/test/test_patma.py +++ b/Lib/test/test_patma.py @@ -4,7 +4,6 @@ import dis import enum import inspect -from re import I import sys import unittest from test import support @@ -16,6 +15,13 @@ class Point: y: int +@dataclasses.dataclass +class Point3D: + x: int + y: int + z: int + + class TestCompiler(unittest.TestCase): def test_refleaks(self): @@ -2891,11 +2897,81 @@ class B(A): ... def test_patma_union_type(self): IntOrStr = int | str - x = 0 - match x: + w = None + match 0: case IntOrStr(): - x = 1 - self.assertEqual(x, 1) + w = 0 + self.assertEqual(w, 0) + + def test_patma_union_no_match(self): + StrOrBytes = str | bytes + w = None + match 0: + case StrOrBytes(): + w = 0 + self.assertIsNone(w) + + def test_union_type_positional_subpattern(self): + IntOrStr = int | str + w = None + match 0: + case IntOrStr(y): + w = y + self.assertEqual(w, 0) + + def test_union_type_keyword_subpattern(self): + EitherPoint = Point | Point3D + p = Point(x=1, y=2) + w = None + match p: + case EitherPoint(x=1, y=2): + w = 0 + self.assertEqual(w, 0) + + def test_patma_union_arg(self): + p = Point(x=1, y=2) + IntOrStr = int | str + w = None + match p: + case Point(IntOrStr(), IntOrStr()): + w = 0 + self.assertEqual(w, 0) + + def test_patma_union_kwarg(self): + p = Point(x=1, y=2) + IntOrStr = int | str + w = None + match p: + case Point(x=IntOrStr(), y=IntOrStr()): + w = 0 + self.assertEqual(w, 0) + + def test_patma_union_arg_no_match(self): + p = Point(x=1, y=2) + StrOrBytes = str | bytes + w = None + match p: + case Point(StrOrBytes(), StrOrBytes()): + w = 0 + self.assertIsNone(w) + + def test_patma_union_kwarg_no_match(self): + p = Point(x=1, y=2) + StrOrBytes = str | bytes + w = None + match p: + case Point(x=StrOrBytes(), y=StrOrBytes()): + w = 0 + self.assertIsNone(w) + + def test_union_type_match_second_member(self): + EitherPoint = Point | Point3D + p = Point3D(x=1, y=2, z=3) + w = None + match p: + case EitherPoint(x=1, y=2, z=3): + w = 0 + self.assertEqual(w, 0) class TestSyntaxErrors(unittest.TestCase): @@ -3239,8 +3315,28 @@ def test_mapping_pattern_duplicate_key_edge_case3(self): pass """) + class TestTypeErrors(unittest.TestCase): + def test_generic_type(self): + t = list[str] + w = None + with self.assertRaises(TypeError): + match ["s"]: + case t(): + w = 0 + self.assertIsNone(w) + + def test_legacy_generic_type(self): + from typing import List + t = List[str] + w = None + with self.assertRaises(TypeError): + match ["s"]: + case t(): + w = 0 + self.assertIsNone(w) + def test_accepts_positional_subpatterns_0(self): class Class: __match_args__ = () @@ -3350,6 +3446,124 @@ def test_class_pattern_not_type(self): w = 0 self.assertIsNone(w) + def test_class_or_union_not_specialform(self): + from typing import Literal + name = type(Literal).__name__ + msg = rf"called match pattern must be a class or types.UnionType \(got {name}\)" + w = None + with self.assertRaisesRegex(TypeError, msg): + match 1: + case Literal(): + w = 0 + self.assertIsNone(w) + + def test_legacy_union_type(self): + from typing import Union + IntOrStr = Union[int, str] + name = type(IntOrStr).__name__ + msg = rf"called match pattern must be a class or types.UnionType \(got {name}\)" + w = None + with self.assertRaisesRegex(TypeError, msg): + match 1: + case IntOrStr(): + w = 0 + self.assertIsNone(w) + + def test_expanded_union_mirrors_isinstance_success(self): + ListOfInt = list[int] + t = int | ListOfInt + try: # get the isinstance result + reference = isinstance(1, t) + except TypeError as exc: + reference = exc + + try: # get the match-case result + match 1: + case int() | ListOfInt(): + result = True + case _: + result = False + except TypeError as exc: + result = exc + + # we should ge the same result + self.assertIs(result, True) + self.assertIs(reference, True) + + def test_expanded_union_mirrors_isinstance_failure(self): + ListOfInt = list[int] + t = ListOfInt | int + + try: # get the isinstance result + reference = isinstance(1, t) + except TypeError as exc: + reference = exc + + try: # get the match-case result + match 1: + case ListOfInt() | int(): + result = True + case _: + result = False + except TypeError as exc: + result = exc + + # we should ge the same result + self.assertIsInstance(result, TypeError) + self.assertIsInstance(reference, TypeError) + + def test_union_type_mirrors_isinstance_success(self): + t = int | list[int] + + try: # get the isinstance result + reference = isinstance(1, t) + except TypeError as exc: + reference = exc + + try: # get the match-case result + match 1: + case t(): + result = True + case _: + result = False + except TypeError as exc: + result = exc + + # we should ge the same result + self.assertIs(result, True) + self.assertIs(reference, True) + + def test_union_type_mirrors_isinstance_failure(self): + t = list[int] | int + + try: # get the isinstance result + reference = isinstance(1, t) + except TypeError as exc: + reference = exc + + try: # get the match-case result + match 1: + case t(): + result = True + case _: + result = False + except TypeError as exc: + result = exc + + # we should ge the same result + self.assertIsInstance(result, TypeError) + self.assertIsInstance(reference, TypeError) + + def test_generic_union_type(self): + from collections.abc import Sequence, Set + t = Sequence[str] | Set[str] + w = None + with self.assertRaises(TypeError): + match ["s"]: + case t(): + w = 0 + self.assertIsNone(w) + def test_regular_protocol(self): from typing import Protocol class P(Protocol): ... @@ -3379,31 +3593,6 @@ class A: w = 0 self.assertIsNone(w) - def test_union_type_postional_subpattern(self): - IntOrStr = int | str - x = 1 - w = None - with self.assertRaises(TypeError): - match x: - case IntOrStr(x): - w = 0 - self.assertEqual(x, 1) - self.assertIsNone(w) - - def test_union_type_keyword_subpattern(self): - @dataclasses.dataclass - class Point2: - x: int - y: int - EitherPoint = Point | Point2 - x = Point(x=1, y=2) - w = None - with self.assertRaises(TypeError): - match x: - case EitherPoint(x=1, y=2): - w = 0 - self.assertIsNone(w) - class TestValueErrors(unittest.TestCase): diff --git a/Python/ceval.c b/Python/ceval.c index 24df578f158d30..a5d6f37aa11beb 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -726,9 +726,27 @@ PyObject* _PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type, Py_ssize_t nargs, PyObject *kwargs) { - if (!PyType_Check(type) && !_PyUnion_Check(type)) { - const char *e = "called match pattern must be a class or a union"; - _PyErr_Format(tstate, PyExc_TypeError, e); + // Recurse on unions. + if (_PyUnion_Check(type)) { + // get union members + PyObject *members = _Py_union_args(type); + const Py_ssize_t n = PyTuple_GET_SIZE(members); + + // iterate over union members and return first match + for (Py_ssize_t i = 0; i < n; i++) { + PyObject *member = PyTuple_GET_ITEM(members, i); + PyObject *attrs = _PyEval_MatchClass(tstate, subject, member, nargs, kwargs); + // match found + if (attrs != NULL) { + return attrs; + } + } + // no match found + return NULL; + } + if (!PyType_Check(type)) { + const char *e = "called match pattern must be a class or types.UnionType (got %s)"; + _PyErr_Format(tstate, PyExc_TypeError, e, Py_TYPE(type)->tp_name); return NULL; } assert(PyTuple_CheckExact(kwargs)); @@ -736,16 +754,6 @@ _PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type, if (PyObject_IsInstance(subject, type) <= 0) { return NULL; } - // Subpatterns are not supported for union types: - if (_PyUnion_Check(type)) { - // Return error if any positional or keyword arguments are given: - if (nargs || PyTuple_GET_SIZE(kwargs)) { - const char *e = "union types do not support sub-patterns"; - _PyErr_Format(tstate, PyExc_TypeError, e); - return NULL; - } - return PyTuple_New(0); - } // So far so good: PyObject *seen = PySet_New(NULL); if (seen == NULL) { From 2f9aa386a31f71276dee415557fe92993c20b97b Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Mon, 21 Jul 2025 11:35:48 +0200 Subject: [PATCH 3/7] Updated error string --- Lib/test/test_patma.py | 4 ++-- Python/ceval.c | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py index c16f09a583fbba..52b5aa62d1f116 100644 --- a/Lib/test/test_patma.py +++ b/Lib/test/test_patma.py @@ -3449,7 +3449,7 @@ def test_class_pattern_not_type(self): def test_class_or_union_not_specialform(self): from typing import Literal name = type(Literal).__name__ - msg = rf"called match pattern must be a class or types.UnionType \(got {name}\)" + msg = rf"called match pattern must be a class or typing.Union of classes \(got {name}\)" w = None with self.assertRaisesRegex(TypeError, msg): match 1: @@ -3461,7 +3461,7 @@ def test_legacy_union_type(self): from typing import Union IntOrStr = Union[int, str] name = type(IntOrStr).__name__ - msg = rf"called match pattern must be a class or types.UnionType \(got {name}\)" + msg = rf"called match pattern must be a class or typing.Union of classes \(got {name}\)" w = None with self.assertRaisesRegex(TypeError, msg): match 1: diff --git a/Python/ceval.c b/Python/ceval.c index a5d6f37aa11beb..65d69ffe352ec5 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -745,7 +745,7 @@ _PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type, return NULL; } if (!PyType_Check(type)) { - const char *e = "called match pattern must be a class or types.UnionType (got %s)"; + const char *e = "called match pattern must be a class or typing.Union of classes (got %s)"; _PyErr_Format(tstate, PyExc_TypeError, e, Py_TYPE(type)->tp_name); return NULL; } From 4829f23480950b5c831acbdb32695fd307321dcc Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Mon, 21 Jul 2025 11:38:15 +0200 Subject: [PATCH 4/7] changed test_legacy_union to test_typing_union. Changed behavior due to gh-105499 --- Lib/test/test_patma.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py index 52b5aa62d1f116..acbf6aafadcddd 100644 --- a/Lib/test/test_patma.py +++ b/Lib/test/test_patma.py @@ -3457,17 +3457,14 @@ def test_class_or_union_not_specialform(self): w = 0 self.assertIsNone(w) - def test_legacy_union_type(self): + def test_typing_union(self): from typing import Union - IntOrStr = Union[int, str] - name = type(IntOrStr).__name__ - msg = rf"called match pattern must be a class or typing.Union of classes \(got {name}\)" - w = None - with self.assertRaisesRegex(TypeError, msg): - match 1: - case IntOrStr(): - w = 0 - self.assertIsNone(w) + IntOrStr = Union[int, str] # identical to int | str since gh-105499 + w = False + match 1: + case IntOrStr(): + w = True + self.assertIs(w, True) def test_expanded_union_mirrors_isinstance_success(self): ListOfInt = list[int] From 65f2a6437bbf69817b5d3ddd40ef98446a83c857 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Mon, 21 Jul 2025 14:38:11 +0200 Subject: [PATCH 5/7] Update Python/ceval.c Co-authored-by: sobolevn --- Python/ceval.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Python/ceval.c b/Python/ceval.c index 65d69ffe352ec5..fd102ea4516aab 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -745,7 +745,7 @@ _PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type, return NULL; } if (!PyType_Check(type)) { - const char *e = "called match pattern must be a class or typing.Union of classes (got %s)"; + const char *e = "called match pattern must be a class or a union of classes (got %s)"; _PyErr_Format(tstate, PyExc_TypeError, e, Py_TYPE(type)->tp_name); return NULL; } From 3e12030d370c09c65640f867919bf9e504ed2007 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Mon, 21 Jul 2025 15:13:40 +0200 Subject: [PATCH 6/7] fixed test to match new error message --- Lib/test/test_patma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py index acbf6aafadcddd..cb6662d3eaaff1 100644 --- a/Lib/test/test_patma.py +++ b/Lib/test/test_patma.py @@ -3449,7 +3449,7 @@ def test_class_pattern_not_type(self): def test_class_or_union_not_specialform(self): from typing import Literal name = type(Literal).__name__ - msg = rf"called match pattern must be a class or typing.Union of classes \(got {name}\)" + msg = rf"called match pattern must be a class or a union of classes \(got {name}\)" w = None with self.assertRaisesRegex(TypeError, msg): match 1: From 0febb806bca93ad8fadc550c13857b70c83beedb Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Mon, 21 Jul 2025 15:20:44 +0200 Subject: [PATCH 7/7] added to docs --- Doc/reference/compound_stmts.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Doc/reference/compound_stmts.rst b/Doc/reference/compound_stmts.rst index e95fa3a6424e23..36bd911c05f09a 100644 --- a/Doc/reference/compound_stmts.rst +++ b/Doc/reference/compound_stmts.rst @@ -1098,6 +1098,11 @@ The same keyword should not be repeated in class patterns. The following is the logical flow for matching a class pattern against a subject value: +#. If ``name_or_attr`` is a union type, apply the subsequent steps in order to + each of its members, returning the first successful match or raising the first + encountered exception. + This mirrors the behavior of :func:`isinstance` with union types. + #. If ``name_or_attr`` is not an instance of the builtin :class:`type` , raise :exc:`TypeError`. pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy