@@ -145,19 +145,23 @@ class SuggestionEngine:
145
145
"""Engine for finding call sites and suggesting signatures."""
146
146
147
147
def __init__ (self , fgmanager : FineGrainedBuildManager ,
148
+ * ,
148
149
json : bool ,
149
150
no_errors : bool = False ,
150
151
no_any : bool = False ,
151
- try_text : bool = False ) -> None :
152
+ try_text : bool = False ,
153
+ flex_any : Optional [float ] = None ) -> None :
152
154
self .fgmanager = fgmanager
153
155
self .manager = fgmanager .manager
154
156
self .plugin = self .manager .plugin
155
157
self .graph = fgmanager .graph
156
158
157
159
self .give_json = json
158
160
self .no_errors = no_errors
159
- self .no_any = no_any
160
161
self .try_text = try_text
162
+ self .flex_any = flex_any
163
+ if no_any :
164
+ self .flex_any = 1.0
161
165
162
166
self .max_guesses = 16
163
167
@@ -285,13 +289,13 @@ def get_callsites(self, func: FuncDef) -> Tuple[List[Callsite], List[str]]:
285
289
286
290
return collector_plugin .mystery_hits , errors
287
291
288
- def filter_options (self , guesses : List [CallableType ]) -> List [CallableType ]:
292
+ def filter_options (self , guesses : List [CallableType ], is_method : bool ) -> List [CallableType ]:
289
293
"""Apply any configured filters to the possible guesses.
290
294
291
- Currently the only option is disabling Anys ."""
295
+ Currently the only option is filtering based on Any prevalance ."""
292
296
return [
293
297
t for t in guesses
294
- if not self .no_any or not callable_has_any ( t )
298
+ if self .flex_any is None or self . any_score_callable ( t , is_method ) >= self . flex_any
295
299
]
296
300
297
301
def find_best (self , func : FuncDef , guesses : List [CallableType ]) -> Tuple [CallableType , int ]:
@@ -329,7 +333,7 @@ def get_suggestion(self, function: str) -> str:
329
333
self .get_trivial_type (node ),
330
334
self .get_default_arg_types (graph [mod ], node ),
331
335
callsites )
332
- guesses = self .filter_options (guesses )
336
+ guesses = self .filter_options (guesses , is_method )
333
337
if len (guesses ) > self .max_guesses :
334
338
raise SuggestionFailure ("Too many possibilities!" )
335
339
best , _ = self .find_best (node , guesses )
@@ -344,7 +348,7 @@ def get_suggestion(self, function: str) -> str:
344
348
ret_types = [NoneType ()]
345
349
346
350
guesses = [best .copy_modified (ret_type = t ) for t in ret_types ]
347
- guesses = self .filter_options (guesses )
351
+ guesses = self .filter_options (guesses , is_method )
348
352
best , errors = self .find_best (node , guesses )
349
353
350
354
if self .no_errors and errors :
@@ -511,14 +515,16 @@ def format_callable(self,
511
515
def format_type (self , cur_module : Optional [str ], typ : Type ) -> str :
512
516
return typ .accept (TypeFormatter (cur_module , self .graph ))
513
517
514
- def score_type (self , t : Type ) -> int :
518
+ def score_type (self , t : Type , arg_pos : bool ) -> int :
515
519
"""Generate a score for a type that we use to pick which type to use.
516
520
517
521
Lower is better, prefer non-union/non-any types. Don't penalize optionals.
518
522
"""
519
523
t = get_proper_type (t )
520
524
if isinstance (t , AnyType ):
521
525
return 20
526
+ if arg_pos and isinstance (t , NoneType ):
527
+ return 20
522
528
if isinstance (t , UnionType ):
523
529
if any (isinstance (x , AnyType ) for x in t .items ):
524
530
return 20
@@ -529,7 +535,41 @@ def score_type(self, t: Type) -> int:
529
535
return 0
530
536
531
537
def score_callable (self , t : CallableType ) -> int :
532
- return sum ([self .score_type (x ) for x in t .arg_types ])
538
+ return (sum ([self .score_type (x , arg_pos = True ) for x in t .arg_types ]) +
539
+ self .score_type (t .ret_type , arg_pos = False ))
540
+
541
+ def any_score_type (self , ut : Type , arg_pos : bool ) -> float :
542
+ """Generate a very made up number representing the Anyness of a type.
543
+
544
+ Higher is better, 1.0 is max
545
+ """
546
+ t = get_proper_type (ut )
547
+ if isinstance (t , AnyType ) and t .type_of_any != TypeOfAny .special_form :
548
+ return 0
549
+ if isinstance (t , NoneType ) and arg_pos :
550
+ return 0.5
551
+ if isinstance (t , UnionType ):
552
+ if any (isinstance (x , AnyType ) for x in t .items ):
553
+ return 0.5
554
+ if any (has_any_type (x ) for x in t .items ):
555
+ return 0.25
556
+ if has_any_type (t ):
557
+ return 0.5
558
+
559
+ return 1.0
560
+
561
+ def any_score_callable (self , t : CallableType , is_method : bool ) -> float :
562
+ # Ignore the first argument of methods
563
+ scores = [self .any_score_type (x , arg_pos = True ) for x in t .arg_types [int (is_method ):]]
564
+ # Return type counts twice (since it spreads type information), unless it is
565
+ # None in which case it does not count at all. (Though it *does* still count
566
+ # if there are no arguments.)
567
+ if not isinstance (get_proper_type (t .ret_type ), NoneType ) or not scores :
568
+ ret = self .any_score_type (t .ret_type , arg_pos = False )
569
+ scores += [ret , ret ]
570
+
571
+ # print(scores, t)
572
+ return sum (scores ) / len (scores )
533
573
534
574
535
575
class TypeFormatter (TypeStrVisitor ):
@@ -611,13 +651,6 @@ def count_errors(msgs: List[str]) -> int:
611
651
return len ([x for x in msgs if ' error: ' in x ])
612
652
613
653
614
- def callable_has_any (t : CallableType ) -> int :
615
- # We count a bare None in argument position as Any, since
616
- # pyannotate turns it into Optional[Any]
617
- return (any (isinstance (at , NoneType ) for at in get_proper_types (t .arg_types ))
618
- or has_any_type (t ))
619
-
620
-
621
654
T = TypeVar ('T' )
622
655
623
656
0 commit comments