Skip to content

Commit 7e7eedd

Browse files
authored
Infer type for partial generic type from assignment (python#8036)
Code like this no longer requires a type annotation: ``` a = [] if foo(): a = [1] ``` Work towards python#1055.
1 parent b7465de commit 7e7eedd

File tree

3 files changed

+69
-29
lines changed

3 files changed

+69
-29
lines changed

mypy/checker.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2042,6 +2042,7 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
20422042
self.check_assignment_to_multiple_lvalues(lvalue.items, rvalue, rvalue,
20432043
infer_lvalue_type)
20442044
else:
2045+
self.try_infer_partial_generic_type_from_assignment(lvalue, rvalue)
20452046
lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue)
20462047
# If we're assigning to __getattr__ or similar methods, check that the signature is
20472048
# valid.
@@ -2141,6 +2142,37 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
21412142
rvalue_type = remove_instance_last_known_values(rvalue_type)
21422143
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)
21432144

2145+
def try_infer_partial_generic_type_from_assignment(self,
2146+
lvalue: Lvalue,
2147+
rvalue: Expression) -> None:
2148+
"""Try to infer a precise type for partial generic type from assignment.
2149+
2150+
Example where this happens:
2151+
2152+
x = []
2153+
if foo():
2154+
x = [1] # Infer List[int] as type of 'x'
2155+
"""
2156+
if (isinstance(lvalue, NameExpr)
2157+
and isinstance(lvalue.node, Var)
2158+
and isinstance(lvalue.node.type, PartialType)):
2159+
var = lvalue.node
2160+
typ = lvalue.node.type
2161+
if typ.type is None:
2162+
return
2163+
partial_types = self.find_partial_types(var)
2164+
if partial_types is None:
2165+
return
2166+
rvalue_type = self.expr_checker.accept(rvalue)
2167+
rvalue_type = get_proper_type(rvalue_type)
2168+
if isinstance(rvalue_type, Instance):
2169+
if rvalue_type.type == typ.type:
2170+
var.type = rvalue_type
2171+
del partial_types[var]
2172+
elif isinstance(rvalue_type, AnyType):
2173+
var.type = fill_typevars_with_any(typ.type)
2174+
del partial_types[var]
2175+
21442176
def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[Type],
21452177
rvalue: Expression) -> bool:
21462178
lvalue_node = lvalue.node

mypy/errors.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -598,14 +598,22 @@ def remove_duplicates(self, errors: List[ErrorTuple]) -> List[ErrorTuple]:
598598
i = 0
599599
while i < len(errors):
600600
dup = False
601+
# Use slightly special formatting for member conflicts reporting.
602+
conflicts_notes = False
603+
j = i - 1
604+
while j >= 0 and errors[j][0] == errors[i][0]:
605+
if errors[j][4].strip() == 'Got:':
606+
conflicts_notes = True
607+
j -= 1
601608
j = i - 1
602609
while (j >= 0 and errors[j][0] == errors[i][0] and
603610
errors[j][1] == errors[i][1]):
604611
if (errors[j][3] == errors[i][3] and
605612
# Allow duplicate notes in overload conflicts reporting.
606-
not (errors[i][3] == 'note' and
607-
errors[i][4].strip() in allowed_duplicates
608-
or errors[i][4].strip().startswith('def ')) and
613+
not ((errors[i][3] == 'note' and
614+
errors[i][4].strip() in allowed_duplicates)
615+
or (errors[i][4].strip().startswith('def ') and
616+
conflicts_notes)) and
609617
errors[j][4] == errors[i][4]): # ignore column
610618
dup = True
611619
break

test-data/unit/check-inference.test

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,34 +1367,29 @@ a = []
13671367
a.append(1)
13681368
a.append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
13691369
[builtins fixtures/list.pyi]
1370-
[out]
13711370

13721371
[case testInferListInitializedToEmptyUsingUpdate]
13731372
a = []
13741373
a.extend([''])
13751374
a.append(0) # E: Argument 1 to "append" of "list" has incompatible type "int"; expected "str"
13761375
[builtins fixtures/list.pyi]
1377-
[out]
13781376

13791377
[case testInferListInitializedToEmptyAndNotAnnotated]
13801378
a = [] # E: Need type annotation for 'a' (hint: "a: List[<type>] = ...")
13811379
[builtins fixtures/list.pyi]
1382-
[out]
13831380

13841381
[case testInferListInitializedToEmptyAndReadBeforeAppend]
13851382
a = [] # E: Need type annotation for 'a' (hint: "a: List[<type>] = ...")
13861383
if a: pass
13871384
a.xyz # E: "List[Any]" has no attribute "xyz"
13881385
a.append('')
13891386
[builtins fixtures/list.pyi]
1390-
[out]
13911387

13921388
[case testInferListInitializedToEmptyAndIncompleteTypeInAppend]
13931389
a = [] # E: Need type annotation for 'a' (hint: "a: List[<type>] = ...")
13941390
a.append([])
13951391
a() # E: "List[Any]" not callable
13961392
[builtins fixtures/list.pyi]
1397-
[out]
13981393

13991394
[case testInferListInitializedToEmptyAndMultipleAssignment]
14001395
a, b = [], []
@@ -1403,15 +1398,13 @@ b.append('')
14031398
a() # E: "List[int]" not callable
14041399
b() # E: "List[str]" not callable
14051400
[builtins fixtures/list.pyi]
1406-
[out]
14071401

14081402
[case testInferListInitializedToEmptyInFunction]
14091403
def f() -> None:
14101404
a = []
14111405
a.append(1)
14121406
a.append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
14131407
[builtins fixtures/list.pyi]
1414-
[out]
14151408

14161409
[case testInferListInitializedToEmptyAndNotAnnotatedInFunction]
14171410
def f() -> None:
@@ -1422,7 +1415,6 @@ def g() -> None: pass
14221415
a = []
14231416
a.append(1)
14241417
[builtins fixtures/list.pyi]
1425-
[out]
14261418

14271419
[case testInferListInitializedToEmptyAndReadBeforeAppendInFunction]
14281420
def f() -> None:
@@ -1431,15 +1423,13 @@ def f() -> None:
14311423
a.xyz # E: "List[Any]" has no attribute "xyz"
14321424
a.append('')
14331425
[builtins fixtures/list.pyi]
1434-
[out]
14351426

14361427
[case testInferListInitializedToEmptyInClassBody]
14371428
class A:
14381429
a = []
14391430
a.append(1)
14401431
a.append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
14411432
[builtins fixtures/list.pyi]
1442-
[out]
14431433

14441434
[case testInferListInitializedToEmptyAndNotAnnotatedInClassBody]
14451435
class A:
@@ -1449,7 +1439,6 @@ class B:
14491439
a = []
14501440
a.append(1)
14511441
[builtins fixtures/list.pyi]
1452-
[out]
14531442

14541443
[case testInferListInitializedToEmptyInMethod]
14551444
class A:
@@ -1458,14 +1447,12 @@ class A:
14581447
a.append(1)
14591448
a.append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
14601449
[builtins fixtures/list.pyi]
1461-
[out]
14621450

14631451
[case testInferListInitializedToEmptyAndNotAnnotatedInMethod]
14641452
class A:
14651453
def f(self) -> None:
14661454
a = [] # E: Need type annotation for 'a' (hint: "a: List[<type>] = ...")
14671455
[builtins fixtures/list.pyi]
1468-
[out]
14691456

14701457
[case testInferListInitializedToEmptyInMethodViaAttribute]
14711458
class A:
@@ -1475,7 +1462,6 @@ class A:
14751462
self.a.append(1)
14761463
self.a.append('')
14771464
[builtins fixtures/list.pyi]
1478-
[out]
14791465

14801466
[case testInferListInitializedToEmptyInClassBodyAndOverriden]
14811467
from typing import List
@@ -1490,57 +1476,49 @@ class B(A):
14901476
def x(self) -> List[int]: # E: Signature of "x" incompatible with supertype "A"
14911477
return [123]
14921478
[builtins fixtures/list.pyi]
1493-
[out]
14941479

14951480
[case testInferSetInitializedToEmpty]
14961481
a = set()
14971482
a.add(1)
14981483
a.add('') # E: Argument 1 to "add" of "set" has incompatible type "str"; expected "int"
14991484
[builtins fixtures/set.pyi]
1500-
[out]
15011485

15021486
[case testInferSetInitializedToEmptyUsingDiscard]
15031487
a = set()
15041488
a.discard('')
15051489
a.add(0) # E: Argument 1 to "add" of "set" has incompatible type "int"; expected "str"
15061490
[builtins fixtures/set.pyi]
1507-
[out]
15081491

15091492
[case testInferSetInitializedToEmptyUsingUpdate]
15101493
a = set()
15111494
a.update({0})
15121495
a.add('') # E: Argument 1 to "add" of "set" has incompatible type "str"; expected "int"
15131496
[builtins fixtures/set.pyi]
1514-
[out]
15151497

15161498
[case testInferDictInitializedToEmpty]
15171499
a = {}
15181500
a[1] = ''
15191501
a() # E: "Dict[int, str]" not callable
15201502
[builtins fixtures/dict.pyi]
1521-
[out]
15221503

15231504
[case testInferDictInitializedToEmptyUsingUpdate]
15241505
a = {}
15251506
a.update({'': 42})
15261507
a() # E: "Dict[str, int]" not callable
15271508
[builtins fixtures/dict.pyi]
1528-
[out]
15291509

15301510
[case testInferDictInitializedToEmptyUsingUpdateError]
15311511
a = {} # E: Need type annotation for 'a' (hint: "a: Dict[<type>, <type>] = ...")
15321512
a.update([1, 2]) # E: Argument 1 to "update" of "dict" has incompatible type "List[int]"; expected "Mapping[Any, Any]"
15331513
a() # E: "Dict[Any, Any]" not callable
15341514
[builtins fixtures/dict.pyi]
1535-
[out]
15361515

15371516
[case testInferDictInitializedToEmptyAndIncompleteTypeInUpdate]
15381517
a = {} # E: Need type annotation for 'a' (hint: "a: Dict[<type>, <type>] = ...")
15391518
a[1] = {}
15401519
b = {} # E: Need type annotation for 'b' (hint: "b: Dict[<type>, <type>] = ...")
15411520
b[{}] = 1
15421521
[builtins fixtures/dict.pyi]
1543-
[out]
15441522

15451523
[case testInferDictInitializedToEmptyAndUpdatedFromMethod]
15461524
map = {}
@@ -1557,20 +1535,42 @@ def add():
15571535
[case testSpecialCaseEmptyListInitialization]
15581536
def f(blocks: Any): # E: Name 'Any' is not defined \
15591537
# N: Did you forget to import it from "typing"? (Suggestion: "from typing import Any")
1560-
to_process = [] # E: Need type annotation for 'to_process' (hint: "to_process: List[<type>] = ...")
1538+
to_process = []
15611539
to_process = list(blocks)
15621540
[builtins fixtures/list.pyi]
1563-
[out]
15641541

15651542
[case testSpecialCaseEmptyListInitialization2]
15661543
def f(blocks: object):
1567-
to_process = [] # E: Need type annotation for 'to_process' (hint: "to_process: List[<type>] = ...")
1544+
to_process = []
15681545
to_process = list(blocks) # E: No overload variant of "list" matches argument type "object" \
15691546
# N: Possible overload variant: \
15701547
# N: def [T] __init__(self, x: Iterable[T]) -> List[T] \
15711548
# N: <1 more non-matching overload not shown>
15721549
[builtins fixtures/list.pyi]
1573-
[out]
1550+
1551+
[case testInferListInitializedToEmptyAndAssigned]
1552+
a = []
1553+
if bool():
1554+
a = [1]
1555+
reveal_type(a) # N: Revealed type is 'builtins.list[builtins.int*]'
1556+
1557+
def f():
1558+
return [1]
1559+
b = []
1560+
if bool():
1561+
b = f()
1562+
reveal_type(b) # N: Revealed type is 'builtins.list[Any]'
1563+
1564+
d = {}
1565+
if bool():
1566+
d = {1: 'x'}
1567+
reveal_type(d) # N: Revealed type is 'builtins.dict[builtins.int*, builtins.str*]'
1568+
1569+
dd = {} # E: Need type annotation for 'dd' (hint: "dd: Dict[<type>, <type>] = ...")
1570+
if bool():
1571+
dd = [1] # E: Incompatible types in assignment (expression has type "List[int]", variable has type "Dict[Any, Any]")
1572+
reveal_type(dd) # N: Revealed type is 'builtins.dict[Any, Any]'
1573+
[builtins fixtures/dict.pyi]
15741574

15751575

15761576
-- Inferring types of variables first initialized to None (partial types)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy