Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 75 additions & 48 deletions rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -778,13 +778,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
prefix1 = TypePath::singleton(getArrayTypeParameter()) and
prefix2.isEmpty()
or
exists(Struct s |
n2 = [n1.(RangeExpr).getStart(), n1.(RangeExpr).getEnd()] and
prefix1 = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and
prefix2.isEmpty() and
s = getRangeType(n1)
)
or
exists(ClosureExpr ce, int index |
n1 = ce and
n2 = ce.getParam(index).getPat() and
Expand Down Expand Up @@ -829,6 +822,12 @@ private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) {
bodyReturns(parent, child) and
strictcount(Expr e | bodyReturns(parent, e)) > 1 and
prefix.isEmpty()
or
exists(Struct s |
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not what solves the timeout, but I saw cases where type information would incorrectly flow between limits in range expressions, so I decided to treat them as LUB conversions.

child = [parent.(RangeExpr).getStart(), parent.(RangeExpr).getEnd()] and
prefix = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and
s = getRangeType(parent)
)
}

/**
Expand Down Expand Up @@ -1031,10 +1030,10 @@ private module StructExprMatchingInput implements MatchingInputSig {
private module StructExprMatching = Matching<StructExprMatchingInput>;

pragma[nomagic]
private Type inferStructExprType0(AstNode n, boolean isReturn, TypePath path) {
private Type inferStructExprType0(AstNode n, FunctionPosition pos, TypePath path) {
exists(StructExprMatchingInput::Access a, StructExprMatchingInput::AccessPosition apos |
n = a.getNodeAt(apos) and
if apos.isStructPos() then isReturn = true else isReturn = false
if apos.isStructPos() then pos.isReturn() else pos.asPosition() = 0 // the acutal position doesn't matter, as long as it is positional
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in the comment: "acutal" should be "actual".

Suggested change
if apos.isStructPos() then pos.isReturn() else pos.asPosition() = 0 // the acutal position doesn't matter, as long as it is positional
if apos.isStructPos() then pos.isReturn() else pos.asPosition() = 0 // the actual position doesn't matter, as long as it is positional

Copilot uses AI. Check for mistakes.
|
result = StructExprMatching::inferAccessType(a, apos, path)
or
Expand Down Expand Up @@ -1113,6 +1112,25 @@ private Trait getCallExprTraitQualifier(CallExpr ce) {
* Provides functionality related to context-based typing of calls.
*/
private module ContextTyping {
/**
* Holds if `f` mentions type parameter `tp` at some non-return position,
* possibly via a constraint on another mentioned type parameter.
*/
pragma[nomagic]
private predicate assocFunctionMentionsTypeParameterAtNonRetPos(
ImplOrTraitItemNode i, Function f, TypeParameter tp
) {
exists(FunctionPosition nonRetPos |
not nonRetPos.isReturn() and
tp = getAssocFunctionTypeAt(f, i, nonRetPos, _)
)
or
exists(TypeParameter mid |
assocFunctionMentionsTypeParameterAtNonRetPos(i, f, mid) and
tp = getATypeParameterConstraint(mid, _)
)
}

/**
* Holds if the return type of the function `f` inside `i` at `path` is type
* parameter `tp`, and `tp` does not appear in the type of any parameter of
Expand All @@ -1129,12 +1147,7 @@ private module ContextTyping {
) {
pos.isReturn() and
tp = getAssocFunctionTypeAt(f, i, pos, path) and
not exists(FunctionPosition nonResPos | not nonResPos.isReturn() |
tp = getAssocFunctionTypeAt(f, i, nonResPos, _)
or
// `Self` types in traits implicitly mention all type parameters of the trait
getAssocFunctionTypeAt(f, i, nonResPos, _) = TSelfTypeParameter(i)
)
not assocFunctionMentionsTypeParameterAtNonRetPos(i, f, tp)
}

/**
Expand Down Expand Up @@ -1184,7 +1197,7 @@ private module ContextTyping {
pragma[nomagic]
private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) }

signature Type inferCallTypeSig(AstNode n, boolean isReturn, TypePath path);
signature Type inferCallTypeSig(AstNode n, FunctionPosition pos, TypePath path);

/**
* Given a predicate `inferCallType` for inferring the type of a call at a given
Expand All @@ -1194,19 +1207,34 @@ private module ContextTyping {
*/
module CheckContextTyping<inferCallTypeSig/3 inferCallType> {
pragma[nomagic]
private Type inferCallTypeFromContextCand(AstNode n, TypePath prefix, TypePath path) {
result = inferCallType(n, false, path) and
private Type inferCallNonReturnType(AstNode n, FunctionPosition pos, TypePath path) {
result = inferCallType(n, pos, path) and
not pos.isReturn()
}

pragma[nomagic]
private Type inferCallNonReturnType(
AstNode n, FunctionPosition pos, TypePath prefix, TypePath path
) {
result = inferCallNonReturnType(n, pos, path) and
hasUnknownType(n) and
prefix = path.getAPrefix()
}

pragma[nomagic]
Type check(AstNode n, TypePath path) {
result = inferCallType(n, true, path)
result = inferCallType(n, any(FunctionPosition pos | pos.isReturn()), path)
or
exists(TypePath prefix |
result = inferCallTypeFromContextCand(n, prefix, path) and
exists(FunctionPosition pos, TypePath prefix |
result = inferCallNonReturnType(n, pos, prefix, path) and
hasUnknownTypeAt(n, prefix)
|
pos.isPosition()
or
// Never propagate type information directly into the receiver, since its type
// must already have been known in order to resolve the call
pos.isSelf() and
not prefix.isEmpty()
)
}
}
Expand Down Expand Up @@ -2607,12 +2635,9 @@ private Type inferMethodCallType0(
}

pragma[nomagic]
private Type inferMethodCallTypeNonSelf(AstNode n, boolean isReturn, TypePath path) {
exists(MethodCallMatchingInput::AccessPosition apos |
result = inferMethodCallType0(_, apos, n, _, path) and
not apos.isSelf() and
if apos.isReturn() then isReturn = true else isReturn = false
)
private Type inferMethodCallTypeNonSelf(AstNode n, FunctionPosition pos, TypePath path) {
result = inferMethodCallType0(_, pos, n, _, path) and
not pos.isSelf()
}

/**
Expand Down Expand Up @@ -2664,11 +2689,11 @@ private Type inferMethodCallTypeSelf(AstNode n, DerefChain derefChain, TypePath
)
}

private Type inferMethodCallTypePreCheck(AstNode n, boolean isReturn, TypePath path) {
result = inferMethodCallTypeNonSelf(n, isReturn, path)
private Type inferMethodCallTypePreCheck(AstNode n, FunctionPosition pos, TypePath path) {
result = inferMethodCallTypeNonSelf(n, pos, path)
or
result = inferMethodCallTypeSelf(n, DerefChain::nil(), path) and
isReturn = false
pos.isSelf()
}

/**
Expand Down Expand Up @@ -3301,14 +3326,11 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
private module NonMethodCallMatching = Matching<NonMethodCallMatchingInput>;

pragma[nomagic]
private Type inferNonMethodCallType0(AstNode n, boolean isReturn, TypePath path) {
exists(NonMethodCallMatchingInput::Access a, NonMethodCallMatchingInput::AccessPosition apos |
n = a.getNodeAt(apos) and
if apos.isReturn() then isReturn = true else isReturn = false
|
result = NonMethodCallMatching::inferAccessType(a, apos, path)
private Type inferNonMethodCallType0(AstNode n, FunctionPosition pos, TypePath path) {
exists(NonMethodCallMatchingInput::Access a | n = a.getNodeAt(pos) |
result = NonMethodCallMatching::inferAccessType(a, pos, path)
or
a.hasUnknownTypeAt(apos, path) and
a.hasUnknownTypeAt(pos, path) and
result = TUnknownType()
)
}
Expand Down Expand Up @@ -3379,11 +3401,10 @@ private module OperationMatchingInput implements MatchingInputSig {
private module OperationMatching = Matching<OperationMatchingInput>;

pragma[nomagic]
private Type inferOperationType0(AstNode n, boolean isReturn, TypePath path) {
exists(OperationMatchingInput::Access a, OperationMatchingInput::AccessPosition apos |
n = a.getNodeAt(apos) and
result = OperationMatching::inferAccessType(a, apos, path) and
if apos.isReturn() then isReturn = true else isReturn = false
private Type inferOperationType0(AstNode n, FunctionPosition pos, TypePath path) {
exists(OperationMatchingInput::Access a |
n = a.getNodeAt(pos) and
result = OperationMatching::inferAccessType(a, pos, path)
)
}

Expand Down Expand Up @@ -3716,11 +3737,13 @@ private module AwaitSatisfiesConstraintInput implements SatisfiesConstraintInput
}
}

private module AwaitSatisfiesConstraint =
SatisfiesConstraint<AwaitTarget, AwaitSatisfiesConstraintInput>;

pragma[nomagic]
private Type inferAwaitExprType(AstNode n, TypePath path) {
exists(TypePath exprPath |
SatisfiesConstraint<AwaitTarget, AwaitSatisfiesConstraintInput>::satisfiesConstraintType(n.(AwaitExpr)
.getExpr(), _, exprPath, result) and
AwaitSatisfiesConstraint::satisfiesConstraintType(n.(AwaitExpr).getExpr(), _, exprPath, result) and
exprPath.isCons(getFutureOutputTypeParameter(), path)
)
}
Expand Down Expand Up @@ -3922,13 +3945,15 @@ private AssociatedTypeTypeParameter getIntoIteratorItemTypeParameter() {
result = getAssociatedTypeTypeParameter(any(IntoIteratorTrait t).getItemType())
}

private module ForIterableSatisfiesConstraint =
SatisfiesConstraint<ForIterableExpr, ForIterableSatisfiesConstraintInput>;

pragma[nomagic]
private Type inferForLoopExprType(AstNode n, TypePath path) {
// type of iterable -> type of pattern (loop variable)
exists(ForExpr fe, TypePath exprPath, AssociatedTypeTypeParameter tp |
n = fe.getPat() and
SatisfiesConstraint<ForIterableExpr, ForIterableSatisfiesConstraintInput>::satisfiesConstraintType(fe.getIterable(),
_, exprPath, result) and
ForIterableSatisfiesConstraint::satisfiesConstraintType(fe.getIterable(), _, exprPath, result) and
exprPath.isCons(tp, path)
|
tp = getIntoIteratorItemTypeParameter()
Expand Down Expand Up @@ -3963,10 +3988,12 @@ private module InvokedClosureSatisfiesConstraintInput implements
}
}

private module InvokedClosureSatisfiesConstraint =
SatisfiesConstraint<InvokedClosureExpr, InvokedClosureSatisfiesConstraintInput>;

/** Gets the type of `ce` when viewed as an implementation of `FnOnce`. */
private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
SatisfiesConstraint<InvokedClosureExpr, InvokedClosureSatisfiesConstraintInput>::satisfiesConstraintType(ce,
_, path, result)
InvokedClosureSatisfiesConstraint::satisfiesConstraintType(ce, _, path, result)
}

/**
Expand Down
1 change: 1 addition & 0 deletions rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2740,6 +2740,7 @@ mod blanket_impl;
mod closure;
mod dereference;
mod dyn_type;
mod regressions;

fn main() {
field_access::f(); // $ target=f
Expand Down
34 changes: 34 additions & 0 deletions rust/ql/test/library-tests/type-inference/regressions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
mod regression1 {

pub struct S<T>(T);

pub enum E {
V { vec: Vec<E> },
}

impl<T> From<S<T>> for Option<T> {
fn from(s: S<T>) -> Self {
Some(s.0) // $ fieldof=S
}
}

pub fn f() -> E {
let mut vec_e = Vec::new(); // $ target=new
let mut opt_e = None;

let e = E::V { vec: Vec::new() }; // $ target=new

if let Some(e) = opt_e {
vec_e.push(e); // $ target=push
}
opt_e = e.into(); // $ target=into

#[rustfmt::skip]
let _ = if let Some(last) = vec_e.pop() // $ target=pop
{
opt_e = last.into(); // $ target=into
};

opt_e.unwrap() // $ target=unwrap
}
}
Loading
Loading