Skip to content

Commit 8238746

Browse files
authored
Merge pull request #20084 from paldepind/rust/type-inference-trait-object
Rust: Implement type inference for trait objects/`dyn` types
2 parents 5da7ae8 + b3dc6cb commit 8238746

File tree

11 files changed

+640
-311
lines changed

11 files changed

+640
-311
lines changed

rust/ql/.generated.list

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/ql/.gitattributes

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/ql/lib/codeql/rust/elements/internal/DynTraitTypeReprImpl.qll

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// generated by codegen, remove this comment if you wish to edit this file
21
/**
32
* This module provides a hand-modifiable wrapper around the generated class `DynTraitTypeRepr`.
43
*
@@ -12,6 +11,10 @@ private import codeql.rust.elements.internal.generated.DynTraitTypeRepr
1211
* be referenced directly.
1312
*/
1413
module Impl {
14+
private import rust
15+
private import codeql.rust.internal.PathResolution as PathResolution
16+
17+
// the following QLdoc is generated: if you need to edit it, do it in the schema file
1518
/**
1619
* A dynamic trait object type.
1720
*
@@ -21,5 +24,16 @@ module Impl {
2124
* // ^^^^^^^^^
2225
* ```
2326
*/
24-
class DynTraitTypeRepr extends Generated::DynTraitTypeRepr { }
27+
class DynTraitTypeRepr extends Generated::DynTraitTypeRepr {
28+
/** Gets the trait that this trait object refers to. */
29+
pragma[nomagic]
30+
Trait getTrait() {
31+
result =
32+
PathResolution::resolvePath(this.getTypeBoundList()
33+
.getBound(0)
34+
.getTypeRepr()
35+
.(PathTypeRepr)
36+
.getPath())
37+
}
38+
}
2539
}

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,15 @@ newtype TType =
2424
TArrayType() or // todo: add size?
2525
TRefType() or // todo: add mut?
2626
TImplTraitType(ImplTraitTypeRepr impl) or
27+
TDynTraitType(Trait t) { t = any(DynTraitTypeRepr dt).getTrait() } or
2728
TSliceType() or
2829
TTupleTypeParameter(int arity, int i) { exists(TTuple(arity)) and i in [0 .. arity - 1] } or
2930
TTypeParamTypeParameter(TypeParam t) or
3031
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
3132
TArrayTypeParameter() or
33+
TDynTraitTypeParameter(TypeParam tp) {
34+
tp = any(DynTraitTypeRepr dt).getTrait().getGenericParamList().getATypeParam()
35+
} or
3236
TRefTypeParameter() or
3337
TSelfTypeParameter(Trait t) or
3438
TSliceTypeParameter()
@@ -247,6 +251,26 @@ class ImplTraitType extends Type, TImplTraitType {
247251
override Location getLocation() { result = impl.getLocation() }
248252
}
249253

254+
class DynTraitType extends Type, TDynTraitType {
255+
Trait trait;
256+
257+
DynTraitType() { this = TDynTraitType(trait) }
258+
259+
override StructField getStructField(string name) { none() }
260+
261+
override TupleField getTupleField(int i) { none() }
262+
263+
override DynTraitTypeParameter getTypeParameter(int i) {
264+
result = TDynTraitTypeParameter(trait.getGenericParamList().getTypeParam(i))
265+
}
266+
267+
Trait getTrait() { result = trait }
268+
269+
override string toString() { result = "dyn " + trait.getName().toString() }
270+
271+
override Location getLocation() { result = trait.getLocation() }
272+
}
273+
250274
/**
251275
* An [impl Trait in return position][1] type, for example:
252276
*
@@ -381,6 +405,18 @@ class ArrayTypeParameter extends TypeParameter, TArrayTypeParameter {
381405
override Location getLocation() { result instanceof EmptyLocation }
382406
}
383407

408+
class DynTraitTypeParameter extends TypeParameter, TDynTraitTypeParameter {
409+
private TypeParam typeParam;
410+
411+
DynTraitTypeParameter() { this = TDynTraitTypeParameter(typeParam) }
412+
413+
TypeParam getTypeParam() { result = typeParam }
414+
415+
override string toString() { result = "dyn(" + typeParam.toString() + ")" }
416+
417+
override Location getLocation() { result = typeParam.getLocation() }
418+
}
419+
384420
/** An implicit reference type parameter. */
385421
class RefTypeParameter extends TypeParameter, TRefTypeParameter {
386422
override string toString() { result = "&T" }
@@ -465,6 +501,13 @@ final class ImplTypeAbstraction extends TypeAbstraction, Impl {
465501
}
466502
}
467503

504+
final class DynTypeAbstraction extends TypeAbstraction, DynTraitTypeRepr {
505+
override TypeParameter getATypeParameter() {
506+
result.(TypeParamTypeParameter).getTypeParam() =
507+
this.getTrait().getGenericParamList().getATypeParam()
508+
}
509+
}
510+
468511
final class TraitTypeAbstraction extends TypeAbstraction, Trait {
469512
override TypeParameter getATypeParameter() {
470513
result.(TypeParamTypeParameter).getTypeParam() = this.getGenericParamList().getATypeParam()

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ private module Input1 implements InputSig1<Location> {
9797
id = 2
9898
or
9999
kind = 1 and
100+
id = idOfTypeParameterAstNode(tp0.(DynTraitTypeParameter).getTypeParam())
101+
or
102+
kind = 2 and
100103
exists(AstNode node | id = idOfTypeParameterAstNode(node) |
101104
node = tp0.(TypeParamTypeParameter).getTypeParam() or
102105
node = tp0.(AssociatedTypeTypeParameter).getTypeAlias() or
@@ -107,7 +110,7 @@ private module Input1 implements InputSig1<Location> {
107110
exists(TupleTypeParameter ttp, int maxArity |
108111
maxArity = max(int i | i = any(TupleType tt).getArity()) and
109112
tp0 = ttp and
110-
kind = 2 and
113+
kind = 3 and
111114
id = ttp.getTupleType().getArity() * maxArity + ttp.getIndex()
112115
)
113116
|
@@ -189,6 +192,14 @@ private module Input2 implements InputSig2 {
189192
condition = impl and
190193
constraint = impl.getTypeBoundList().getABound().getTypeRepr()
191194
)
195+
or
196+
// a `dyn Trait` type implements `Trait`. See the comment on
197+
// `DynTypeBoundListMention` for further details.
198+
exists(DynTraitTypeRepr object |
199+
abs = object and
200+
condition = object.getTypeBoundList() and
201+
constraint = object.getTrait()
202+
)
192203
}
193204
}
194205

@@ -1715,10 +1726,16 @@ private Function getMethodFromImpl(MethodCall mc) {
17151726

17161727
bindingset[trait, name]
17171728
pragma[inline_late]
1718-
private Function getTraitMethod(ImplTraitReturnType trait, string name) {
1729+
private Function getImplTraitMethod(ImplTraitReturnType trait, string name) {
17191730
result = getMethodSuccessor(trait.getImplTraitTypeRepr(), name)
17201731
}
17211732

1733+
bindingset[traitObject, name]
1734+
pragma[inline_late]
1735+
private Function getDynTraitMethod(DynTraitType traitObject, string name) {
1736+
result = getMethodSuccessor(traitObject.getTrait(), name)
1737+
}
1738+
17221739
pragma[nomagic]
17231740
private Function resolveMethodCallTarget(MethodCall mc) {
17241741
// The method comes from an `impl` block targeting the type of the receiver.
@@ -1729,7 +1746,10 @@ private Function resolveMethodCallTarget(MethodCall mc) {
17291746
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
17301747
or
17311748
// The type of the receiver is an `impl Trait` type.
1732-
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1749+
result = getImplTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1750+
or
1751+
// The type of the receiver is a trait object `dyn Trait` type.
1752+
result = getDynTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
17331753
}
17341754

17351755
pragma[nomagic]
@@ -2073,6 +2093,13 @@ private module Debug {
20732093
result = resolveCallTarget(c)
20742094
}
20752095

2096+
predicate debugConditionSatisfiesConstraint(
2097+
TypeAbstraction abs, TypeMention condition, TypeMention constraint
2098+
) {
2099+
abs = getRelevantLocatable() and
2100+
Input2::conditionSatisfiesConstraint(abs, condition, constraint)
2101+
}
2102+
20762103
predicate debugInferImplicitSelfType(SelfParam self, TypePath path, Type t) {
20772104
self = getRelevantLocatable() and
20782105
t = inferImplicitSelfType(self, path)

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,64 @@ class SelfTypeParameterMention extends TypeMention instanceof Name {
309309
result = TSelfTypeParameter(trait)
310310
}
311311
}
312+
313+
class DynTraitTypeReprMention extends TypeMention instanceof DynTraitTypeRepr {
314+
private DynTraitType dynType;
315+
316+
DynTraitTypeReprMention() {
317+
// This excludes `DynTraitTypeRepr` elements where `getTrait` is not
318+
// defined, i.e., where path resolution can't find a trait.
319+
dynType.getTrait() = super.getTrait()
320+
}
321+
322+
override Type resolveTypeAt(TypePath path) {
323+
path.isEmpty() and
324+
result = dynType
325+
or
326+
exists(DynTraitTypeParameter tp, TypePath path0, TypePath suffix |
327+
tp = dynType.getTypeParameter(_) and
328+
path = TypePath::cons(tp, suffix) and
329+
result = super.getTypeBoundList().getBound(0).getTypeRepr().(TypeMention).resolveTypeAt(path0) and
330+
path0.isCons(TTypeParamTypeParameter(tp.getTypeParam()), suffix)
331+
)
332+
}
333+
}
334+
335+
// We want a type of the form `dyn Trait` to implement `Trait`. If `Trait` has
336+
// type parameters then `dyn Trait` has equivalent type parameters and the
337+
// implementation should be abstracted over them.
338+
//
339+
// Intuitively we want something to the effect of:
340+
// ```
341+
// impl<A, B, ..> Trait<A, B, ..> for (dyn Trait)<A, B, ..>
342+
// ```
343+
// To achieve this:
344+
// - `DynTypeAbstraction` is an abstraction over type parameters of the trait.
345+
// - `DynTypeBoundListMention` (this class) is a type mention which has `dyn
346+
// Trait` at the root and which for every type parameter of `dyn Trait` has the
347+
// corresponding type parameter of the trait.
348+
// - `TraitMention` (which is used for other things as well) is a type mention
349+
// for the trait applied to its own type parameters.
350+
//
351+
// We arbitrarily use the `TypeBoundList` inside `DynTraitTypeRepr` to encode
352+
// this type mention, since it doesn't syntactically appear in the AST. This
353+
// works because there is a one-to-one correspondence between a trait object and
354+
// its list of type bounds.
355+
class DynTypeBoundListMention extends TypeMention instanceof TypeBoundList {
356+
private Trait trait;
357+
358+
DynTypeBoundListMention() {
359+
exists(DynTraitTypeRepr dyn | this = dyn.getTypeBoundList() and trait = dyn.getTrait())
360+
}
361+
362+
override Type resolveTypeAt(TypePath path) {
363+
path.isEmpty() and
364+
result.(DynTraitType).getTrait() = trait
365+
or
366+
exists(TypeParam param |
367+
param = trait.getGenericParamList().getATypeParam() and
368+
path = TypePath::singleton(TDynTraitTypeParameter(param)) and
369+
result = TTypeParamTypeParameter(param)
370+
)
371+
}
372+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
category: minorAnalysis
3+
---
4+
* Type inference now supports trait objects, i.e., `dyn Trait` types.
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Test cases for type inference and method resolution with `dyn` types
2+
3+
use std::fmt::Debug;
4+
5+
trait MyTrait1 {
6+
// MyTrait1::m
7+
fn m(&self) -> String;
8+
}
9+
10+
trait GenericGet<A> {
11+
// GenericGet::get
12+
fn get(&self) -> A;
13+
}
14+
15+
#[derive(Clone, Debug)]
16+
struct MyStruct {
17+
value: i32,
18+
}
19+
20+
impl MyTrait1 for MyStruct {
21+
// MyStruct1::m
22+
fn m(&self) -> String {
23+
format!("MyTrait1: {}", self.value) // $ fieldof=MyStruct
24+
}
25+
}
26+
27+
#[derive(Clone, Debug)]
28+
struct GenStruct<A: Clone + Debug> {
29+
value: A,
30+
}
31+
32+
impl<A: Clone + Debug> GenericGet<A> for GenStruct<A> {
33+
// GenStruct<A>::get
34+
fn get(&self) -> A {
35+
self.value.clone() // $ fieldof=GenStruct target=clone
36+
}
37+
}
38+
39+
fn get_a<A, G: GenericGet<A> + ?Sized>(a: &G) -> A {
40+
a.get() // $ target=GenericGet::get
41+
}
42+
43+
fn get_box_trait<A: Clone + Debug + 'static>(a: A) -> Box<dyn GenericGet<A>> {
44+
Box::new(GenStruct { value: a }) // $ target=new
45+
}
46+
47+
fn test_basic_dyn_trait(obj: &dyn MyTrait1) {
48+
let _result = (*obj).m(); // $ target=deref target=MyTrait1::m type=_result:String
49+
}
50+
51+
fn test_generic_dyn_trait(obj: &dyn GenericGet<String>) {
52+
let _result1 = (*obj).get(); // $ target=deref target=GenericGet::get type=_result1:String
53+
let _result2 = get_a(obj); // $ target=get_a type=_result2:String
54+
}
55+
56+
fn test_poly_dyn_trait() {
57+
let obj = get_box_trait(true); // $ target=get_box_trait
58+
let _result = (*obj).get(); // $ target=deref target=GenericGet::get type=_result:bool
59+
}
60+
61+
pub fn test() {
62+
test_basic_dyn_trait(&MyStruct { value: 42 }); // $ target=test_basic_dyn_trait
63+
test_generic_dyn_trait(&GenStruct {
64+
value: "".to_string(),
65+
}); // $ target=test_generic_dyn_trait
66+
test_poly_dyn_trait(); // $ target=test_poly_dyn_trait
67+
}

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2292,8 +2292,6 @@ mod loops {
22922292
}
22932293
}
22942294

2295-
mod dereference;
2296-
22972295
mod explicit_type_args {
22982296
struct S1<T>(T);
22992297

@@ -2461,6 +2459,9 @@ mod closures {
24612459
}
24622460
}
24632461

2462+
mod dereference;
2463+
mod dyn_type;
2464+
24642465
fn main() {
24652466
field_access::f(); // $ target=f
24662467
method_impl::f(); // $ target=f
@@ -2491,5 +2492,6 @@ fn main() {
24912492
dereference::test(); // $ target=test
24922493
pattern_matching::test_all_patterns(); // $ target=test_all_patterns
24932494
pattern_matching_experimental::box_patterns(); // $ target=box_patterns
2494-
closures::f() // $ target=f
2495+
closures::f(); // $ target=f
2496+
dyn_type::test(); // $ target=test
24952497
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy