diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 24f0c8c85d61..2b042b4c0c7c 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -233,6 +233,8 @@ "builtins.memoryview", } +POISON_KEY: Final = (-1,) + class TooManyUnions(Exception): """Indicates that we need to stop splitting unions in an attempt @@ -356,7 +358,12 @@ def __init__( self._arg_infer_context_cache = None + self.overload_stack_depth = 0 + self._args_cache: dict[tuple[int, ...], list[Type]] = {} + def reset(self) -> None: + assert self.overload_stack_depth == 0 + assert not self._args_cache self.resolved_type = {} def visit_name_expr(self, e: NameExpr) -> Type: @@ -1613,9 +1620,10 @@ def check_call( object_type, ) elif isinstance(callee, Overloaded): - return self.check_overload_call( - callee, args, arg_kinds, arg_names, callable_name, object_type, context - ) + with self.overload_context(): + return self.check_overload_call( + callee, args, arg_kinds, arg_names, callable_name, object_type, context + ) elif isinstance(callee, AnyType) or not self.chk.in_checked_function(): return self.check_any_type_call(args, callee) elif isinstance(callee, UnionType): @@ -1678,6 +1686,14 @@ def check_call( else: return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error) + @contextmanager + def overload_context(self) -> Iterator[None]: + self.overload_stack_depth += 1 + yield + self.overload_stack_depth -= 1 + if self.overload_stack_depth == 0: + self._args_cache.clear() + def check_callable_call( self, callee: CallableType, @@ -1935,20 +1951,40 @@ def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type: self.msg.unsupported_type_type(item, context) return AnyType(TypeOfAny.from_error) - def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]: + def infer_arg_types_in_empty_context( + self, args: list[Expression], *, allow_cache: bool + ) -> list[Type]: """Infer argument expression types in an empty context. In short, we basically recurse on each argument without considering in what context the argument was called. """ + # We can only use this hack locally while checking a single nested overloaded + # call. This saves a lot of rechecking, but is not generally safe. Cache is + # pruned upon leaving the outermost overload. + can_cache = ( + allow_cache + and POISON_KEY not in self._args_cache + and not any(isinstance(t, TempNode) for t in args) + ) + key = tuple(map(id, args)) + if can_cache and key in self._args_cache: + return self._args_cache[key] res: list[Type] = [] - for arg in args: - arg_type = self.accept(arg) - if has_erased_component(arg_type): - res.append(NoneType()) - else: - res.append(arg_type) + with self.msg.filter_errors(filter_errors=True, save_filtered_errors=True) as w: + for arg in args: + arg_type = self.accept(arg) + if has_erased_component(arg_type): + res.append(NoneType()) + else: + res.append(arg_type) + + if w.has_new_errors(): + self.msg.add_errors(w.filtered_errors()) + elif can_cache: + # Do not cache if new diagnostics were emitted: they may impact parent overload + self._args_cache[key] = res return res def infer_more_unions_for_recursive_type(self, type_context: Type) -> bool: @@ -2712,7 +2748,7 @@ def check_overload_call( """Checks a call to an overloaded function.""" # Normalize unpacked kwargs before checking the call. callee = callee.with_unpacked_kwargs() - arg_types = self.infer_arg_types_in_empty_context(args) + arg_types = self.infer_arg_types_in_empty_context(args, allow_cache=True) # Step 1: Filter call targets to remove ones where the argument counts don't match plausible_targets = self.plausible_overload_call_targets( arg_types, arg_kinds, arg_names, callee @@ -2921,17 +2957,16 @@ def infer_overload_return_type( for typ in plausible_targets: assert self.msg is self.chk.msg - with self.msg.filter_errors() as w: - with self.chk.local_type_map() as m: - ret_type, infer_type = self.check_call( - callee=typ, - args=args, - arg_kinds=arg_kinds, - arg_names=arg_names, - context=context, - callable_name=callable_name, - object_type=object_type, - ) + with self.msg.filter_errors() as w, self.chk.local_type_map() as m: + ret_type, infer_type = self.check_call( + callee=typ, + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=context, + callable_name=callable_name, + object_type=object_type, + ) is_match = not w.has_new_errors() if is_match: # Return early if possible; otherwise record info, so we can @@ -3307,7 +3342,7 @@ def apply_generic_arguments( ) def check_any_type_call(self, args: list[Expression], callee: Type) -> tuple[Type, Type]: - self.infer_arg_types_in_empty_context(args) + self.infer_arg_types_in_empty_context(args, allow_cache=False) callee = get_proper_type(callee) if isinstance(callee, AnyType): return ( @@ -3478,6 +3513,7 @@ def visit_op_expr(self, e: OpExpr) -> Type: return self.strfrm_checker.check_str_interpolation(e.left, e.right) if isinstance(e.left, StrExpr): return self.strfrm_checker.check_str_interpolation(e.left, e.right) + left_type = self.accept(e.left) proper_left_type = get_proper_type(left_type) @@ -4350,6 +4386,9 @@ def check_list_multiply(self, e: OpExpr) -> Type: return result def visit_assignment_expr(self, e: AssignmentExpr) -> Type: + if self.overload_stack_depth > 0: + # Poison cache when we encounter assignments in overloads - they affect the binder. + self._args_cache[POISON_KEY] = [] value = self.accept(e.value) self.chk.check_assignment(e.target, e.value) self.chk.check_final(e) @@ -5405,6 +5444,9 @@ def find_typeddict_context( def visit_lambda_expr(self, e: LambdaExpr) -> Type: """Type check lambda expression.""" + if self.overload_stack_depth > 0: + # Poison cache when we encounter lambdas - it isn't safe to cache their types. + self._args_cache[POISON_KEY] = [] self.chk.check_default_args(e, body_is_trivial=False) inferred_type, type_override = self.infer_lambda_type_using_context(e) if not inferred_type: pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy