-
-
Notifications
You must be signed in to change notification settings - Fork 3k
Refine parent type when narrowing "lookup" expressions #7917
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
14fa27f
b707f3d
d183f1c
2c2506d
0c33147
68b4801
c2b2cce
0c53aad
b4f35ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
This diff adds support for the following pattern: ```python from typing import Enum, List from enum import Enum class Key(Enum): A = 1 B = 2 class Foo: key: Literal[Key.A] blah: List[int] class Bar: key: Literal[Key.B] something: List[str] x: Union[Foo, Bar] if x.key is Key.A: reveal_type(x) # Revealed type is 'Foo' else: reveal_type(x) # Revealed type is 'Bar' ``` In short, when we do `x.key is Key.A`, we "propagate" the information we discovered about `x.key` up one level to refine the type of `x`. We perform this propagation only when `x` is a Union and only when we are doing member or index lookups into instances, typeddicts, namedtuples, and tuples. For indexing operations, we have one additional limitation: we *must* use a literal expression in order for narrowing to work at all. Using Literal types or Final instances won't work; See #7905 for more details. To put it another way, this adds support for tagged unions, I guess. This more or less resolves #7344. We currently don't have support for narrowing based on string or int literals, but that's a separate issue and should be resolved by #7169 (which I resumed work on earlier this week).
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
|
||
from typing import ( | ||
Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Iterable, | ||
Sequence | ||
Mapping, Sequence | ||
) | ||
from typing_extensions import Final | ||
|
||
|
@@ -43,11 +43,13 @@ | |
) | ||
import mypy.checkexpr | ||
from mypy.checkmember import ( | ||
analyze_descriptor_access, type_object_type, | ||
analyze_member_access, analyze_descriptor_access, type_object_type, | ||
) | ||
from mypy.typeops import ( | ||
map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union, | ||
erase_def_to_union_or_bound, erase_to_union_or_bound, | ||
erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal, | ||
try_getting_str_literals_from_type, try_getting_int_literals_from_type, | ||
tuple_fallback, is_singleton_type, try_expanding_enum_to_union, | ||
true_only, false_only, function_type, | ||
) | ||
from mypy import message_registry | ||
|
@@ -72,9 +74,6 @@ | |
from mypy.plugin import Plugin, CheckerPluginInterface | ||
from mypy.sharedparse import BINARY_MAGIC_METHODS | ||
from mypy.scope import Scope | ||
from mypy.typeops import ( | ||
tuple_fallback, coerce_to_literal, is_singleton_type, try_expanding_enum_to_union | ||
) | ||
from mypy import state, errorcodes as codes | ||
from mypy.traverser import has_return_statement, all_return_statements | ||
from mypy.errorcodes import ErrorCode | ||
|
@@ -3709,6 +3708,12 @@ def find_isinstance_check(self, node: Expression | |
|
||
Guaranteed to not return None, None. (But may return {}, {}) | ||
""" | ||
if_map, else_map = self.find_isinstance_check_helper(node) | ||
new_if_map = self.propagate_up_typemap_info(self.type_map, if_map) | ||
new_else_map = self.propagate_up_typemap_info(self.type_map, else_map) | ||
return new_if_map, new_else_map | ||
|
||
def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeMap]: | ||
type_map = self.type_map | ||
if is_true_literal(node): | ||
return {}, None | ||
|
@@ -3835,28 +3840,185 @@ def find_isinstance_check(self, node: Expression | |
else None) | ||
return if_map, else_map | ||
elif isinstance(node, OpExpr) and node.op == 'and': | ||
left_if_vars, left_else_vars = self.find_isinstance_check(node.left) | ||
right_if_vars, right_else_vars = self.find_isinstance_check(node.right) | ||
left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left) | ||
right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right) | ||
|
||
# (e1 and e2) is true if both e1 and e2 are true, | ||
# and false if at least one of e1 and e2 is false. | ||
return (and_conditional_maps(left_if_vars, right_if_vars), | ||
or_conditional_maps(left_else_vars, right_else_vars)) | ||
elif isinstance(node, OpExpr) and node.op == 'or': | ||
left_if_vars, left_else_vars = self.find_isinstance_check(node.left) | ||
right_if_vars, right_else_vars = self.find_isinstance_check(node.right) | ||
left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left) | ||
right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right) | ||
|
||
# (e1 or e2) is true if at least one of e1 or e2 is true, | ||
# and false if both e1 and e2 are false. | ||
return (or_conditional_maps(left_if_vars, right_if_vars), | ||
and_conditional_maps(left_else_vars, right_else_vars)) | ||
elif isinstance(node, UnaryExpr) and node.op == 'not': | ||
left, right = self.find_isinstance_check(node.expr) | ||
left, right = self.find_isinstance_check_helper(node.expr) | ||
return right, left | ||
|
||
# Not a supported isinstance check | ||
return {}, {} | ||
|
||
def propagate_up_typemap_info(self, | ||
existing_types: Mapping[Expression, Type], | ||
new_types: TypeMap) -> TypeMap: | ||
"""Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types. | ||
|
||
Specifically, this function accepts two mappings of expression to original types: | ||
the original mapping (existing_types), and a new mapping (new_types) intended to | ||
update the original. | ||
|
||
This function iterates through new_types and attempts to use the information to try | ||
refining any parent types that happen to be unions. | ||
|
||
For example, suppose there are two types "A = Tuple[int, int]" and "B = Tuple[str, str]". | ||
Next, suppose that 'new_types' specifies the expression 'foo[0]' has a refined type | ||
of 'int' and that 'foo' was previously deduced to be of type Union[A, B]. | ||
|
||
Then, this function will observe that since A[0] is an int and B[0] is not, the type of | ||
'foo' can be further refined from Union[A, B] into just B. | ||
|
||
We perform this kind of "parent narrowing" for member lookup expressions and indexing | ||
expressions into tuples, namedtuples, and typeddicts. This narrowing is also performed | ||
only once, for the immediate parents of any "lookup" expressions in `new_types`. | ||
|
||
We return the newly refined map. This map is guaranteed to be a superset of 'new_types'. | ||
""" | ||
if new_types is None: | ||
return None | ||
output_map = {} | ||
for expr, expr_type in new_types.items(): | ||
# The original inferred type should always be present in the output map, of course | ||
output_map[expr] = expr_type | ||
|
||
# Next, try using this information to refine the parent type, if applicable. | ||
# Note that we currently refine just the immediate parent. | ||
# | ||
# TODO: Should we also try recursively refining any parents of the parents? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should. This was one of the first things I thought of when looking at the PR. Supporting only immediate parents may look arbitrary to a user and potentially cause confusions. Imagine a simple situation: class Model(Generic[T]):
attr: T
class A:
model: Model[int]
class B:
model: Model[str]
x: Union[A, B]:
if isinstance(x.model.attr, int):
... # I would want 'x' to be an 'A' here I think we should just cycle up until the expression is |
||
# | ||
# One quick-and-dirty way of doing this would be to have the caller repeatedly run | ||
# this function until we reach fixpoint; another way would be to modify | ||
# 'refine_parent_type' to run in a loop. Both approaches seem expensive though. | ||
new_mapping = self.refine_parent_type(existing_types, expr, expr_type) | ||
for parent_expr, proposed_parent_type in new_mapping.items(): | ||
# We don't try inferring anything if we've already inferred something for | ||
# the parent expression. | ||
# TODO: Consider picking the narrower type instead of always discarding this? | ||
if parent_expr in new_types: | ||
continue | ||
output_map[parent_expr] = proposed_parent_type | ||
return output_map | ||
|
||
def refine_parent_type(self, | ||
existing_types: Mapping[Expression, Type], | ||
expr: Expression, | ||
expr_type: Type) -> Mapping[Expression, Type]: | ||
"""Checks if the given expr is a 'lookup operation' into a union and refines the parent type | ||
based on the 'expr_type'. | ||
|
||
For more details about what a 'lookup operation' is and how we use the expr_type to refine | ||
the parent type, see the docstring in 'propagate_up_typemap_info'. | ||
""" | ||
|
||
# First, check if this expression is one that's attempting to | ||
# "lookup" some key in the parent type. If so, save the parent type | ||
# and create function that will try replaying the same lookup | ||
# operation against arbitrary types. | ||
if isinstance(expr, MemberExpr): | ||
parent_expr = expr.expr | ||
parent_type = existing_types.get(parent_expr) | ||
member_name = expr.name | ||
|
||
def replay_lookup(new_parent_type: ProperType) -> Optional[Type]: | ||
msg_copy = self.msg.clean_copy() | ||
msg_copy.disable_count = 0 | ||
member_type = analyze_member_access( | ||
name=member_name, | ||
typ=new_parent_type, | ||
context=parent_expr, | ||
is_lvalue=False, | ||
is_super=False, | ||
is_operator=False, | ||
msg=msg_copy, | ||
original_type=new_parent_type, | ||
chk=self, | ||
in_literal_context=False, | ||
) | ||
if msg_copy.is_errors(): | ||
return None | ||
else: | ||
return member_type | ||
elif isinstance(expr, IndexExpr): | ||
parent_expr = expr.base | ||
parent_type = existing_types.get(parent_expr) | ||
|
||
index_type = existing_types.get(expr.index) | ||
if index_type is None: | ||
return {} | ||
|
||
str_literals = try_getting_str_literals_from_type(index_type) | ||
if str_literals is not None: | ||
def replay_lookup(new_parent_type: ProperType) -> Optional[Type]: | ||
if not isinstance(new_parent_type, TypedDictType): | ||
return None | ||
try: | ||
assert str_literals is not None | ||
member_types = [new_parent_type.items[key] for key in str_literals] | ||
except KeyError: | ||
return None | ||
return make_simplified_union(member_types) | ||
else: | ||
int_literals = try_getting_int_literals_from_type(index_type) | ||
if int_literals is not None: | ||
def replay_lookup(new_parent_type: ProperType) -> Optional[Type]: | ||
Michael0x2a marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if not isinstance(new_parent_type, TupleType): | ||
return None | ||
try: | ||
assert int_literals is not None | ||
member_types = [new_parent_type.items[key] for key in int_literals] | ||
except IndexError: | ||
return None | ||
return make_simplified_union(member_types) | ||
else: | ||
return {} | ||
else: | ||
return {} | ||
|
||
# If we somehow didn't previously derive the parent type, abort: | ||
# something went wrong at an earlier stage. | ||
if parent_type is None: | ||
return {} | ||
|
||
# We currently only try refining the parent type if it's a Union. | ||
parent_type = get_proper_type(parent_type) | ||
if not isinstance(parent_type, UnionType): | ||
return {} | ||
|
||
# Take each element in the parent union and replay the original lookup procedure | ||
# to figure out which parents are compatible. | ||
new_parent_types = [] | ||
for item in parent_type.items: | ||
item = get_proper_type(item) | ||
member_type = replay_lookup(item) | ||
if member_type is None: | ||
# We were unable to obtain the member type. So, we give up on refining this | ||
# parent type entirely. | ||
return {} | ||
|
||
if is_overlapping_types(member_type, expr_type): | ||
new_parent_types.append(item) | ||
|
||
# If none of the parent types overlap (if we derived an empty union), something | ||
# went wrong. We should never hit this case, but deriving the uninhabited type or | ||
# reporting an error both seem unhelpful. So we abort. | ||
if not new_parent_types: | ||
return {} | ||
|
||
return {parent_expr: make_simplified_union(new_parent_types)} | ||
|
||
# | ||
# Helpers | ||
# | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,17 +5,17 @@ | |
since these may assume that MROs are ready. | ||
""" | ||
|
||
from typing import cast, Optional, List, Sequence, Set | ||
from typing import cast, Optional, List, Sequence, Set, TypeVar, Type as TypingType | ||
import sys | ||
|
||
from mypy.types import ( | ||
TupleType, Instance, FunctionLike, Type, CallableType, TypeVarDef, Overloaded, | ||
TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, | ||
TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, TypedDictType, | ||
AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types, | ||
copy_type, TypeAliasType | ||
) | ||
from mypy.nodes import ( | ||
FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, TypeVar, ARG_STAR, ARG_STAR2, ARG_POS, | ||
FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, ARG_STAR, ARG_STAR2, ARG_POS, | ||
Expression, StrExpr, Var | ||
) | ||
from mypy.maptype import map_instance_to_supertype | ||
|
@@ -43,6 +43,25 @@ def tuple_fallback(typ: TupleType) -> Instance: | |
return Instance(info, [join_type_list(typ.items)]) | ||
|
||
|
||
def try_getting_instance_fallback(typ: ProperType) -> Optional[Instance]: | ||
"""Returns the Instance fallback for this type if one exists. | ||
|
||
Otherwise, returns None. | ||
""" | ||
if isinstance(typ, Instance): | ||
return typ | ||
elif isinstance(typ, TupleType): | ||
return tuple_fallback(typ) | ||
elif isinstance(typ, TypedDictType): | ||
return typ.fallback | ||
elif isinstance(typ, FunctionLike): | ||
return typ.fallback | ||
elif isinstance(typ, LiteralType): | ||
return typ.fallback | ||
else: | ||
return None | ||
|
||
|
||
def type_object_type_from_function(signature: FunctionLike, | ||
info: TypeInfo, | ||
def_info: TypeInfo, | ||
|
@@ -475,27 +494,66 @@ def try_getting_str_literals(expr: Expression, typ: Type) -> Optional[List[str]] | |
2. 'typ' is a LiteralType containing a string | ||
3. 'typ' is a UnionType containing only LiteralType of strings | ||
""" | ||
typ = get_proper_type(typ) | ||
|
||
if isinstance(expr, StrExpr): | ||
return [expr.value] | ||
|
||
# TODO: See if we can eliminate this function and call the below one directly | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What kind of problems does this cause? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, let's leave this out for a separate PR. |
||
return try_getting_str_literals_from_type(typ) | ||
|
||
|
||
def try_getting_str_literals_from_type(typ: Type) -> Optional[List[str]]: | ||
"""If the given expression or type corresponds to a string Literal | ||
or a union of string Literals, returns a list of the underlying strings. | ||
Otherwise, returns None. | ||
|
||
For example, if we had the type 'Literal["foo", "bar"]' as input, this function | ||
would return a list of strings ["foo", "bar"]. | ||
""" | ||
return try_getting_literals_from_type(typ, str, "builtins.str") | ||
|
||
|
||
def try_getting_int_literals_from_type(typ: Type) -> Optional[List[int]]: | ||
"""If the given expression or type corresponds to an int Literal | ||
or a union of int Literals, returns a list of the underlying ints. | ||
Otherwise, returns None. | ||
|
||
For example, if we had the type 'Literal[1, 2, 3]' as input, this function | ||
would return a list of ints [1, 2, 3]. | ||
""" | ||
return try_getting_literals_from_type(typ, int, "builtins.int") | ||
|
||
|
||
T = TypeVar('T') | ||
|
||
|
||
def try_getting_literals_from_type(typ: Type, | ||
target_literal_type: TypingType[T], | ||
target_fullname: str) -> Optional[List[T]]: | ||
"""If the given expression or type corresponds to a Literal or | ||
union of Literals where the underlying values corresponds to the given | ||
target type, returns a list of those underlying values. Otherwise, | ||
returns None. | ||
""" | ||
typ = get_proper_type(typ) | ||
|
||
if isinstance(typ, Instance) and typ.last_known_value is not None: | ||
possible_literals = [typ.last_known_value] # type: List[Type] | ||
elif isinstance(typ, UnionType): | ||
possible_literals = list(typ.items) | ||
else: | ||
possible_literals = [typ] | ||
|
||
strings = [] | ||
literals = [] # type: List[T] | ||
for lit in get_proper_types(possible_literals): | ||
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == 'builtins.str': | ||
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == target_fullname: | ||
val = lit.value | ||
assert isinstance(val, str) | ||
strings.append(val) | ||
if isinstance(val, target_literal_type): | ||
literals.append(val) | ||
else: | ||
return None | ||
else: | ||
return None | ||
return strings | ||
return literals | ||
|
||
|
||
def get_enum_values(typ: Instance) -> List[str]: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is now outdated.