Skip to content

Commit 9d72fab

Browse files
authored
Merge pull request #20119 from paldepind/rust/type-inference-assoc-type-tp
Rust: Type inference for impl trait types with type parameters
2 parents 37b508b + 92bce4e commit 9d72fab

File tree

9 files changed

+1320
-1148
lines changed

9 files changed

+1320
-1148
lines changed

rust/ql/.generated.list

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/ql/.gitattributes

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
// generated by codegen, remove this comment if you wish to edit this file
21
/**
32
* This module provides a hand-modifiable wrapper around the generated class `ImplTraitTypeRepr`.
43
*
54
* INTERNAL: Do not use.
65
*/
76

87
private import codeql.rust.elements.internal.generated.ImplTraitTypeRepr
8+
private import rust
99

1010
/**
1111
* INTERNAL: This module contains the customizable definition of `ImplTraitTypeRepr` and should not
1212
* be referenced directly.
1313
*/
1414
module Impl {
15+
// the following QLdoc is generated: if you need to edit it, do it in the schema file
1516
/**
1617
* An `impl Trait` type.
1718
*
@@ -21,5 +22,15 @@ module Impl {
2122
* // ^^^^^^^^^^^^^^^^^^^^^^^^^^
2223
* ```
2324
*/
24-
class ImplTraitTypeRepr extends Generated::ImplTraitTypeRepr { }
25+
class ImplTraitTypeRepr extends Generated::ImplTraitTypeRepr {
26+
/** Gets the function for which this impl trait type occurs, if any. */
27+
Function getFunction() {
28+
this.getParentNode*() = [result.getRetType().getTypeRepr(), result.getAParam().getTypeRepr()]
29+
}
30+
31+
/** Holds if this impl trait type occurs in the return type of a function. */
32+
predicate isInReturnPos() {
33+
this.getParentNode*() = this.getFunction().getRetType().getTypeRepr()
34+
}
35+
}
2536
}

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,21 @@ newtype TType =
5252
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
5353
TArrayTypeParameter() or
5454
TDynTraitTypeParameter(AstNode n) { dynTraitTypeParameter(_, n) } or
55+
TImplTraitTypeParameter(ImplTraitTypeRepr implTrait, TypeParam tp) {
56+
implTraitTypeParam(implTrait, _, tp)
57+
} or
5558
TRefTypeParameter() or
5659
TSelfTypeParameter(Trait t) or
5760
TSliceTypeParameter()
5861

62+
predicate implTraitTypeParam(ImplTraitTypeRepr implTrait, int i, TypeParam tp) {
63+
implTrait.isInReturnPos() and
64+
tp = implTrait.getFunction().getGenericParamList().getTypeParam(i) and
65+
// Only include type parameters of the function that occur inside the impl
66+
// trait type.
67+
exists(Path path | path.getParentNode*() = implTrait and resolvePath(path) = tp)
68+
}
69+
5970
/**
6071
* A type without type arguments.
6172
*
@@ -263,7 +274,12 @@ class ImplTraitType extends Type, TImplTraitType {
263274

264275
override TupleField getTupleField(int i) { none() }
265276

266-
override TypeParameter getTypeParameter(int i) { none() }
277+
override TypeParameter getTypeParameter(int i) {
278+
exists(TypeParam tp |
279+
implTraitTypeParam(impl, i, tp) and
280+
result = TImplTraitTypeParameter(impl, tp)
281+
)
282+
}
267283

268284
override string toString() { result = impl.toString() }
269285

@@ -302,7 +318,7 @@ class DynTraitType extends Type, TDynTraitType {
302318
class ImplTraitReturnType extends ImplTraitType {
303319
private Function function;
304320

305-
ImplTraitReturnType() { impl = function.getRetType().getTypeRepr() }
321+
ImplTraitReturnType() { impl.isInReturnPos() and function = impl.getFunction() }
306322

307323
override Function getFunction() { result = function }
308324
}
@@ -456,6 +472,21 @@ class DynTraitTypeParameter extends TypeParameter, TDynTraitTypeParameter {
456472
override Location getLocation() { result = n.getLocation() }
457473
}
458474

475+
class ImplTraitTypeParameter extends TypeParameter, TImplTraitTypeParameter {
476+
private TypeParam typeParam;
477+
private ImplTraitTypeRepr implTrait;
478+
479+
ImplTraitTypeParameter() { this = TImplTraitTypeParameter(implTrait, typeParam) }
480+
481+
TypeParam getTypeParam() { result = typeParam }
482+
483+
ImplTraitTypeRepr getImplTraitTypeRepr() { result = implTrait }
484+
485+
override string toString() { result = "impl(" + typeParam.toString() + ")" }
486+
487+
override Location getLocation() { result = typeParam.getLocation() }
488+
}
489+
459490
/** An implicit reference type parameter. */
460491
class RefTypeParameter extends TypeParameter, TRefTypeParameter {
461492
override string toString() { result = "&T" }
@@ -569,5 +600,7 @@ final class SelfTypeBoundTypeAbstraction extends TypeAbstraction, Name {
569600
}
570601

571602
final class ImplTraitTypeReprAbstraction extends TypeAbstraction, ImplTraitTypeRepr {
572-
override TypeParameter getATypeParameter() { none() }
603+
override TypeParameter getATypeParameter() {
604+
implTraitTypeParam(this, _, result.(TypeParamTypeParameter).getTypeParam())
605+
}
573606
}

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -83,42 +83,48 @@ private module Input1 implements InputSig1<Location> {
8383

8484
int getTypeParameterId(TypeParameter tp) {
8585
tp =
86-
rank[result](TypeParameter tp0, int kind, int id |
86+
rank[result](TypeParameter tp0, int kind, int id1, int id2 |
8787
tp0 instanceof ArrayTypeParameter and
8888
kind = 0 and
89-
id = 0
89+
id1 = 0 and
90+
id2 = 0
9091
or
9192
tp0 instanceof RefTypeParameter and
9293
kind = 0 and
93-
id = 1
94+
id1 = 0 and
95+
id2 = 1
9496
or
9597
tp0 instanceof SliceTypeParameter and
9698
kind = 0 and
97-
id = 2
99+
id1 = 0 and
100+
id2 = 2
98101
or
99102
kind = 1 and
100-
id =
103+
id1 = 0 and
104+
id2 =
101105
idOfTypeParameterAstNode([
102106
tp0.(DynTraitTypeParameter).getTypeParam().(AstNode),
103107
tp0.(DynTraitTypeParameter).getTypeAlias()
104108
])
105109
or
106110
kind = 2 and
107-
exists(AstNode node | id = idOfTypeParameterAstNode(node) |
111+
id1 = idOfTypeParameterAstNode(tp0.(ImplTraitTypeParameter).getImplTraitTypeRepr()) and
112+
id2 = idOfTypeParameterAstNode(tp0.(ImplTraitTypeParameter).getTypeParam())
113+
or
114+
kind = 3 and
115+
id1 = 0 and
116+
exists(AstNode node | id2 = idOfTypeParameterAstNode(node) |
108117
node = tp0.(TypeParamTypeParameter).getTypeParam() or
109118
node = tp0.(AssociatedTypeTypeParameter).getTypeAlias() or
110119
node = tp0.(SelfTypeParameter).getTrait() or
111120
node = tp0.(ImplTraitTypeTypeParameter).getImplTraitTypeRepr()
112121
)
113122
or
114-
exists(TupleTypeParameter ttp, int maxArity |
115-
maxArity = max(int i | i = any(TupleType tt).getArity()) and
116-
tp0 = ttp and
117-
kind = 3 and
118-
id = ttp.getTupleType().getArity() * maxArity + ttp.getIndex()
119-
)
123+
kind = 4 and
124+
id1 = tp0.(TupleTypeParameter).getTupleType().getArity() and
125+
id2 = tp0.(TupleTypeParameter).getIndex()
120126
|
121-
tp0 order by kind, id
127+
tp0 order by kind, id1, id2
122128
)
123129
}
124130
}

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,12 @@ class ImplTraitTypeReprMention extends TypeMention instanceof ImplTraitTypeRepr
258258
override Type resolveTypeAt(TypePath typePath) {
259259
typePath.isEmpty() and
260260
result.(ImplTraitType).getImplTraitTypeRepr() = this
261+
or
262+
exists(ImplTraitTypeParameter tp |
263+
this = tp.getImplTraitTypeRepr() and
264+
typePath = TypePath::singleton(tp) and
265+
result = TTypeParamTypeParameter(tp.getTypeParam())
266+
)
261267
}
262268
}
263269

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
multipleCallTargets
22
| dereference.rs:61:15:61:24 | e1.deref() |
3-
| main.rs:2253:13:2253:31 | ...::from(...) |
4-
| main.rs:2254:13:2254:31 | ...::from(...) |
5-
| main.rs:2255:13:2255:31 | ...::from(...) |
6-
| main.rs:2261:13:2261:31 | ...::from(...) |
7-
| main.rs:2262:13:2262:31 | ...::from(...) |
8-
| main.rs:2263:13:2263:31 | ...::from(...) |
3+
| main.rs:2278:13:2278:31 | ...::from(...) |
4+
| main.rs:2279:13:2279:31 | ...::from(...) |
5+
| main.rs:2280:13:2280:31 | ...::from(...) |
6+
| main.rs:2286:13:2286:31 | ...::from(...) |
7+
| main.rs:2287:13:2287:31 | ...::from(...) |
8+
| main.rs:2288:13:2288:31 | ...::from(...) |

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1913,8 +1913,10 @@ mod async_ {
19131913
}
19141914

19151915
mod impl_trait {
1916+
#[derive(Copy, Clone)]
19161917
struct S1;
19171918
struct S2;
1919+
struct S3<T3>(T3);
19181920

19191921
trait Trait1 {
19201922
fn f1(&self) {} // Trait1f1
@@ -1946,6 +1948,13 @@ mod impl_trait {
19461948
}
19471949
}
19481950

1951+
impl<T: Clone> MyTrait<T> for S3<T> {
1952+
fn get_a(&self) -> T {
1953+
let S3(t) = self;
1954+
t.clone()
1955+
}
1956+
}
1957+
19491958
fn get_a_my_trait() -> impl MyTrait<S2> {
19501959
S1
19511960
}
@@ -1954,6 +1963,18 @@ mod impl_trait {
19541963
t.get_a() // $ target=MyTrait::get_a
19551964
}
19561965

1966+
fn get_a_my_trait2<T: Clone>(x: T) -> impl MyTrait<T> {
1967+
S3(x)
1968+
}
1969+
1970+
fn get_a_my_trait3<T: Clone>(x: T) -> Option<impl MyTrait<T>> {
1971+
Some(S3(x))
1972+
}
1973+
1974+
fn get_a_my_trait4<T: Clone>(x: T) -> (impl MyTrait<T>, impl MyTrait<T>) {
1975+
(S3(x.clone()), S3(x)) // $ target=clone
1976+
}
1977+
19571978
fn uses_my_trait2<A>(t: impl MyTrait<A>) -> A {
19581979
t.get_a() // $ target=MyTrait::get_a
19591980
}
@@ -1967,6 +1988,10 @@ mod impl_trait {
19671988
let a = get_a_my_trait(); // $ target=get_a_my_trait
19681989
let c = uses_my_trait2(a); // $ type=c:S2 target=uses_my_trait2
19691990
let d = uses_my_trait2(S1); // $ type=d:S2 target=uses_my_trait2
1991+
let e = get_a_my_trait2(S1).get_a(); // $ target=get_a_my_trait2 target=MyTrait::get_a type=e:S1
1992+
// For this function the `impl` type does not appear in the root of the return type
1993+
let f = get_a_my_trait3(S1).unwrap().get_a(); // $ target=get_a_my_trait3 target=unwrap target=MyTrait::get_a type=f:S1
1994+
let g = get_a_my_trait4(S1).0.get_a(); // $ target=get_a_my_trait4 target=MyTrait::get_a type=g:S1
19701995
}
19711996
}
19721997

@@ -2425,7 +2450,7 @@ mod tuples {
24252450

24262451
let pair = [1, 1].into(); // $ type=pair:(T_2) type=pair:0(2).i32 type=pair:1(2).i32 MISSING: target=into
24272452
match pair {
2428-
(0,0) => print!("unexpected"),
2453+
(0, 0) => print!("unexpected"),
24292454
_ => print!("expected"),
24302455
}
24312456
let x = pair.0; // $ type=x:i32

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy