Skip to content

Commit 3acbf3f

Browse files
sobolevnmsullivan
andauthored
Adds get_function_signature_hook (python#9102)
This PR introduces get_function_signature_hook that behaves the similar way as get_method_signature_hook. Closes python#9101 Co-authored-by: Michael Sullivan <sully@msully.net>
1 parent e4131a5 commit 3acbf3f

File tree

4 files changed

+118
-21
lines changed

4 files changed

+118
-21
lines changed

mypy/checkexpr.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@
5656
from mypy.util import split_module_names
5757
from mypy.typevars import fill_typevars
5858
from mypy.visitor import ExpressionVisitor
59-
from mypy.plugin import Plugin, MethodContext, MethodSigContext, FunctionContext
59+
from mypy.plugin import (
60+
Plugin,
61+
MethodContext, MethodSigContext,
62+
FunctionContext, FunctionSigContext,
63+
)
6064
from mypy.typeops import (
6165
tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound,
6266
function_type, callable_type, try_getting_str_literals, custom_special_method,
@@ -730,12 +734,15 @@ def apply_function_plugin(self,
730734
callee.arg_names, formal_arg_names,
731735
callee.ret_type, formal_arg_exprs, context, self.chk))
732736

733-
def apply_method_signature_hook(
737+
def apply_signature_hook(
734738
self, callee: FunctionLike, args: List[Expression],
735-
arg_kinds: List[int], context: Context,
736-
arg_names: Optional[Sequence[Optional[str]]], object_type: Type,
737-
signature_hook: Callable[[MethodSigContext], CallableType]) -> FunctionLike:
738-
"""Apply a plugin hook that may infer a more precise signature for a method."""
739+
arg_kinds: List[int],
740+
arg_names: Optional[Sequence[Optional[str]]],
741+
hook: Callable[
742+
[List[List[Expression]], CallableType],
743+
CallableType,
744+
]) -> FunctionLike:
745+
"""Helper to apply a signature hook for either a function or method"""
739746
if isinstance(callee, CallableType):
740747
num_formals = len(callee.arg_kinds)
741748
formal_to_actual = map_actuals_to_formals(
@@ -746,19 +753,40 @@ def apply_method_signature_hook(
746753
for formal, actuals in enumerate(formal_to_actual):
747754
for actual in actuals:
748755
formal_arg_exprs[formal].append(args[actual])
749-
object_type = get_proper_type(object_type)
750-
return signature_hook(
751-
MethodSigContext(object_type, formal_arg_exprs, callee, context, self.chk))
756+
return hook(formal_arg_exprs, callee)
752757
else:
753758
assert isinstance(callee, Overloaded)
754759
items = []
755760
for item in callee.items():
756-
adjusted = self.apply_method_signature_hook(
757-
item, args, arg_kinds, context, arg_names, object_type, signature_hook)
761+
adjusted = self.apply_signature_hook(
762+
item, args, arg_kinds, arg_names, hook)
758763
assert isinstance(adjusted, CallableType)
759764
items.append(adjusted)
760765
return Overloaded(items)
761766

767+
def apply_function_signature_hook(
768+
self, callee: FunctionLike, args: List[Expression],
769+
arg_kinds: List[int], context: Context,
770+
arg_names: Optional[Sequence[Optional[str]]],
771+
signature_hook: Callable[[FunctionSigContext], CallableType]) -> FunctionLike:
772+
"""Apply a plugin hook that may infer a more precise signature for a function."""
773+
return self.apply_signature_hook(
774+
callee, args, arg_kinds, arg_names,
775+
(lambda args, sig:
776+
signature_hook(FunctionSigContext(args, sig, context, self.chk))))
777+
778+
def apply_method_signature_hook(
779+
self, callee: FunctionLike, args: List[Expression],
780+
arg_kinds: List[int], context: Context,
781+
arg_names: Optional[Sequence[Optional[str]]], object_type: Type,
782+
signature_hook: Callable[[MethodSigContext], CallableType]) -> FunctionLike:
783+
"""Apply a plugin hook that may infer a more precise signature for a method."""
784+
pobject_type = get_proper_type(object_type)
785+
return self.apply_signature_hook(
786+
callee, args, arg_kinds, arg_names,
787+
(lambda args, sig:
788+
signature_hook(MethodSigContext(pobject_type, args, sig, context, self.chk))))
789+
762790
def transform_callee_type(
763791
self, callable_name: Optional[str], callee: Type, args: List[Expression],
764792
arg_kinds: List[int], context: Context,
@@ -779,13 +807,17 @@ def transform_callee_type(
779807
(if appropriate) before the signature is passed to check_call.
780808
"""
781809
callee = get_proper_type(callee)
782-
if (callable_name is not None
783-
and object_type is not None
784-
and isinstance(callee, FunctionLike)):
785-
signature_hook = self.plugin.get_method_signature_hook(callable_name)
786-
if signature_hook:
787-
return self.apply_method_signature_hook(
788-
callee, args, arg_kinds, context, arg_names, object_type, signature_hook)
810+
if callable_name is not None and isinstance(callee, FunctionLike):
811+
if object_type is not None:
812+
method_sig_hook = self.plugin.get_method_signature_hook(callable_name)
813+
if method_sig_hook:
814+
return self.apply_method_signature_hook(
815+
callee, args, arg_kinds, context, arg_names, object_type, method_sig_hook)
816+
else:
817+
function_sig_hook = self.plugin.get_function_signature_hook(callable_name)
818+
if function_sig_hook:
819+
return self.apply_function_signature_hook(
820+
callee, args, arg_kinds, context, arg_names, function_sig_hook)
789821

790822
return callee
791823

mypy/plugin.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,16 @@ def final_iteration(self) -> bool:
365365
('is_check', bool) # Is this invocation for checking whether the config matches
366366
])
367367

368+
# A context for a function signature hook that infers a better signature for a
369+
# function. Note that argument types aren't available yet. If you need them,
370+
# you have to use a method hook instead.
371+
FunctionSigContext = NamedTuple(
372+
'FunctionSigContext', [
373+
('args', List[List[Expression]]), # Actual expressions for each formal argument
374+
('default_signature', CallableType), # Original signature of the method
375+
('context', Context), # Relevant location context (e.g. for error messages)
376+
('api', CheckerPluginInterface)])
377+
368378
# A context for a function hook that infers the return type of a function with
369379
# a special signature.
370380
#
@@ -395,7 +405,7 @@ def final_iteration(self) -> bool:
395405
# TODO: document ProperType in the plugin changelog/update issue.
396406
MethodSigContext = NamedTuple(
397407
'MethodSigContext', [
398-
('type', ProperType), # Base object type for method call
408+
('type', ProperType), # Base object type for method call
399409
('args', List[List[Expression]]), # Actual expressions for each formal argument
400410
('default_signature', CallableType), # Original signature of the method
401411
('context', Context), # Relevant location context (e.g. for error messages)
@@ -407,7 +417,7 @@ def final_iteration(self) -> bool:
407417
# This is very similar to FunctionContext (only differences are documented).
408418
MethodContext = NamedTuple(
409419
'MethodContext', [
410-
('type', ProperType), # Base object type for method call
420+
('type', ProperType), # Base object type for method call
411421
('arg_types', List[List[Type]]), # List of actual caller types for each formal argument
412422
# see FunctionContext for details about names and kinds
413423
('arg_kinds', List[List[int]]),
@@ -421,7 +431,7 @@ def final_iteration(self) -> bool:
421431
# A context for an attribute type hook that infers the type of an attribute.
422432
AttributeContext = NamedTuple(
423433
'AttributeContext', [
424-
('type', ProperType), # Type of object with attribute
434+
('type', ProperType), # Type of object with attribute
425435
('default_attr_type', Type), # Original attribute type
426436
('context', Context), # Relevant location context (e.g. for error messages)
427437
('api', CheckerPluginInterface)])
@@ -533,6 +543,22 @@ def func(x: Other[int]) -> None:
533543
"""
534544
return None
535545

546+
def get_function_signature_hook(self, fullname: str
547+
) -> Optional[Callable[[FunctionSigContext], CallableType]]:
548+
"""Adjust the signature a function.
549+
550+
This method is called before type checking a function call. Plugin
551+
may infer a better type for the function.
552+
553+
from lib import Class, do_stuff
554+
555+
do_stuff(42)
556+
Class()
557+
558+
This method will be called with 'lib.do_stuff' and then with 'lib.Class'.
559+
"""
560+
return None
561+
536562
def get_function_hook(self, fullname: str
537563
) -> Optional[Callable[[FunctionContext], Type]]:
538564
"""Adjust the return type of a function call.
@@ -721,6 +747,10 @@ def get_type_analyze_hook(self, fullname: str
721747
) -> Optional[Callable[[AnalyzeTypeContext], Type]]:
722748
return self._find_hook(lambda plugin: plugin.get_type_analyze_hook(fullname))
723749

750+
def get_function_signature_hook(self, fullname: str
751+
) -> Optional[Callable[[FunctionSigContext], CallableType]]:
752+
return self._find_hook(lambda plugin: plugin.get_function_signature_hook(fullname))
753+
724754
def get_function_hook(self, fullname: str
725755
) -> Optional[Callable[[FunctionContext], Type]]:
726756
return self._find_hook(lambda plugin: plugin.get_function_hook(fullname))

test-data/unit/check-custom-plugin.test

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,3 +721,12 @@ Cls().attr = "foo" # E: Incompatible types in assignment (expression has type "
721721
[file mypy.ini]
722722
\[mypy]
723723
plugins=<ROOT>/test-data/unit/plugins/descriptor.py
724+
725+
[case testFunctionSigPluginFile]
726+
# flags: --config-file tmp/mypy.ini
727+
728+
def dynamic_signature(arg1: str) -> str: ...
729+
reveal_type(dynamic_signature(1)) # N: Revealed type is 'builtins.int'
730+
[file mypy.ini]
731+
\[mypy]
732+
plugins=<ROOT>/test-data/unit/plugins/function_sig_hook.py
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from mypy.plugin import CallableType, CheckerPluginInterface, FunctionSigContext, Plugin
2+
from mypy.types import Instance, Type
3+
4+
class FunctionSigPlugin(Plugin):
5+
def get_function_signature_hook(self, fullname):
6+
if fullname == '__main__.dynamic_signature':
7+
return my_hook
8+
return None
9+
10+
def _str_to_int(api: CheckerPluginInterface, typ: Type) -> Type:
11+
if isinstance(typ, Instance):
12+
if typ.type.fullname == 'builtins.str':
13+
return api.named_generic_type('builtins.int', [])
14+
elif typ.args:
15+
return typ.copy_modified(args=[_str_to_int(api, t) for t in typ.args])
16+
17+
return typ
18+
19+
def my_hook(ctx: FunctionSigContext) -> CallableType:
20+
return ctx.default_signature.copy_modified(
21+
arg_types=[_str_to_int(ctx.api, t) for t in ctx.default_signature.arg_types],
22+
ret_type=_str_to_int(ctx.api, ctx.default_signature.ret_type),
23+
)
24+
25+
def plugin(version):
26+
return FunctionSigPlugin

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy