Skip to content

Commit bd94bcb

Browse files
authored
Try simple-minded call expression cache (#19505)
This gives a modest 1% improvement on self-check (compiled), but it gives almost 40% on `mypy -c "import colour"`. Some comments: * I only cache `CallExpr`, `ListExpr`, and `TupleExpr`, this is not very principled, I found this as a best balance between rare cases like `colour`, and more common cases like self-check. * Caching is fragile within lambdas, so I simply disable it, it rarely matters anyway. * I cache both messages and the type map, surprisingly the latter only affects couple test cases, but I still do this generally for peace of mind. * It looks like there are only three things that require cache invalidation: binder, partial types, and deferrals. In general, this is a bit scary (as this a major change), but also perf improvements for slow libraries are very tempting.
1 parent e40c36c commit bd94bcb

File tree

6 files changed

+84
-5
lines changed

6 files changed

+84
-5
lines changed

mypy/binder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ def __init__(self, options: Options) -> None:
138138
# flexible inference of variable types (--allow-redefinition-new).
139139
self.bind_all = options.allow_redefinition_new
140140

141+
# This tracks any externally visible changes in binder to invalidate
142+
# expression caches when needed.
143+
self.version = 0
144+
141145
def _get_id(self) -> int:
142146
self.next_id += 1
143147
return self.next_id
@@ -158,6 +162,7 @@ def push_frame(self, conditional_frame: bool = False) -> Frame:
158162
return f
159163

160164
def _put(self, key: Key, type: Type, from_assignment: bool, index: int = -1) -> None:
165+
self.version += 1
161166
self.frames[index].types[key] = CurrentType(type, from_assignment)
162167

163168
def _get(self, key: Key, index: int = -1) -> CurrentType | None:
@@ -185,6 +190,7 @@ def put(self, expr: Expression, typ: Type, *, from_assignment: bool = True) -> N
185190
self._put(key, typ, from_assignment)
186191

187192
def unreachable(self) -> None:
193+
self.version += 1
188194
self.frames[-1].unreachable = True
189195

190196
def suppress_unreachable_warnings(self) -> None:

mypy/checker.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,6 @@ def reset(self) -> None:
449449
self.binder = ConditionalTypeBinder(self.options)
450450
self._type_maps[1:] = []
451451
self._type_maps[0].clear()
452-
self.temp_type_map = None
453452
self.expr_checker.reset()
454453
self.deferred_nodes = []
455454
self.partial_types = []
@@ -3024,6 +3023,8 @@ def visit_block(self, b: Block) -> None:
30243023
break
30253024
else:
30263025
self.accept(s)
3026+
# Clear expression cache after each statement to avoid unlimited growth.
3027+
self.expr_checker.expr_cache.clear()
30273028

30283029
def should_report_unreachable_issues(self) -> bool:
30293030
return (
@@ -4005,7 +4006,7 @@ def check_multi_assignment_from_union(
40054006
for t, lv in zip(transposed, self.flatten_lvalues(lvalues)):
40064007
# We can access _type_maps directly since temporary type maps are
40074008
# only created within expressions.
4008-
t.append(self._type_maps[0].pop(lv, AnyType(TypeOfAny.special_form)))
4009+
t.append(self._type_maps[-1].pop(lv, AnyType(TypeOfAny.special_form)))
40094010
union_types = tuple(make_simplified_union(col) for col in transposed)
40104011
for expr, items in assignments.items():
40114012
# Bind a union of types collected in 'assignments' to every expression.
@@ -4664,6 +4665,8 @@ def replace_partial_type(
46644665
) -> None:
46654666
"""Replace the partial type of var with a non-partial type."""
46664667
var.type = new_type
4668+
# Updating a partial type should invalidate expression caches.
4669+
self.binder.version += 1
46674670
del partial_types[var]
46684671
if self.options.allow_redefinition_new:
46694672
# When using --allow-redefinition-new, binder tracks all types of

mypy/checkexpr.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from mypy.checkmember import analyze_member_access, has_operator
2020
from mypy.checkstrformat import StringFormatterChecker
2121
from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars
22-
from mypy.errors import ErrorWatcher, report_internal_error
22+
from mypy.errors import ErrorInfo, ErrorWatcher, report_internal_error
2323
from mypy.expandtype import (
2424
expand_type,
2525
expand_type_by_instance,
@@ -355,9 +355,15 @@ def __init__(
355355
type_state.infer_polymorphic = not self.chk.options.old_type_inference
356356

357357
self._arg_infer_context_cache = None
358+
self.expr_cache: dict[
359+
tuple[Expression, Type | None],
360+
tuple[int, Type, list[ErrorInfo], dict[Expression, Type]],
361+
] = {}
362+
self.in_lambda_expr = False
358363

359364
def reset(self) -> None:
360365
self.resolved_type = {}
366+
self.expr_cache.clear()
361367

362368
def visit_name_expr(self, e: NameExpr) -> Type:
363369
"""Type check a name expression.
@@ -5402,6 +5408,8 @@ def find_typeddict_context(
54025408

54035409
def visit_lambda_expr(self, e: LambdaExpr) -> Type:
54045410
"""Type check lambda expression."""
5411+
old_in_lambda = self.in_lambda_expr
5412+
self.in_lambda_expr = True
54055413
self.chk.check_default_args(e, body_is_trivial=False)
54065414
inferred_type, type_override = self.infer_lambda_type_using_context(e)
54075415
if not inferred_type:
@@ -5422,6 +5430,7 @@ def visit_lambda_expr(self, e: LambdaExpr) -> Type:
54225430
ret_type = self.accept(e.expr(), allow_none_return=True)
54235431
fallback = self.named_type("builtins.function")
54245432
self.chk.return_types.pop()
5433+
self.in_lambda_expr = old_in_lambda
54255434
return callable_type(e, fallback, ret_type)
54265435
else:
54275436
# Type context available.
@@ -5434,6 +5443,7 @@ def visit_lambda_expr(self, e: LambdaExpr) -> Type:
54345443
self.accept(e.expr(), allow_none_return=True)
54355444
ret_type = self.chk.lookup_type(e.expr())
54365445
self.chk.return_types.pop()
5446+
self.in_lambda_expr = old_in_lambda
54375447
return replace_callable_return_type(inferred_type, ret_type)
54385448

54395449
def infer_lambda_type_using_context(
@@ -5978,6 +5988,24 @@ def accept(
59785988
typ = self.visit_conditional_expr(node, allow_none_return=True)
59795989
elif allow_none_return and isinstance(node, AwaitExpr):
59805990
typ = self.visit_await_expr(node, allow_none_return=True)
5991+
# Deeply nested generic calls can deteriorate performance dramatically.
5992+
# Although in most cases caching makes little difference, in worst case
5993+
# it avoids exponential complexity.
5994+
# We cannot use cache inside lambdas, because they skip immediate type
5995+
# context, and use enclosing one, see infer_lambda_type_using_context().
5996+
# TODO: consider using cache for more expression kinds.
5997+
elif isinstance(node, (CallExpr, ListExpr, TupleExpr)) and not (
5998+
self.in_lambda_expr or self.chk.current_node_deferred
5999+
):
6000+
if (node, type_context) in self.expr_cache:
6001+
binder_version, typ, messages, type_map = self.expr_cache[(node, type_context)]
6002+
if binder_version == self.chk.binder.version:
6003+
self.chk.store_types(type_map)
6004+
self.msg.add_errors(messages)
6005+
else:
6006+
typ = self.accept_maybe_cache(node, type_context=type_context)
6007+
else:
6008+
typ = self.accept_maybe_cache(node, type_context=type_context)
59816009
else:
59826010
typ = node.accept(self)
59836011
except Exception as err:
@@ -6008,6 +6036,21 @@ def accept(
60086036
self.in_expression = False
60096037
return result
60106038

6039+
def accept_maybe_cache(self, node: Expression, type_context: Type | None = None) -> Type:
6040+
binder_version = self.chk.binder.version
6041+
# Micro-optimization: inline local_type_map() as it is somewhat slow in mypyc.
6042+
type_map: dict[Expression, Type] = {}
6043+
self.chk._type_maps.append(type_map)
6044+
with self.msg.filter_errors(filter_errors=True, save_filtered_errors=True) as msg:
6045+
typ = node.accept(self)
6046+
messages = msg.filtered_errors()
6047+
if binder_version == self.chk.binder.version and not self.chk.current_node_deferred:
6048+
self.expr_cache[(node, type_context)] = (binder_version, typ, messages, type_map)
6049+
self.chk._type_maps.pop()
6050+
self.chk.store_types(type_map)
6051+
self.msg.add_errors(messages)
6052+
return typ
6053+
60116054
def named_type(self, name: str) -> Instance:
60126055
"""Return an instance type with type given by the name and no type
60136056
arguments. Alias for TypeChecker.named_type.

mypy/errors.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ class Errors:
390390
# in some cases to avoid reporting huge numbers of errors.
391391
seen_import_error = False
392392

393-
_watchers: list[ErrorWatcher] = []
393+
_watchers: list[ErrorWatcher]
394394

395395
def __init__(
396396
self,
@@ -421,6 +421,7 @@ def initialize(self) -> None:
421421
self.scope = None
422422
self.target_module = None
423423
self.seen_import_error = False
424+
self._watchers = []
424425

425426
def reset(self) -> None:
426427
self.initialize()
@@ -931,7 +932,8 @@ def prefer_simple_messages(self) -> bool:
931932
if self.file in self.ignored_files:
932933
# Errors ignored, so no point generating fancy messages
933934
return True
934-
for _watcher in self._watchers:
935+
if self._watchers:
936+
_watcher = self._watchers[-1]
935937
if _watcher._filter is True and _watcher._filtered is None:
936938
# Errors are filtered
937939
return True

test-data/unit/check-overloading.test

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6801,3 +6801,26 @@ class D(Generic[T]):
68016801
a: D[str] # E: Type argument "str" of "D" must be a subtype of "C"
68026802
reveal_type(a.f(1)) # N: Revealed type is "builtins.int"
68036803
reveal_type(a.f("x")) # N: Revealed type is "builtins.str"
6804+
6805+
[case testMultiAssignFromUnionInOverloadCached]
6806+
from typing import Iterable, overload, Union, Optional
6807+
6808+
@overload
6809+
def always_bytes(str_or_bytes: None) -> None: ...
6810+
@overload
6811+
def always_bytes(str_or_bytes: Union[str, bytes]) -> bytes: ...
6812+
def always_bytes(str_or_bytes: Union[None, str, bytes]) -> Optional[bytes]:
6813+
pass
6814+
6815+
class Headers:
6816+
def __init__(self, iter: Iterable[tuple[bytes, bytes]]) -> None: ...
6817+
6818+
headers: Union[Headers, dict[Union[str, bytes], Union[str, bytes]], Iterable[tuple[bytes, bytes]]]
6819+
6820+
if isinstance(headers, dict):
6821+
headers = Headers(
6822+
(always_bytes(k), always_bytes(v)) for k, v in headers.items()
6823+
)
6824+
6825+
reveal_type(headers) # N: Revealed type is "Union[__main__.Headers, typing.Iterable[tuple[builtins.bytes, builtins.bytes]]]"
6826+
[builtins fixtures/isinstancelist.pyi]

test-data/unit/fixtures/isinstancelist.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class bool(int): pass
2626
class str:
2727
def __add__(self, x: str) -> str: pass
2828
def __getitem__(self, x: int) -> str: pass
29+
class bytes: pass
2930

3031
T = TypeVar('T')
3132
KT = TypeVar('KT')
@@ -52,6 +53,7 @@ class dict(Mapping[KT, VT]):
5253
def __setitem__(self, k: KT, v: VT) -> None: pass
5354
def __iter__(self) -> Iterator[KT]: pass
5455
def update(self, a: Mapping[KT, VT]) -> None: pass
56+
def items(self) -> Iterable[Tuple[KT, VT]]: pass
5557

5658
class set(Generic[T]):
5759
def __iter__(self) -> Iterator[T]: pass

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy