From 90418dfa909c3139089f2309ee0cc33c3b4c114d Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Mon, 8 Jul 2019 00:53:21 -0700 Subject: [PATCH] Make reachability code understand chained comparisons Currently, our reachability code does not understand how to parse comparisons like `a == b == c`: the `find_isinstance_check` method only attempts to analyze comparisons that contain a single `==`, `is`, or `in` operator. This pull request generalizes that logic so we can support any arbitrary number of comparisons. It also along the way unifies the logic we have for handling `is` and `==` checks: the latter check is now just treated a weaker variation of the former. (Expressions containing `==` may do arbitrary things if the underlying operands contain custom `__eq__` methods.) As a side-effect, this PR adds support for the following: x: Optional[str] if x is 'some-string': # Previously, the revealed type would be Union[str, None] # Now, the revealed type is just 'str' reveal_type(x) else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' We previously supported this narrowing logic when doing equality checks (e.g. when doing `if x == 'some-string'`). As a second side-effect, this PR adds support for the following: class Foo(Enum): A = 1 B = 2 y: Foo if y == Foo.A: reveal_type(y) # N: Revealed type is 'Literal[Foo.A]' else: reveal_type(y) # N: Revealed type is 'Literal[Foo.B]' We previously supported this kind of narrowing only when doing identity checks (e.g. `if y is Foo.A`). To avoid any bad interactions with custom `__eq__` methods, we enable this narrowing check only if both operands do not define custom `__eq__` methods. --- mypy/checker.py | 208 +++++++++++++++++++++-------- mypy/nodes.py | 8 ++ test-data/unit/check-enum.test | 128 +++++++++++++++++- test-data/unit/check-optional.test | 12 ++ test-data/unit/check-tuples.test | 1 + 5 files changed, 298 insertions(+), 59 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 9ea3a9b902d9..a7281cb3b40b 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3536,67 +3536,61 @@ def find_isinstance_check(self, node: Expression vartype = type_map[expr] return self.conditional_callable_type_map(expr, vartype) elif isinstance(node, ComparisonExpr): - operand_types = [coerce_to_literal(type_map[expr]) - for expr in node.operands if expr in type_map] - - is_not = node.operators == ['is not'] - if (is_not or node.operators == ['is']) and len(operand_types) == len(node.operands): - if_vars = {} # type: TypeMap - else_vars = {} # type: TypeMap - - for i, expr in enumerate(node.operands): - var_type = operand_types[i] - other_type = operand_types[1 - i] - - if literal(expr) == LITERAL_TYPE and is_singleton_type(other_type): - # This should only be true at most once: there should be - # exactly two elements in node.operands and if the 'other type' is - # a singleton type, it by definition does not need to be narrowed: - # it already has the most precise type possible so does not need to - # be narrowed/included in the output map. - # - # TODO: Generalize this to handle the case where 'other_type' is - # a union of singleton types. + operand_types = [] + for expr in node.operands: + if expr not in type_map: + return {}, {} + operand_types.append(coerce_to_literal(type_map[expr])) + + type_maps = [] + for i, (operator, left_expr, right_expr) in enumerate(node.pairwise()): + left_type = operand_types[i] + right_type = operand_types[i + 1] + + if_map = {} # type: TypeMap + else_map = {} # type: TypeMap + if operator in {'in', 'not in'}: + right_item_type = builtin_item_type(right_type) + if right_item_type is None or is_optional(right_item_type): + continue + if (isinstance(right_item_type, Instance) + and right_item_type.type.fullname() == 'builtins.object'): + continue + + if (is_optional(left_type) and literal(left_expr) == LITERAL_TYPE + and not is_literal_none(left_expr) and + is_overlapping_erased_types(left_type, right_item_type)): + if_map, else_map = {left_expr: remove_optional(left_type)}, {} + else: + continue + elif operator in {'==', '!='}: + if_map, else_map = self.narrow_given_equality( + left_expr, left_type, right_expr, right_type, assume_identity=False) + elif operator in {'is', 'is not'}: + if_map, else_map = self.narrow_given_equality( + left_expr, left_type, right_expr, right_type, assume_identity=True) + else: + continue - if isinstance(other_type, LiteralType) and other_type.is_enum_literal(): - fallback_name = other_type.fallback.type.fullname() - var_type = try_expanding_enum_to_union(var_type, fallback_name) + if operator in {'not in', '!=', 'is not'}: + if_map, else_map = else_map, if_map - target_type = [TypeRange(other_type, is_upper_bound=False)] - if_vars, else_vars = conditional_type_map(expr, var_type, target_type) - break + type_maps.append((if_map, else_map)) - if is_not: - if_vars, else_vars = else_vars, if_vars - return if_vars, else_vars - # Check for `x == y` where x is of type Optional[T] and y is of type T - # or a type that overlaps with T (or vice versa). - elif node.operators == ['==']: - first_type = type_map[node.operands[0]] - second_type = type_map[node.operands[1]] - if is_optional(first_type) != is_optional(second_type): - if is_optional(first_type): - optional_type, comp_type = first_type, second_type - optional_expr = node.operands[0] - else: - optional_type, comp_type = second_type, first_type - optional_expr = node.operands[1] - if is_overlapping_erased_types(optional_type, comp_type): - return {optional_expr: remove_optional(optional_type)}, {} - elif node.operators in [['in'], ['not in']]: - expr = node.operands[0] - left_type = type_map[expr] - right_type = builtin_item_type(type_map[node.operands[1]]) - right_ok = right_type and (not is_optional(right_type) and - (not isinstance(right_type, Instance) or - right_type.type.fullname() != 'builtins.object')) - if (right_type and right_ok and is_optional(left_type) and - literal(expr) == LITERAL_TYPE and not is_literal_none(expr) and - is_overlapping_erased_types(left_type, right_type)): - if node.operators == ['in']: - return {expr: remove_optional(left_type)}, {} - if node.operators == ['not in']: - return {}, {expr: remove_optional(left_type)} + if len(type_maps) == 0: + return {}, {} + elif len(type_maps) == 1: + return type_maps[0] + else: + # Comparisons like 'a == b == c is d' is the same thing as + # '(a == b) and (b == c) and (c is d)'. So after generating each + # individual comparison's typemaps, we "and" them together here. + # (Also see comments below where we handle the 'and' OpExpr.) + final_if_map, final_else_map = type_maps[0] + for if_map, else_map in type_maps[1:]: + final_if_map = and_conditional_maps(final_if_map, if_map) + final_else_map = or_conditional_maps(final_else_map, else_map) + return final_if_map, final_else_map elif isinstance(node, RefExpr): # Restrict the type of the variable to True-ish/False-ish in the if and else branches # respectively @@ -3630,6 +3624,78 @@ def find_isinstance_check(self, node: Expression # Not a supported isinstance check return {}, {} + def narrow_given_equality(self, + left_expr: Expression, + left_type: Type, + right_expr: Expression, + right_type: Type, + assume_identity: bool, + ) -> Tuple[TypeMap, TypeMap]: + """Assuming that the given 'left' and 'right' exprs are equal to each other, try + producing TypeMaps refining the types of either the left or right exprs (or neither, + if we can't learn anything from the comparison). + + For more details about what TypeMaps are, see the docstring in find_isinstance_check. + + If 'assume_identity' is true, assume that this comparison was done using an + identity comparison (left_expr is right_expr), not just an equality comparison + (left_expr == right_expr). Identity checks are not overridable, so we can infer + more information in that case. + """ + + # For the sake of simplicity, we currently attempt inferring a more precise type + # for just one of the two variables. + comparisons = [ + (left_expr, left_type, right_type), + (right_expr, right_type, left_type), + ] + + for expr, expr_type, other_type in comparisons: + # The 'expr' isn't an expression that we can refine the type of. Skip + # attempting to refine this expr. + if literal(expr) != LITERAL_TYPE: + continue + + # Case 1: If the 'other_type' is a singleton (only one value has + # the specified type), attempt to narrow 'expr_type' to just that + # singleton type. + if is_singleton_type(other_type): + if isinstance(other_type, LiteralType) and other_type.is_enum_literal(): + if not assume_identity: + # Our checks need to be more conservative if the operand is + # '==' or '!=': all bets are off if either of the two operands + # has a custom `__eq__` or `__ne__` method. + # + # So, we permit this check to succeed only if 'other_type' does + # not define custom equality logic + if not uses_default_equality_checks(expr_type): + continue + if not uses_default_equality_checks(other_type.fallback): + continue + fallback_name = other_type.fallback.type.fullname() + expr_type = try_expanding_enum_to_union(expr_type, fallback_name) + + target_type = [TypeRange(other_type, is_upper_bound=False)] + return conditional_type_map(expr, expr_type, target_type) + + # Case 2: Given expr_type=Union[A, None] and other_type=A, narrow to just 'A'. + # + # Note: This check is actually strictly speaking unsafe: stripping away the 'None' + # would be unsound in the case where A defines an '__eq__' method that always + # returns 'True', for example. + # + # We implement this check partly for backwards-compatibility reasons and partly + # because those kinds of degenerate '__eq__' implementations are probably rare + # enough that this is fine in practice. + # + # We could also probably generalize this block to strip away *any* singleton type, + # if we were fine with a bit more unsoundness. + if is_optional(expr_type) and not is_optional(other_type): + if is_overlapping_erased_types(expr_type, other_type): + return {expr: remove_optional(expr_type)}, {} + + return {}, {} + # # Helpers # @@ -4505,6 +4571,32 @@ def is_private(node_name: str) -> bool: return node_name.startswith('__') and not node_name.endswith('__') +def uses_default_equality_checks(typ: Type) -> bool: + """Returns 'true' if we know for certain that the given type is using + the default __eq__ and __ne__ checks defined in 'builtins.object'. + We can use this information to make more aggressive inferences when + analyzing things like equality checks. + + When in doubt, this function will conservatively bias towards + returning False. + """ + if isinstance(typ, UnionType): + return all(map(uses_default_equality_checks, typ.items)) + # TODO: Generalize this so it'll handle other types with fallbacks + if isinstance(typ, LiteralType): + typ = typ.fallback + if isinstance(typ, Instance): + typeinfo = typ.type + eq_sym = typeinfo.get('__eq__') + ne_sym = typeinfo.get('__ne__') + if eq_sym is None or ne_sym is None: + return False + return (eq_sym.fullname == 'builtins.object.__eq__' + and ne_sym.fullname == 'builtins.object.__ne__') + else: + return False + + def is_singleton_type(typ: Type) -> bool: """Returns 'true' if this type is a "singleton type" -- if there exists exactly only one runtime value associated with this type. diff --git a/mypy/nodes.py b/mypy/nodes.py index 44f5fe610039..0ee5f798641f 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1718,6 +1718,7 @@ class ComparisonExpr(Expression): def __init__(self, operators: List[str], operands: List[Expression]) -> None: super().__init__() + assert len(operators) + 1 == len(operands) self.operators = operators self.operands = operands self.method_types = [] @@ -1725,6 +1726,13 @@ def __init__(self, operators: List[str], operands: List[Expression]) -> None: def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_comparison_expr(self) + def pairwise(self) -> Iterator[Tuple[str, Expression, Expression]]: + """If this comparison expr is "a < b is c == d", yields the sequence + ("<", a, b), ("is", b, c), ("==", c, d) + """ + for i, operator in enumerate(self.operators): + yield operator, self.operands[i], self.operands[i + 1] + class SliceExpr(Expression): """Slice expression (e.g. 'x:y', 'x:', '::2' or ':'). diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 9f015f24986c..72975fbe0daa 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -611,7 +611,7 @@ main:2: note: Revealed type is 'builtins.int' [out2] main:2: note: Revealed type is 'builtins.str' -[case testEnumReachabilityChecksBasic] +[case testEnumReachabilityChecksBasicIdentity] from enum import Enum from typing_extensions import Literal @@ -659,6 +659,54 @@ else: reveal_type(y) # No output here: this branch is unreachable [builtins fixtures/bool.pyi] +[case testEnumReachabilityChecksBasicEquality] +from enum import Enum +from typing_extensions import Literal + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +x: Literal[Foo.A, Foo.B, Foo.C] +if x == Foo.A: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif x == Foo.B: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' +elif x == Foo.C: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]' +else: + reveal_type(x) # No output here: this branch is unreachable + +if Foo.A == x: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif Foo.B == x: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' +elif Foo.C == x: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]' +else: + reveal_type(x) # No output here: this branch is unreachable + +y: Foo +if y == Foo.A: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif y == Foo.B: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' +elif y == Foo.C: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]' +else: + reveal_type(y) # No output here: this branch is unreachable + +if Foo.A == y: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif Foo.B == y: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' +elif Foo.C == y: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]' +else: + reveal_type(y) # No output here: this branch is unreachable +[builtins fixtures/bool.pyi] + [case testEnumReachabilityChecksIndirect] from enum import Enum from typing_extensions import Literal, Final @@ -854,3 +902,81 @@ def process(response: Union[str, Reason] = '') -> str: return 'PROCESSED: ' + response [builtins fixtures/primitives.pyi] + +[case testEnumReachabilityDisabledGivenCustomEquality] +from typing import Union +from enum import Enum + +class Parent(Enum): + def __ne__(self, other: object) -> bool: return True + +class Foo(Enum): + A = 1 + B = 2 + def __eq__(self, other: object) -> bool: return True + +class Bar(Parent): + A = 1 + B = 2 + +class Ok(Enum): + A = 1 + B = 2 + +x: Foo +if x is Foo.A: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' + +if x == Foo.A: + reveal_type(x) # N: Revealed type is '__main__.Foo' +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' + +y: Bar +if y is Bar.A: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Bar.A]' +else: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Bar.B]' + +if y == Bar.A: + reveal_type(y) # N: Revealed type is '__main__.Bar' +else: + reveal_type(y) # N: Revealed type is '__main__.Bar' + +z1: Union[Bar, Ok] +if z1 is Ok.A: + reveal_type(z1) # N: Revealed type is 'Literal[__main__.Ok.A]' +else: + reveal_type(z1) # N: Revealed type is 'Union[__main__.Bar, Literal[__main__.Ok.B]]' + +z2: Union[Bar, Ok] +if z2 == Ok.A: + reveal_type(z2) # N: Revealed type is 'Union[__main__.Bar, __main__.Ok]' +else: + reveal_type(z2) # N: Revealed type is 'Union[__main__.Bar, __main__.Ok]' +[builtins fixtures/primitives.pyi] + +[case testEnumReachabilityWithChaining] +from enum import Enum +class Foo(Enum): + A = 1 + B = 2 + +x: Foo +y: Foo +if x is Foo.A is y: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is '__main__.Foo' + +if x == Foo.A == y: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is '__main__.Foo' +[builtins fixtures/primitives.pyi] diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 8f5313870e27..7e17f8787ea2 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -485,6 +485,10 @@ if x == '': reveal_type(x) # N: Revealed type is 'builtins.str' else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' +if x is '': + reveal_type(x) # N: Revealed type is 'builtins.str' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' [builtins fixtures/ops.pyi] [case testInferEqualsNotOptionalWithUnion] @@ -494,6 +498,10 @@ if x == '': reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int]' else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' +if x is '': + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int]' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' [builtins fixtures/ops.pyi] [case testInferEqualsNotOptionalWithOverlap] @@ -503,6 +511,10 @@ if x == object(): reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int]' else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' +if x is object(): + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int]' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' [builtins fixtures/ops.pyi] [case testInferEqualsStillOptionalWithNoOverlap] diff --git a/test-data/unit/check-tuples.test b/test-data/unit/check-tuples.test index e7f240e91926..4b1abe22347b 100644 --- a/test-data/unit/check-tuples.test +++ b/test-data/unit/check-tuples.test @@ -1196,6 +1196,7 @@ x = y reveal_type(x) # N: Revealed type is 'Tuple[builtins.int, builtins.int]' [case testTupleOverlapDifferentTuples] +# flags: --strict-optional from typing import Optional, Tuple class A: pass class B: pass pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy