Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions checker/internal/type_check_env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include <cstddef>
#include <cstdint>
#include <string>

#include "absl/base/nullability.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -134,7 +133,7 @@ absl::StatusOr<absl::optional<VariableDecl>> TypeCheckEnv::LookupTypeConstant(
google::protobuf::Arena* absl_nonnull arena, absl::string_view name) const {
CEL_ASSIGN_OR_RETURN(absl::optional<Type> type, LookupTypeName(name));
if (type.has_value()) {
return MakeVariableDecl(std::string(type->name()), TypeType(arena, *type));
return MakeVariableDecl(type->name(), TypeType(arena, *type));
}

if (name.find('.') != name.npos) {
Expand Down Expand Up @@ -185,7 +184,7 @@ absl::StatusOr<absl::optional<StructTypeField>> TypeCheckEnv::LookupStructField(
return absl::nullopt;
}

const VariableDecl* absl_nullable VariableScope::LookupVariable(
const VariableDecl* absl_nullable VariableScope::LookupLocalVariable(
absl::string_view name) const {
const VariableScope* scope = this;
while (scope != nullptr) {
Expand All @@ -194,8 +193,7 @@ const VariableDecl* absl_nullable VariableScope::LookupVariable(
}
scope = scope->parent_;
}

return env_->LookupVariable(name);
return nullptr;
}

} // namespace cel::checker_internal
22 changes: 8 additions & 14 deletions checker/internal/type_check_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,11 @@ class TypeCheckEnv;
// Helper class for managing nested scopes and the local variables they
// implicitly declare.
//
// Nested scopes have a lifetime dependency on any parent scopes and the
// parent Type environment. Nested scopes should generally be managed by
// unique_ptrs.
// Nested scopes have a lifetime dependency on any parent scopes and should
// generally be managed by unique_ptrs.
class VariableScope {
public:
explicit VariableScope(const TypeCheckEnv& env ABSL_ATTRIBUTE_LIFETIME_BOUND)
: env_(&env), parent_(nullptr) {}
explicit VariableScope() : parent_(nullptr) {}

VariableScope(const VariableScope&) = delete;
VariableScope& operator=(const VariableScope&) = delete;
Expand All @@ -61,18 +59,17 @@ class VariableScope {

std::unique_ptr<VariableScope> MakeNestedScope() const
ABSL_ATTRIBUTE_LIFETIME_BOUND {
return absl::WrapUnique(new VariableScope(*env_, this));
return absl::WrapUnique(new VariableScope(this));
}

const VariableDecl* absl_nullable LookupVariable(
const VariableDecl* absl_nullable LookupLocalVariable(
absl::string_view name) const;

private:
VariableScope(const TypeCheckEnv& env ABSL_ATTRIBUTE_LIFETIME_BOUND,
const VariableScope* parent ABSL_ATTRIBUTE_LIFETIME_BOUND)
: env_(&env), parent_(parent) {}
explicit VariableScope(
const VariableScope* parent ABSL_ATTRIBUTE_LIFETIME_BOUND)
: parent_(parent) {}

const TypeCheckEnv* absl_nonnull env_;
const VariableScope* absl_nullable parent_;
absl::flat_hash_map<std::string, VariableDecl> variables_;
};
Expand Down Expand Up @@ -190,9 +187,6 @@ class TypeCheckEnv {
TypeCheckEnv MakeExtendedEnvironment() const ABSL_ATTRIBUTE_LIFETIME_BOUND {
return TypeCheckEnv(this);
}
VariableScope MakeVariableScope() const ABSL_ATTRIBUTE_LIFETIME_BOUND {
return VariableScope(*this);
}

const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const {
return descriptor_pool_.get();
Expand Down
86 changes: 58 additions & 28 deletions checker/internal/type_checker_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
Expand Down Expand Up @@ -246,7 +247,7 @@ class ResolveVisitor : public AstVisitorBase {
inference_context_(&inference_context),
issues_(&issues),
ast_(&ast),
root_scope_(env.MakeVariableScope()),
root_scope_(),
arena_(arena),
current_scope_(&root_scope_) {}

Expand Down Expand Up @@ -344,9 +345,13 @@ class ResolveVisitor : public AstVisitorBase {
absl::string_view function_name,
int arg_count, bool is_receiver);

// Resolves the function call shape (i.e. the number of arguments and call
// style) for the given function call.
const VariableDecl* absl_nullable LookupIdentifier(absl::string_view name);
// Resolves a global identifier (i.e. declared in the CEL environment).
const VariableDecl* absl_nullable LookupGlobalIdentifier(
absl::string_view name);

// Resolves a local identifier (i.e. a bind or comrprehension var).
const VariableDecl* absl_nullable LookupLocalIdentifier(
absl::string_view name);

// Resolves the applicable function overloads for the given function call.
//
Expand Down Expand Up @@ -967,10 +972,19 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr,
types_[&expr] = resolution->result_type;
}

const VariableDecl* absl_nullable ResolveVisitor::LookupIdentifier(
const VariableDecl* absl_nullable ResolveVisitor::LookupLocalIdentifier(
absl::string_view name) {
// Container resolution doesn't apply for local vars so .foo is redundant but
// legal.
if (absl::StartsWith(name, ".")) {
name = name.substr(1);
}
return current_scope_->LookupLocalVariable(name);
}

const VariableDecl* absl_nullable ResolveVisitor::LookupGlobalIdentifier(
absl::string_view name) {
if (const VariableDecl* decl = current_scope_->LookupVariable(name);
decl != nullptr) {
if (const VariableDecl* decl = env_->LookupVariable(name); decl != nullptr) {
return decl;
}
absl::StatusOr<absl::optional<VariableDecl>> constant =
Expand All @@ -996,22 +1010,31 @@ const VariableDecl* absl_nullable ResolveVisitor::LookupIdentifier(

void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr,
absl::string_view name) {
const VariableDecl* decl = nullptr;
// Local variables (comprehension, bind) are simple identifiers so we can
// skip generating the different namespace-qualified candidates.
const VariableDecl* decl = LookupLocalIdentifier(name);

if (decl != nullptr) {
attributes_[&expr] = decl;
types_[&expr] = inference_context_->InstantiateTypeParams(decl->type());
return;
}

namespace_generator_.GenerateCandidates(
name, [&decl, this](absl::string_view candidate) {
decl = LookupIdentifier(candidate);
decl = LookupGlobalIdentifier(candidate);
// continue searching.
return decl == nullptr;
});

if (decl == nullptr) {
ReportMissingReference(expr, name);
types_[&expr] = ErrorType();
if (decl != nullptr) {
attributes_[&expr] = decl;
types_[&expr] = inference_context_->InstantiateTypeParams(decl->type());
return;
}

attributes_[&expr] = decl;
types_[&expr] = inference_context_->InstantiateTypeParams(decl->type());
ReportMissingReference(expr, name);
types_[&expr] = ErrorType();
}

void ResolveVisitor::ResolveQualifiedIdentifier(
Expand All @@ -1021,26 +1044,34 @@ void ResolveVisitor::ResolveQualifiedIdentifier(
return;
}

const VariableDecl* absl_nullable decl = nullptr;
int segment_index_out = -1;
namespace_generator_.GenerateCandidates(
qualifiers, [&decl, &segment_index_out, this](absl::string_view candidate,
int segment_index) {
decl = LookupIdentifier(candidate);
if (decl != nullptr) {
segment_index_out = segment_index;
return false;
}
return true;
});
// Local variables (comprehension, bind) are simple identifiers so we can
// skip generating the different namespace-qualified candidates.
const VariableDecl* decl = LookupLocalIdentifier(qualifiers[0]);
int matched_segment_index = -1;

if (decl != nullptr) {
matched_segment_index = 0;
} else {
namespace_generator_.GenerateCandidates(
qualifiers, [&decl, &matched_segment_index, this](
absl::string_view candidate, int segment_index) {
decl = LookupGlobalIdentifier(candidate);
if (decl != nullptr) {
matched_segment_index = segment_index;
return false;
}
return true;
});
}

if (decl == nullptr) {
ReportMissingReference(expr, FormatCandidate(qualifiers));
types_[&expr] = ErrorType();
return;
}

const int num_select_opts = qualifiers.size() - segment_index_out - 1;
const int num_select_opts = qualifiers.size() - matched_segment_index - 1;

const Expr* root = &expr;
std::vector<const Expr*> select_opts;
select_opts.reserve(num_select_opts);
Expand Down Expand Up @@ -1211,7 +1242,6 @@ class ResolveRewriter : public AstRewriterBase {
auto& ast_ref = reference_map_[expr.id()];
ast_ref.set_name(decl->name());
for (const auto& overload : decl->overloads()) {
// TODO(uncreated-issue/72): narrow based on type inferences and shape.
ast_ref.mutable_overload_id().push_back(overload.id());
}
expr.mutable_call_expr().set_function(decl->name());
Expand Down
16 changes: 8 additions & 8 deletions checker/internal/type_checker_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::HasSubstr;
using ::testing::IsEmpty;
using ::testing::Not;
using ::testing::Pair;
using ::testing::Property;
using ::testing::SizeIs;
Expand Down Expand Up @@ -750,47 +751,46 @@ TEST(TypeCheckerImplTest, NestedComprehensions) {
EXPECT_THAT(result.GetIssues(), IsEmpty());
}

TEST(TypeCheckerImplTest, ComprehensionVarsFollowNamespacePriorityRules) {
TEST(TypeCheckerImplTest, ComprehensionVarsShadowNamespacePriorityRules) {
TypeCheckEnv env(GetSharedTestingDescriptorPool());
env.set_container("com");
google::protobuf::Arena arena;
ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk());

// Namespace resolution still applies, compre var doesn't shadow com.x
// Namespace compre var shadows com.x
env.InsertVariableIfAbsent(MakeVariableDecl("com.x", IntType()));

TypeCheckerImpl impl(std::move(env));
ASSERT_OK_AND_ASSIGN(auto ast,
MakeTestParsedAst("['1', '2'].all(x, x == 2)"));
MakeTestParsedAst("['1', '2'].exists(x, x == '2')"));
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));

EXPECT_TRUE(result.IsValid());

EXPECT_THAT(result.GetIssues(), IsEmpty());
ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst());
EXPECT_THAT(checked_ast->reference_map(),
Contains(Pair(_, IsVariableReference("com.x"))));
Not(Contains(Pair(_, IsVariableReference("com.x")))));
}

TEST(TypeCheckerImplTest, ComprehensionVarsFollowQualifiedIdentPriority) {
TEST(TypeCheckerImplTest, ComprehensionVarsShadowsQualifiedIdent) {
TypeCheckEnv env(GetSharedTestingDescriptorPool());
google::protobuf::Arena arena;
ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk());

// Namespace resolution still applies, compre var doesn't shadow x.y
env.InsertVariableIfAbsent(MakeVariableDecl("x.y", IntType()));

TypeCheckerImpl impl(std::move(env));
ASSERT_OK_AND_ASSIGN(auto ast,
MakeTestParsedAst("[{'y': '2'}].all(x, x.y == 2)"));
MakeTestParsedAst("[{'y': '2'}].all(x, x.y == '2')"));
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));

EXPECT_TRUE(result.IsValid());

EXPECT_THAT(result.GetIssues(), IsEmpty());
ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst());
EXPECT_THAT(checked_ast->reference_map(),
Contains(Pair(_, IsVariableReference("x.y"))));
Not(Contains(Pair(_, IsVariableReference("x.y")))));
}

TEST(TypeCheckerImplTest, ComprehensionVarsCyclicParamAssignability) {
Expand Down