Skip to content

Commit 90418df

Browse files
committed
Make reachability code understand chained comparisons
Currently, our reachability code does not understand how to parse comparisons like `a == b == c`: the `find_isinstance_check` method only attempts to analyze comparisons that contain a single `==`, `is`, or `in` operator. This pull request generalizes that logic so we can support any arbitrary number of comparisons. It also along the way unifies the logic we have for handling `is` and `==` checks: the latter check is now just treated a weaker variation of the former. (Expressions containing `==` may do arbitrary things if the underlying operands contain custom `__eq__` methods.) As a side-effect, this PR adds support for the following: x: Optional[str] if x is 'some-string': # Previously, the revealed type would be Union[str, None] # Now, the revealed type is just 'str' reveal_type(x) else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' We previously supported this narrowing logic when doing equality checks (e.g. when doing `if x == 'some-string'`). As a second side-effect, this PR adds support for the following: class Foo(Enum): A = 1 B = 2 y: Foo if y == Foo.A: reveal_type(y) # N: Revealed type is 'Literal[Foo.A]' else: reveal_type(y) # N: Revealed type is 'Literal[Foo.B]' We previously supported this kind of narrowing only when doing identity checks (e.g. `if y is Foo.A`). To avoid any bad interactions with custom `__eq__` methods, we enable this narrowing check only if both operands do not define custom `__eq__` methods.
1 parent e818a96 commit 90418df

File tree

5 files changed

+298
-59
lines changed

5 files changed

+298
-59
lines changed

mypy/checker.py

Lines changed: 150 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3536,67 +3536,61 @@ def find_isinstance_check(self, node: Expression
35363536
vartype = type_map[expr]
35373537
return self.conditional_callable_type_map(expr, vartype)
35383538
elif isinstance(node, ComparisonExpr):
3539-
operand_types = [coerce_to_literal(type_map[expr])
3540-
for expr in node.operands if expr in type_map]
3541-
3542-
is_not = node.operators == ['is not']
3543-
if (is_not or node.operators == ['is']) and len(operand_types) == len(node.operands):
3544-
if_vars = {} # type: TypeMap
3545-
else_vars = {} # type: TypeMap
3546-
3547-
for i, expr in enumerate(node.operands):
3548-
var_type = operand_types[i]
3549-
other_type = operand_types[1 - i]
3550-
3551-
if literal(expr) == LITERAL_TYPE and is_singleton_type(other_type):
3552-
# This should only be true at most once: there should be
3553-
# exactly two elements in node.operands and if the 'other type' is
3554-
# a singleton type, it by definition does not need to be narrowed:
3555-
# it already has the most precise type possible so does not need to
3556-
# be narrowed/included in the output map.
3557-
#
3558-
# TODO: Generalize this to handle the case where 'other_type' is
3559-
# a union of singleton types.
3539+
operand_types = []
3540+
for expr in node.operands:
3541+
if expr not in type_map:
3542+
return {}, {}
3543+
operand_types.append(coerce_to_literal(type_map[expr]))
3544+
3545+
type_maps = []
3546+
for i, (operator, left_expr, right_expr) in enumerate(node.pairwise()):
3547+
left_type = operand_types[i]
3548+
right_type = operand_types[i + 1]
3549+
3550+
if_map = {} # type: TypeMap
3551+
else_map = {} # type: TypeMap
3552+
if operator in {'in', 'not in'}:
3553+
right_item_type = builtin_item_type(right_type)
3554+
if right_item_type is None or is_optional(right_item_type):
3555+
continue
3556+
if (isinstance(right_item_type, Instance)
3557+
and right_item_type.type.fullname() == 'builtins.object'):
3558+
continue
3559+
3560+
if (is_optional(left_type) and literal(left_expr) == LITERAL_TYPE
3561+
and not is_literal_none(left_expr) and
3562+
is_overlapping_erased_types(left_type, right_item_type)):
3563+
if_map, else_map = {left_expr: remove_optional(left_type)}, {}
3564+
else:
3565+
continue
3566+
elif operator in {'==', '!='}:
3567+
if_map, else_map = self.narrow_given_equality(
3568+
left_expr, left_type, right_expr, right_type, assume_identity=False)
3569+
elif operator in {'is', 'is not'}:
3570+
if_map, else_map = self.narrow_given_equality(
3571+
left_expr, left_type, right_expr, right_type, assume_identity=True)
3572+
else:
3573+
continue
35603574

3561-
if isinstance(other_type, LiteralType) and other_type.is_enum_literal():
3562-
fallback_name = other_type.fallback.type.fullname()
3563-
var_type = try_expanding_enum_to_union(var_type, fallback_name)
3575+
if operator in {'not in', '!=', 'is not'}:
3576+
if_map, else_map = else_map, if_map
35643577

3565-
target_type = [TypeRange(other_type, is_upper_bound=False)]
3566-
if_vars, else_vars = conditional_type_map(expr, var_type, target_type)
3567-
break
3578+
type_maps.append((if_map, else_map))
35683579

3569-
if is_not:
3570-
if_vars, else_vars = else_vars, if_vars
3571-
return if_vars, else_vars
3572-
# Check for `x == y` where x is of type Optional[T] and y is of type T
3573-
# or a type that overlaps with T (or vice versa).
3574-
elif node.operators == ['==']:
3575-
first_type = type_map[node.operands[0]]
3576-
second_type = type_map[node.operands[1]]
3577-
if is_optional(first_type) != is_optional(second_type):
3578-
if is_optional(first_type):
3579-
optional_type, comp_type = first_type, second_type
3580-
optional_expr = node.operands[0]
3581-
else:
3582-
optional_type, comp_type = second_type, first_type
3583-
optional_expr = node.operands[1]
3584-
if is_overlapping_erased_types(optional_type, comp_type):
3585-
return {optional_expr: remove_optional(optional_type)}, {}
3586-
elif node.operators in [['in'], ['not in']]:
3587-
expr = node.operands[0]
3588-
left_type = type_map[expr]
3589-
right_type = builtin_item_type(type_map[node.operands[1]])
3590-
right_ok = right_type and (not is_optional(right_type) and
3591-
(not isinstance(right_type, Instance) or
3592-
right_type.type.fullname() != 'builtins.object'))
3593-
if (right_type and right_ok and is_optional(left_type) and
3594-
literal(expr) == LITERAL_TYPE and not is_literal_none(expr) and
3595-
is_overlapping_erased_types(left_type, right_type)):
3596-
if node.operators == ['in']:
3597-
return {expr: remove_optional(left_type)}, {}
3598-
if node.operators == ['not in']:
3599-
return {}, {expr: remove_optional(left_type)}
3580+
if len(type_maps) == 0:
3581+
return {}, {}
3582+
elif len(type_maps) == 1:
3583+
return type_maps[0]
3584+
else:
3585+
# Comparisons like 'a == b == c is d' is the same thing as
3586+
# '(a == b) and (b == c) and (c is d)'. So after generating each
3587+
# individual comparison's typemaps, we "and" them together here.
3588+
# (Also see comments below where we handle the 'and' OpExpr.)
3589+
final_if_map, final_else_map = type_maps[0]
3590+
for if_map, else_map in type_maps[1:]:
3591+
final_if_map = and_conditional_maps(final_if_map, if_map)
3592+
final_else_map = or_conditional_maps(final_else_map, else_map)
3593+
return final_if_map, final_else_map
36003594
elif isinstance(node, RefExpr):
36013595
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
36023596
# respectively
@@ -3630,6 +3624,78 @@ def find_isinstance_check(self, node: Expression
36303624
# Not a supported isinstance check
36313625
return {}, {}
36323626

3627+
def narrow_given_equality(self,
3628+
left_expr: Expression,
3629+
left_type: Type,
3630+
right_expr: Expression,
3631+
right_type: Type,
3632+
assume_identity: bool,
3633+
) -> Tuple[TypeMap, TypeMap]:
3634+
"""Assuming that the given 'left' and 'right' exprs are equal to each other, try
3635+
producing TypeMaps refining the types of either the left or right exprs (or neither,
3636+
if we can't learn anything from the comparison).
3637+
3638+
For more details about what TypeMaps are, see the docstring in find_isinstance_check.
3639+
3640+
If 'assume_identity' is true, assume that this comparison was done using an
3641+
identity comparison (left_expr is right_expr), not just an equality comparison
3642+
(left_expr == right_expr). Identity checks are not overridable, so we can infer
3643+
more information in that case.
3644+
"""
3645+
3646+
# For the sake of simplicity, we currently attempt inferring a more precise type
3647+
# for just one of the two variables.
3648+
comparisons = [
3649+
(left_expr, left_type, right_type),
3650+
(right_expr, right_type, left_type),
3651+
]
3652+
3653+
for expr, expr_type, other_type in comparisons:
3654+
# The 'expr' isn't an expression that we can refine the type of. Skip
3655+
# attempting to refine this expr.
3656+
if literal(expr) != LITERAL_TYPE:
3657+
continue
3658+
3659+
# Case 1: If the 'other_type' is a singleton (only one value has
3660+
# the specified type), attempt to narrow 'expr_type' to just that
3661+
# singleton type.
3662+
if is_singleton_type(other_type):
3663+
if isinstance(other_type, LiteralType) and other_type.is_enum_literal():
3664+
if not assume_identity:
3665+
# Our checks need to be more conservative if the operand is
3666+
# '==' or '!=': all bets are off if either of the two operands
3667+
# has a custom `__eq__` or `__ne__` method.
3668+
#
3669+
# So, we permit this check to succeed only if 'other_type' does
3670+
# not define custom equality logic
3671+
if not uses_default_equality_checks(expr_type):
3672+
continue
3673+
if not uses_default_equality_checks(other_type.fallback):
3674+
continue
3675+
fallback_name = other_type.fallback.type.fullname()
3676+
expr_type = try_expanding_enum_to_union(expr_type, fallback_name)
3677+
3678+
target_type = [TypeRange(other_type, is_upper_bound=False)]
3679+
return conditional_type_map(expr, expr_type, target_type)
3680+
3681+
# Case 2: Given expr_type=Union[A, None] and other_type=A, narrow to just 'A'.
3682+
#
3683+
# Note: This check is actually strictly speaking unsafe: stripping away the 'None'
3684+
# would be unsound in the case where A defines an '__eq__' method that always
3685+
# returns 'True', for example.
3686+
#
3687+
# We implement this check partly for backwards-compatibility reasons and partly
3688+
# because those kinds of degenerate '__eq__' implementations are probably rare
3689+
# enough that this is fine in practice.
3690+
#
3691+
# We could also probably generalize this block to strip away *any* singleton type,
3692+
# if we were fine with a bit more unsoundness.
3693+
if is_optional(expr_type) and not is_optional(other_type):
3694+
if is_overlapping_erased_types(expr_type, other_type):
3695+
return {expr: remove_optional(expr_type)}, {}
3696+
3697+
return {}, {}
3698+
36333699
#
36343700
# Helpers
36353701
#
@@ -4505,6 +4571,32 @@ def is_private(node_name: str) -> bool:
45054571
return node_name.startswith('__') and not node_name.endswith('__')
45064572

45074573

4574+
def uses_default_equality_checks(typ: Type) -> bool:
4575+
"""Returns 'true' if we know for certain that the given type is using
4576+
the default __eq__ and __ne__ checks defined in 'builtins.object'.
4577+
We can use this information to make more aggressive inferences when
4578+
analyzing things like equality checks.
4579+
4580+
When in doubt, this function will conservatively bias towards
4581+
returning False.
4582+
"""
4583+
if isinstance(typ, UnionType):
4584+
return all(map(uses_default_equality_checks, typ.items))
4585+
# TODO: Generalize this so it'll handle other types with fallbacks
4586+
if isinstance(typ, LiteralType):
4587+
typ = typ.fallback
4588+
if isinstance(typ, Instance):
4589+
typeinfo = typ.type
4590+
eq_sym = typeinfo.get('__eq__')
4591+
ne_sym = typeinfo.get('__ne__')
4592+
if eq_sym is None or ne_sym is None:
4593+
return False
4594+
return (eq_sym.fullname == 'builtins.object.__eq__'
4595+
and ne_sym.fullname == 'builtins.object.__ne__')
4596+
else:
4597+
return False
4598+
4599+
45084600
def is_singleton_type(typ: Type) -> bool:
45094601
"""Returns 'true' if this type is a "singleton type" -- if there exists
45104602
exactly only one runtime value associated with this type.

mypy/nodes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,13 +1718,21 @@ class ComparisonExpr(Expression):
17181718

17191719
def __init__(self, operators: List[str], operands: List[Expression]) -> None:
17201720
super().__init__()
1721+
assert len(operators) + 1 == len(operands)
17211722
self.operators = operators
17221723
self.operands = operands
17231724
self.method_types = []
17241725

17251726
def accept(self, visitor: ExpressionVisitor[T]) -> T:
17261727
return visitor.visit_comparison_expr(self)
17271728

1729+
def pairwise(self) -> Iterator[Tuple[str, Expression, Expression]]:
1730+
"""If this comparison expr is "a < b is c == d", yields the sequence
1731+
("<", a, b), ("is", b, c), ("==", c, d)
1732+
"""
1733+
for i, operator in enumerate(self.operators):
1734+
yield operator, self.operands[i], self.operands[i + 1]
1735+
17281736

17291737
class SliceExpr(Expression):
17301738
"""Slice expression (e.g. 'x:y', 'x:', '::2' or ':').

test-data/unit/check-enum.test

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ main:2: note: Revealed type is 'builtins.int'
611611
[out2]
612612
main:2: note: Revealed type is 'builtins.str'
613613

614-
[case testEnumReachabilityChecksBasic]
614+
[case testEnumReachabilityChecksBasicIdentity]
615615
from enum import Enum
616616
from typing_extensions import Literal
617617

@@ -659,6 +659,54 @@ else:
659659
reveal_type(y) # No output here: this branch is unreachable
660660
[builtins fixtures/bool.pyi]
661661

662+
[case testEnumReachabilityChecksBasicEquality]
663+
from enum import Enum
664+
from typing_extensions import Literal
665+
666+
class Foo(Enum):
667+
A = 1
668+
B = 2
669+
C = 3
670+
671+
x: Literal[Foo.A, Foo.B, Foo.C]
672+
if x == Foo.A:
673+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
674+
elif x == Foo.B:
675+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
676+
elif x == Foo.C:
677+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
678+
else:
679+
reveal_type(x) # No output here: this branch is unreachable
680+
681+
if Foo.A == x:
682+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
683+
elif Foo.B == x:
684+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
685+
elif Foo.C == x:
686+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
687+
else:
688+
reveal_type(x) # No output here: this branch is unreachable
689+
690+
y: Foo
691+
if y == Foo.A:
692+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
693+
elif y == Foo.B:
694+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
695+
elif y == Foo.C:
696+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
697+
else:
698+
reveal_type(y) # No output here: this branch is unreachable
699+
700+
if Foo.A == y:
701+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
702+
elif Foo.B == y:
703+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
704+
elif Foo.C == y:
705+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
706+
else:
707+
reveal_type(y) # No output here: this branch is unreachable
708+
[builtins fixtures/bool.pyi]
709+
662710
[case testEnumReachabilityChecksIndirect]
663711
from enum import Enum
664712
from typing_extensions import Literal, Final
@@ -854,3 +902,81 @@ def process(response: Union[str, Reason] = '') -> str:
854902
return 'PROCESSED: ' + response
855903

856904
[builtins fixtures/primitives.pyi]
905+
906+
[case testEnumReachabilityDisabledGivenCustomEquality]
907+
from typing import Union
908+
from enum import Enum
909+
910+
class Parent(Enum):
911+
def __ne__(self, other: object) -> bool: return True
912+
913+
class Foo(Enum):
914+
A = 1
915+
B = 2
916+
def __eq__(self, other: object) -> bool: return True
917+
918+
class Bar(Parent):
919+
A = 1
920+
B = 2
921+
922+
class Ok(Enum):
923+
A = 1
924+
B = 2
925+
926+
x: Foo
927+
if x is Foo.A:
928+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
929+
else:
930+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
931+
932+
if x == Foo.A:
933+
reveal_type(x) # N: Revealed type is '__main__.Foo'
934+
else:
935+
reveal_type(x) # N: Revealed type is '__main__.Foo'
936+
937+
y: Bar
938+
if y is Bar.A:
939+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Bar.A]'
940+
else:
941+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Bar.B]'
942+
943+
if y == Bar.A:
944+
reveal_type(y) # N: Revealed type is '__main__.Bar'
945+
else:
946+
reveal_type(y) # N: Revealed type is '__main__.Bar'
947+
948+
z1: Union[Bar, Ok]
949+
if z1 is Ok.A:
950+
reveal_type(z1) # N: Revealed type is 'Literal[__main__.Ok.A]'
951+
else:
952+
reveal_type(z1) # N: Revealed type is 'Union[__main__.Bar, Literal[__main__.Ok.B]]'
953+
954+
z2: Union[Bar, Ok]
955+
if z2 == Ok.A:
956+
reveal_type(z2) # N: Revealed type is 'Union[__main__.Bar, __main__.Ok]'
957+
else:
958+
reveal_type(z2) # N: Revealed type is 'Union[__main__.Bar, __main__.Ok]'
959+
[builtins fixtures/primitives.pyi]
960+
961+
[case testEnumReachabilityWithChaining]
962+
from enum import Enum
963+
class Foo(Enum):
964+
A = 1
965+
B = 2
966+
967+
x: Foo
968+
y: Foo
969+
if x is Foo.A is y:
970+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
971+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
972+
else:
973+
reveal_type(x) # N: Revealed type is '__main__.Foo'
974+
reveal_type(y) # N: Revealed type is '__main__.Foo'
975+
976+
if x == Foo.A == y:
977+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
978+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
979+
else:
980+
reveal_type(x) # N: Revealed type is '__main__.Foo'
981+
reveal_type(y) # N: Revealed type is '__main__.Foo'
982+
[builtins fixtures/primitives.pyi]

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy