diff --git a/mypy/checker.py b/mypy/checker.py index 45a8b2e24bb9..9ea3a9b902d9 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2,6 +2,7 @@ import itertools import fnmatch +import sys from contextlib import contextmanager from typing import ( @@ -3535,21 +3536,34 @@ def find_isinstance_check(self, node: Expression vartype = type_map[expr] return self.conditional_callable_type_map(expr, vartype) elif isinstance(node, ComparisonExpr): - # Check for `x is None` and `x is not None`. + operand_types = [coerce_to_literal(type_map[expr]) + for expr in node.operands if expr in type_map] + is_not = node.operators == ['is not'] - if any(is_literal_none(n) for n in node.operands) and ( - is_not or node.operators == ['is']): + if (is_not or node.operators == ['is']) and len(operand_types) == len(node.operands): if_vars = {} # type: TypeMap else_vars = {} # type: TypeMap - for expr in node.operands: - if (literal(expr) == LITERAL_TYPE and not is_literal_none(expr) - and expr in type_map): + + for i, expr in enumerate(node.operands): + var_type = operand_types[i] + other_type = operand_types[1 - i] + + if literal(expr) == LITERAL_TYPE and is_singleton_type(other_type): # This should only be true at most once: there should be - # two elements in node.operands, and at least one of them - # should represent a None. - vartype = type_map[expr] - none_typ = [TypeRange(NoneType(), is_upper_bound=False)] - if_vars, else_vars = conditional_type_map(expr, vartype, none_typ) + # exactly two elements in node.operands and if the 'other type' is + # a singleton type, it by definition does not need to be narrowed: + # it already has the most precise type possible so does not need to + # be narrowed/included in the output map. + # + # TODO: Generalize this to handle the case where 'other_type' is + # a union of singleton types. + + if isinstance(other_type, LiteralType) and other_type.is_enum_literal(): + fallback_name = other_type.fallback.type.fullname() + var_type = try_expanding_enum_to_union(var_type, fallback_name) + + target_type = [TypeRange(other_type, is_upper_bound=False)] + if_vars, else_vars = conditional_type_map(expr, var_type, target_type) break if is_not: @@ -4489,3 +4503,78 @@ def is_overlapping_types_no_promote(left: Type, right: Type) -> bool: def is_private(node_name: str) -> bool: """Check if node is private to class definition.""" return node_name.startswith('__') and not node_name.endswith('__') + + +def is_singleton_type(typ: Type) -> bool: + """Returns 'true' if this type is a "singleton type" -- if there exists + exactly only one runtime value associated with this type. + + That is, given two values 'a' and 'b' that have the same type 't', + 'is_singleton_type(t)' returns True if and only if the expression 'a is b' is + always true. + + Currently, this returns True when given NoneTypes and enum LiteralTypes. + + Note that other kinds of LiteralTypes cannot count as singleton types. For + example, suppose we do 'a = 100000 + 1' and 'b = 100001'. It is not guaranteed + that 'a is b' will always be true -- some implementations of Python will end up + constructing two distinct instances of 100001. + """ + # TODO: Also make this return True if the type is a bool LiteralType. + # Also make this return True if the type corresponds to ... (ellipsis) or NotImplemented? + return isinstance(typ, NoneType) or (isinstance(typ, LiteralType) and typ.is_enum_literal()) + + +def try_expanding_enum_to_union(typ: Type, target_fullname: str) -> Type: + """Attempts to recursively expand any enum Instances with the given target_fullname + into a Union of all of its component LiteralTypes. + + For example, if we have: + + class Color(Enum): + RED = 1 + BLUE = 2 + YELLOW = 3 + + class Status(Enum): + SUCCESS = 1 + FAILURE = 2 + UNKNOWN = 3 + + ...and if we call `try_expanding_enum_to_union(Union[Color, Status], 'module.Color')`, + this function will return Literal[Color.RED, Color.BLUE, Color.YELLOW, Status]. + """ + if isinstance(typ, UnionType): + new_items = [try_expanding_enum_to_union(item, target_fullname) + for item in typ.items] + return UnionType.make_simplified_union(new_items) + elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname() == target_fullname: + new_items = [] + for name, symbol in typ.type.names.items(): + if not isinstance(symbol.node, Var): + continue + new_items.append(LiteralType(name, typ)) + # SymbolTables are really just dicts, and dicts are guaranteed to preserve + # insertion order only starting with Python 3.7. So, we sort these for older + # versions of Python to help make tests deterministic. + # + # We could probably skip the sort for Python 3.6 since people probably run mypy + # only using CPython, but we might as well for the sake of full correctness. + if sys.version_info < (3, 7): + new_items.sort(key=lambda lit: lit.value) + return UnionType.make_simplified_union(new_items) + else: + return typ + + +def coerce_to_literal(typ: Type) -> Type: + """Recursively converts any Instances that have a last_known_value into the + corresponding LiteralType. + """ + if isinstance(typ, UnionType): + new_items = [coerce_to_literal(item) for item in typ.items] + return UnionType.make_simplified_union(new_items) + elif isinstance(typ, Instance) and typ.last_known_value: + return typ.last_known_value + else: + return typ diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 431f0c9b241f..9f015f24986c 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -610,3 +610,247 @@ class SomeEnum(Enum): main:2: note: Revealed type is 'builtins.int' [out2] main:2: note: Revealed type is 'builtins.str' + +[case testEnumReachabilityChecksBasic] +from enum import Enum +from typing_extensions import Literal + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +x: Literal[Foo.A, Foo.B, Foo.C] +if x is Foo.A: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif x is Foo.B: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' +elif x is Foo.C: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]' +else: + reveal_type(x) # No output here: this branch is unreachable + +if Foo.A is x: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif Foo.B is x: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' +elif Foo.C is x: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]' +else: + reveal_type(x) # No output here: this branch is unreachable + +y: Foo +if y is Foo.A: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif y is Foo.B: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' +elif y is Foo.C: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]' +else: + reveal_type(y) # No output here: this branch is unreachable + +if Foo.A is y: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif Foo.B is y: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' +elif Foo.C is y: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]' +else: + reveal_type(y) # No output here: this branch is unreachable +[builtins fixtures/bool.pyi] + +[case testEnumReachabilityChecksIndirect] +from enum import Enum +from typing_extensions import Literal, Final + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +def accepts_foo_a(x: Literal[Foo.A]) -> None: ... + +x: Foo +y: Literal[Foo.A] +z: Final = Foo.A + +if x is y: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +if y is x: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' + +if x is z: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(z) # N: Revealed type is '__main__.Foo' + accepts_foo_a(z) +else: + reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' + reveal_type(z) # N: Revealed type is '__main__.Foo' + accepts_foo_a(z) +if z is x: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(z) # N: Revealed type is '__main__.Foo' + accepts_foo_a(z) +else: + reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' + reveal_type(z) # N: Revealed type is '__main__.Foo' + accepts_foo_a(z) + +if y is z: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(z) # N: Revealed type is '__main__.Foo' + accepts_foo_a(z) +else: + reveal_type(y) # No output: this branch is unreachable + reveal_type(z) # No output: this branch is unreachable +if z is y: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(z) # N: Revealed type is '__main__.Foo' + accepts_foo_a(z) +else: + reveal_type(y) # No output: this branch is unreachable + reveal_type(z) # No output: this branch is unreachable +[builtins fixtures/bool.pyi] + +[case testEnumReachabilityNoNarrowingForUnionMessiness] +from enum import Enum +from typing_extensions import Literal + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +x: Foo +y: Literal[Foo.A, Foo.B] +z: Literal[Foo.B, Foo.C] + +# For the sake of simplicity, no narrowing is done when the narrower type is a Union. +if x is y: + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' + +if y is z: + reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' + reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' +else: + reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' + reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' +[builtins fixtures/bool.pyi] + +[case testEnumReachabilityWithNone] +# flags: --strict-optional +from enum import Enum +from typing import Optional + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +x: Optional[Foo] +if x: + reveal_type(x) # N: Revealed type is '__main__.Foo' +else: + reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]' + +if x is not None: + reveal_type(x) # N: Revealed type is '__main__.Foo' +else: + reveal_type(x) # N: Revealed type is 'None' + +if x is Foo.A: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C], None]' +[builtins fixtures/bool.pyi] + +[case testEnumReachabilityWithMultipleEnums] +from enum import Enum +from typing import Union +from typing_extensions import Literal + +class Foo(Enum): + A = 1 + B = 2 +class Bar(Enum): + A = 1 + B = 2 + +x1: Union[Foo, Bar] +if x1 is Foo.A: + reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], __main__.Bar]' + +x2: Union[Foo, Bar] +if x2 is Bar.A: + reveal_type(x2) # N: Revealed type is 'Literal[__main__.Bar.A]' +else: + reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, Literal[__main__.Bar.B]]' + +x3: Union[Foo, Bar] +if x3 is Foo.A or x3 is Bar.A: + reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Bar.A]]' +else: + reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Bar.B]]' + +[builtins fixtures/bool.pyi] + +[case testEnumReachabilityPEP484Example1] +# flags: --strict-optional +from typing import Union +from typing_extensions import Final +from enum import Enum + +class Empty(Enum): + token = 0 +_empty: Final = Empty.token + +def func(x: Union[int, None, Empty] = _empty) -> int: + boom = x + 42 # E: Unsupported left operand type for + ("None") \ + # E: Unsupported left operand type for + ("Empty") \ + # N: Left operand is of type "Union[int, None, Empty]" + if x is _empty: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]' + return 0 + elif x is None: + reveal_type(x) # N: Revealed type is 'None' + return 1 + else: # At this point typechecker knows that x can only have type int + reveal_type(x) # N: Revealed type is 'builtins.int' + return x + 2 +[builtins fixtures/primitives.pyi] + +[case testEnumReachabilityPEP484Example2] +from typing import Union +from enum import Enum + +class Reason(Enum): + timeout = 1 + error = 2 + +def process(response: Union[str, Reason] = '') -> str: + if response is Reason.timeout: + reveal_type(response) # N: Revealed type is 'Literal[__main__.Reason.timeout]' + return 'TIMEOUT' + elif response is Reason.error: + reveal_type(response) # N: Revealed type is 'Literal[__main__.Reason.error]' + return 'ERROR' + else: + # response can be only str, all other possible values exhausted + reveal_type(response) # N: Revealed type is 'builtins.str' + return 'PROCESSED: ' + response + +[builtins fixtures/primitives.pyi]
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: