Skip to content

Commit 114f250

Browse files
add match-case support for unions
1 parent cb2b850 commit 114f250

File tree

2 files changed

+240
-43
lines changed

2 files changed

+240
-43
lines changed

Lib/test/test_patma.py

Lines changed: 219 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import dis
55
import enum
66
import inspect
7-
from re import I
87
import sys
98
import unittest
109
from test import support
@@ -16,6 +15,13 @@ class Point:
1615
y: int
1716

1817

18+
@dataclasses.dataclass
19+
class Point3D:
20+
x: int
21+
y: int
22+
z: int
23+
24+
1925
class TestCompiler(unittest.TestCase):
2026

2127
def test_refleaks(self):
@@ -2891,11 +2897,81 @@ class B(A): ...
28912897

28922898
def test_patma_union_type(self):
28932899
IntOrStr = int | str
2894-
x = 0
2895-
match x:
2900+
w = None
2901+
match 0:
28962902
case IntOrStr():
2897-
x = 1
2898-
self.assertEqual(x, 1)
2903+
w = 0
2904+
self.assertEqual(w, 0)
2905+
2906+
def test_patma_union_no_match(self):
2907+
StrOrBytes = str | bytes
2908+
w = None
2909+
match 0:
2910+
case StrOrBytes():
2911+
w = 0
2912+
self.assertIsNone(w)
2913+
2914+
def test_union_type_positional_subpattern(self):
2915+
IntOrStr = int | str
2916+
w = None
2917+
match 0:
2918+
case IntOrStr(y):
2919+
w = y
2920+
self.assertEqual(w, 0)
2921+
2922+
def test_union_type_keyword_subpattern(self):
2923+
EitherPoint = Point | Point3D
2924+
p = Point(x=1, y=2)
2925+
w = None
2926+
match p:
2927+
case EitherPoint(x=1, y=2):
2928+
w = 0
2929+
self.assertEqual(w, 0)
2930+
2931+
def test_patma_union_arg(self):
2932+
p = Point(x=1, y=2)
2933+
IntOrStr = int | str
2934+
w = None
2935+
match p:
2936+
case Point(IntOrStr(), IntOrStr()):
2937+
w = 0
2938+
self.assertEqual(w, 0)
2939+
2940+
def test_patma_union_kwarg(self):
2941+
p = Point(x=1, y=2)
2942+
IntOrStr = int | str
2943+
w = None
2944+
match p:
2945+
case Point(x=IntOrStr(), y=IntOrStr()):
2946+
w = 0
2947+
self.assertEqual(w, 0)
2948+
2949+
def test_patma_union_arg_no_match(self):
2950+
p = Point(x=1, y=2)
2951+
StrOrBytes = str | bytes
2952+
w = None
2953+
match p:
2954+
case Point(StrOrBytes(), StrOrBytes()):
2955+
w = 0
2956+
self.assertIsNone(w)
2957+
2958+
def test_patma_union_kwarg_no_match(self):
2959+
p = Point(x=1, y=2)
2960+
StrOrBytes = str | bytes
2961+
w = None
2962+
match p:
2963+
case Point(x=StrOrBytes(), y=StrOrBytes()):
2964+
w = 0
2965+
self.assertIsNone(w)
2966+
2967+
def test_union_type_match_second_member(self):
2968+
EitherPoint = Point | Point3D
2969+
p = Point3D(x=1, y=2, z=3)
2970+
w = None
2971+
match p:
2972+
case EitherPoint(x=1, y=2, z=3):
2973+
w = 0
2974+
self.assertEqual(w, 0)
28992975

29002976

29012977
class TestSyntaxErrors(unittest.TestCase):
@@ -3239,8 +3315,28 @@ def test_mapping_pattern_duplicate_key_edge_case3(self):
32393315
pass
32403316
""")
32413317

3318+
32423319
class TestTypeErrors(unittest.TestCase):
32433320

3321+
def test_generic_type(self):
3322+
t = list[str]
3323+
w = None
3324+
with self.assertRaises(TypeError):
3325+
match ["s"]:
3326+
case t():
3327+
w = 0
3328+
self.assertIsNone(w)
3329+
3330+
def test_legacy_generic_type(self):
3331+
from typing import List
3332+
t = List[str]
3333+
w = None
3334+
with self.assertRaises(TypeError):
3335+
match ["s"]:
3336+
case t():
3337+
w = 0
3338+
self.assertIsNone(w)
3339+
32443340
def test_accepts_positional_subpatterns_0(self):
32453341
class Class:
32463342
__match_args__ = ()
@@ -3350,6 +3446,124 @@ def test_class_pattern_not_type(self):
33503446
w = 0
33513447
self.assertIsNone(w)
33523448

3449+
def test_class_or_union_not_specialform(self):
3450+
from typing import Literal
3451+
name = type(Literal).__name__
3452+
msg = rf"called match pattern must be a class or types.UnionType \(got {name}\)"
3453+
w = None
3454+
with self.assertRaisesRegex(TypeError, msg):
3455+
match 1:
3456+
case Literal():
3457+
w = 0
3458+
self.assertIsNone(w)
3459+
3460+
def test_legacy_union_type(self):
3461+
from typing import Union
3462+
IntOrStr = Union[int, str]
3463+
name = type(IntOrStr).__name__
3464+
msg = rf"called match pattern must be a class or types.UnionType \(got {name}\)"
3465+
w = None
3466+
with self.assertRaisesRegex(TypeError, msg):
3467+
match 1:
3468+
case IntOrStr():
3469+
w = 0
3470+
self.assertIsNone(w)
3471+
3472+
def test_expanded_union_mirrors_isinstance_success(self):
3473+
ListOfInt = list[int]
3474+
t = int | ListOfInt
3475+
try: # get the isinstance result
3476+
reference = isinstance(1, t)
3477+
except TypeError as exc:
3478+
reference = exc
3479+
3480+
try: # get the match-case result
3481+
match 1:
3482+
case int() | ListOfInt():
3483+
result = True
3484+
case _:
3485+
result = False
3486+
except TypeError as exc:
3487+
result = exc
3488+
3489+
# we should ge the same result
3490+
self.assertIs(result, True)
3491+
self.assertIs(reference, True)
3492+
3493+
def test_expanded_union_mirrors_isinstance_failure(self):
3494+
ListOfInt = list[int]
3495+
t = ListOfInt | int
3496+
3497+
try: # get the isinstance result
3498+
reference = isinstance(1, t)
3499+
except TypeError as exc:
3500+
reference = exc
3501+
3502+
try: # get the match-case result
3503+
match 1:
3504+
case ListOfInt() | int():
3505+
result = True
3506+
case _:
3507+
result = False
3508+
except TypeError as exc:
3509+
result = exc
3510+
3511+
# we should ge the same result
3512+
self.assertIsInstance(result, TypeError)
3513+
self.assertIsInstance(reference, TypeError)
3514+
3515+
def test_union_type_mirrors_isinstance_success(self):
3516+
t = int | list[int]
3517+
3518+
try: # get the isinstance result
3519+
reference = isinstance(1, t)
3520+
except TypeError as exc:
3521+
reference = exc
3522+
3523+
try: # get the match-case result
3524+
match 1:
3525+
case t():
3526+
result = True
3527+
case _:
3528+
result = False
3529+
except TypeError as exc:
3530+
result = exc
3531+
3532+
# we should ge the same result
3533+
self.assertIs(result, True)
3534+
self.assertIs(reference, True)
3535+
3536+
def test_union_type_mirrors_isinstance_failure(self):
3537+
t = list[int] | int
3538+
3539+
try: # get the isinstance result
3540+
reference = isinstance(1, t)
3541+
except TypeError as exc:
3542+
reference = exc
3543+
3544+
try: # get the match-case result
3545+
match 1:
3546+
case t():
3547+
result = True
3548+
case _:
3549+
result = False
3550+
except TypeError as exc:
3551+
result = exc
3552+
3553+
# we should ge the same result
3554+
self.assertIsInstance(result, TypeError)
3555+
self.assertIsInstance(reference, TypeError)
3556+
3557+
def test_generic_union_type(self):
3558+
from collections.abc import Sequence, Set
3559+
t = Sequence[str] | Set[str]
3560+
w = None
3561+
with self.assertRaises(TypeError):
3562+
match ["s"]:
3563+
case t():
3564+
w = 0
3565+
self.assertIsNone(w)
3566+
33533567
def test_regular_protocol(self):
33543568
from typing import Protocol
33553569
class P(Protocol): ...
@@ -3379,31 +3593,6 @@ class A:
33793593
w = 0
33803594
self.assertIsNone(w)
33813595

3382-
def test_union_type_postional_subpattern(self):
3383-
IntOrStr = int | str
3384-
x = 1
3385-
w = None
3386-
with self.assertRaises(TypeError):
3387-
match x:
3388-
case IntOrStr(x):
3389-
w = 0
3390-
self.assertEqual(x, 1)
3391-
self.assertIsNone(w)
3392-
3393-
def test_union_type_keyword_subpattern(self):
3394-
@dataclasses.dataclass
3395-
class Point2:
3396-
x: int
3397-
y: int
3398-
EitherPoint = Point | Point2
3399-
x = Point(x=1, y=2)
3400-
w = None
3401-
with self.assertRaises(TypeError):
3402-
match x:
3403-
case EitherPoint(x=1, y=2):
3404-
w = 0
3405-
self.assertIsNone(w)
3406-
34073596

34083597
class TestValueErrors(unittest.TestCase):
34093598

Python/ceval.c

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -726,26 +726,34 @@ PyObject*
726726
_PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type,
727727
Py_ssize_t nargs, PyObject *kwargs)
728728
{
729-
if (!PyType_Check(type) && !_PyUnion_Check(type)) {
730-
const char *e = "called match pattern must be a class or a union";
731-
_PyErr_Format(tstate, PyExc_TypeError, e);
729+
// Recurse on unions.
730+
if (_PyUnion_Check(type)) {
731+
// get union members
732+
PyObject *members = _Py_union_args(type);
733+
const Py_ssize_t n = PyTuple_GET_SIZE(members);
734+
735+
// iterate over union members and return first match
736+
for (Py_ssize_t i = 0; i < n; i++) {
737+
PyObject *member = PyTuple_GET_ITEM(members, i);
738+
PyObject *attrs = _PyEval_MatchClass(tstate, subject, member, nargs, kwargs);
739+
// match found
740+
if (attrs != NULL) {
741+
return attrs;
742+
}
743+
}
744+
// no match found
745+
return NULL;
746+
}
747+
if (!PyType_Check(type)) {
748+
const char *e = "called match pattern must be a class or types.UnionType (got %s)";
749+
_PyErr_Format(tstate, PyExc_TypeError, e, Py_TYPE(type)->tp_name);
732750
return NULL;
733751
}
734752
assert(PyTuple_CheckExact(kwargs));
735753
// First, an isinstance check:
736754
if (PyObject_IsInstance(subject, type) <= 0) {
737755
return NULL;
738756
}
739-
// Subpatterns are not supported for union types:
740-
if (_PyUnion_Check(type)) {
741-
// Return error if any positional or keyword arguments are given:
742-
if (nargs || PyTuple_GET_SIZE(kwargs)) {
743-
const char *e = "union types do not support sub-patterns";
744-
_PyErr_Format(tstate, PyExc_TypeError, e);
745-
return NULL;
746-
}
747-
return PyTuple_New(0);
748-
}
749757
// So far so good:
750758
PyObject *seen = PySet_New(NULL);
751759
if (seen == NULL) {

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy