Skip to content

Commit 3e1b202

Browse files
stackdumpclaude
andcommitted
Add DSL guardrails: type checking, name validation, guard verification
Catches errors at .btw build time instead of at Solidity compile time: - Duplicate names: registers, events, functions with same name - Reserved names: schema/register/function names that conflict with Solidity keywords (function, event, mapping), built-ins (msg, block), forge-std (Test, Script), or generated internals (contractOwner) - Arc type mismatch: indexing a scalar register X[key], or wrong key depth on nested maps (1 key on a map[address,address]uint256) - Unknown registers in arcs: NOPE -|amount|> fn - Unknown identifiers in guards: require(NOPE >= amount) - Undeclared events: @event NoSuchEvent - Unknown types: register X maps[address]uint256 (typo) 10 test cases in TestValidation covering each guardrail. All existing examples and tests still pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8158982 commit 3e1b202

File tree

4 files changed

+334
-3
lines changed

4 files changed

+334
-3
lines changed

dsl/builder.go

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,237 @@ func Build(ast *Schema) (*metamodel.Schema, error) {
8383
}
8484
}
8585

86+
// Validate the built schema
87+
if err := validate(ast, s); err != nil {
88+
return nil, err
89+
}
90+
8691
return s, nil
8792
}
8893

94+
// validate checks the DSL AST and built schema for common errors.
95+
func validate(ast *Schema, s *metamodel.Schema) error {
96+
// Build lookup tables
97+
registers := make(map[string]Register)
98+
for _, r := range ast.Registers {
99+
registers[r.Name] = r
100+
}
101+
102+
fnVars := make(map[string]map[string]bool) // fn name → var names
103+
for _, fn := range ast.Functions {
104+
vars := make(map[string]bool)
105+
for _, v := range fn.Vars {
106+
vars[v.Name] = true
107+
}
108+
fnVars[fn.Name] = vars
109+
}
110+
111+
// 1. Duplicate detection
112+
{
113+
seen := make(map[string]string) // name → kind
114+
for _, r := range ast.Registers {
115+
if prev, ok := seen[r.Name]; ok {
116+
return fmt.Errorf("duplicate name %q (already declared as %s)", r.Name, prev)
117+
}
118+
seen[r.Name] = "register"
119+
}
120+
for _, e := range ast.Events {
121+
if prev, ok := seen[e.Name]; ok {
122+
return fmt.Errorf("duplicate name %q (already declared as %s)", e.Name, prev)
123+
}
124+
seen[e.Name] = "event"
125+
}
126+
for _, f := range ast.Functions {
127+
if prev, ok := seen[f.Name]; ok {
128+
return fmt.Errorf("duplicate name %q (already declared as %s)", f.Name, prev)
129+
}
130+
seen[f.Name] = "function"
131+
}
132+
}
133+
134+
// 2. Reserved name checking
135+
if err := checkReservedName(ast.Name, "schema"); err != nil {
136+
return err
137+
}
138+
for _, r := range ast.Registers {
139+
if err := checkReservedName(r.Name, "register"); err != nil {
140+
return err
141+
}
142+
}
143+
for _, f := range ast.Functions {
144+
if err := checkReservedName(f.Name, "function"); err != nil {
145+
return err
146+
}
147+
}
148+
149+
// 3. Arc type checking
150+
for _, fn := range ast.Functions {
151+
for _, arc := range fn.Arcs {
152+
// Determine the place side (not the function side)
153+
placeName := arc.Source
154+
indices := arc.SourceIndices
155+
if arc.Source == fn.Name {
156+
placeName = arc.Target
157+
indices = arc.TargetIndices
158+
}
159+
160+
reg, ok := registers[placeName]
161+
if !ok {
162+
// Arc references a non-existent register
163+
return fmt.Errorf("function %s: arc references unknown register %q", fn.Name, placeName)
164+
}
165+
166+
mapDepth := mapKeyDepth(reg.Type)
167+
indexCount := len(indices)
168+
169+
if mapDepth == 0 && indexCount > 0 {
170+
return fmt.Errorf("function %s: register %s is %s (scalar), cannot index with [%s]",
171+
fn.Name, reg.Name, reg.Type, strings.Join(indices, "]["))
172+
}
173+
174+
if indexCount > 0 && indexCount != mapDepth {
175+
return fmt.Errorf("function %s: register %s needs %d index key(s) (type %s), got %d",
176+
fn.Name, reg.Name, mapDepth, reg.Type, indexCount)
177+
}
178+
}
179+
}
180+
181+
// 4. Guard variable validation
182+
for _, fn := range ast.Functions {
183+
if fn.Require == "" {
184+
continue
185+
}
186+
vars := fnVars[fn.Name]
187+
if err := validateGuardIdents(fn.Name, fn.Require, registers, vars); err != nil {
188+
return err
189+
}
190+
}
191+
192+
// 5. Event reference validation
193+
eventNames := make(map[string]bool)
194+
for _, e := range ast.Events {
195+
eventNames[e.Name] = true
196+
}
197+
for _, fn := range ast.Functions {
198+
if fn.EventRef != "" && !eventNames[fn.EventRef] {
199+
return fmt.Errorf("function %s: @event references undeclared event %q", fn.Name, fn.EventRef)
200+
}
201+
}
202+
203+
return nil
204+
}
205+
206+
// checkReservedName returns an error if the name conflicts with Solidity or Foundry.
207+
func checkReservedName(name, kind string) error {
208+
reserved := map[string]string{
209+
// Solidity keywords
210+
"function": "Solidity keyword", "event": "Solidity keyword",
211+
"mapping": "Solidity keyword", "constructor": "Solidity keyword",
212+
"require": "Solidity keyword", "revert": "Solidity keyword",
213+
"assert": "Solidity keyword", "return": "Solidity keyword",
214+
"if": "Solidity keyword", "else": "Solidity keyword",
215+
"for": "Solidity keyword", "while": "Solidity keyword",
216+
"true": "Solidity keyword", "false": "Solidity keyword",
217+
// Solidity built-in variables
218+
"msg": "Solidity built-in", "block": "Solidity built-in",
219+
"tx": "Solidity built-in", "this": "Solidity built-in",
220+
// Forge-std conflicts
221+
"Test": "forge-std class", "Script": "forge-std class",
222+
"Vm": "forge-std class", "console": "forge-std class",
223+
// Generated contract internals
224+
"contractOwner": "generated internal", "currentEpoch": "generated internal",
225+
"eventSequence": "generated internal",
226+
}
227+
if reason, ok := reserved[name]; ok {
228+
return fmt.Errorf("%s name %q conflicts with %s", kind, name, reason)
229+
}
230+
return nil
231+
}
232+
233+
// mapKeyDepth returns the nesting depth of a map type.
234+
// "uint256" → 0, "map[address]uint256" → 1, "map[address]map[address]uint256" → 2
235+
func mapKeyDepth(typ string) int {
236+
depth := 0
237+
remaining := typ
238+
for strings.HasPrefix(remaining, "map[") {
239+
close := strings.Index(remaining, "]")
240+
if close == -1 {
241+
break
242+
}
243+
depth++
244+
remaining = remaining[close+1:]
245+
}
246+
return depth
247+
}
248+
249+
// validateGuardIdents checks that all identifiers in a guard expression
250+
// are known registers, function variables, or built-in names.
251+
func validateGuardIdents(fnName, guard string, registers map[string]Register, vars map[string]bool) error {
252+
// Extract identifiers from the guard using simple scanning
253+
// (avoid importing guard package to prevent circular deps)
254+
idents := extractIdents(guard)
255+
256+
builtins := map[string]bool{
257+
"caller": true, "address": true, "true": true, "false": true,
258+
"msg": true, "sender": true, // for msg.sender
259+
}
260+
261+
for _, id := range idents {
262+
if builtins[id] {
263+
continue
264+
}
265+
if _, ok := registers[id]; ok {
266+
continue
267+
}
268+
if vars[id] {
269+
continue
270+
}
271+
// Allow numeric literals
272+
if isNumeric(id) {
273+
continue
274+
}
275+
// Allow known function-like calls (vestedAmount, etc.)
276+
// These are generated helper functions, not user-defined
277+
if strings.Contains(guard, id+"(") {
278+
continue
279+
}
280+
return fmt.Errorf("function %s: guard references unknown identifier %q", fnName, id)
281+
}
282+
return nil
283+
}
284+
285+
// extractIdents returns all identifier-like tokens from an expression string.
286+
func extractIdents(expr string) []string {
287+
var idents []string
288+
i := 0
289+
for i < len(expr) {
290+
if isLetter(expr[i]) || expr[i] == '_' {
291+
start := i
292+
for i < len(expr) && (isLetter(expr[i]) || isDigit(expr[i]) || expr[i] == '_' || expr[i] == '.') {
293+
i++
294+
}
295+
idents = append(idents, expr[start:i])
296+
} else {
297+
i++
298+
}
299+
}
300+
return idents
301+
}
302+
303+
func isLetter(c byte) bool { return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') }
304+
func isDigit(c byte) bool { return c >= '0' && c <= '9' }
305+
func isNumeric(s string) bool {
306+
if len(s) == 0 {
307+
return false
308+
}
309+
for _, c := range s {
310+
if c < '0' || c > '9' {
311+
return false
312+
}
313+
}
314+
return true
315+
}
316+
89317
// buildArc converts a DSL Arc into a metamodel.Arc.
90318
// The function name determines direction:
91319
// - PLACE -|w|> fnName => input arc (Source=PLACE, Target=fnName)

dsl/dsl_test.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package dsl
22

33
import (
4+
"strings"
45
"testing"
56
)
67

@@ -526,3 +527,91 @@ schema Test {
526527
t.Errorf("expected map[uint256]map[address]map[uint256]bool, got %s", ast.Registers[1].Type)
527528
}
528529
}
530+
531+
func TestValidation(t *testing.T) {
532+
cases := []struct {
533+
name string
534+
src string
535+
want string // substring of expected error
536+
}{
537+
{
538+
name: "duplicate register",
539+
src: `schema Foo { version "1.0" register X uint256 register X uint256 }`,
540+
want: `duplicate name "X"`,
541+
},
542+
{
543+
name: "reserved schema name",
544+
src: `schema Test { version "1.0" }`,
545+
want: `schema name "Test" conflicts with forge-std class`,
546+
},
547+
{
548+
name: "reserved function name",
549+
src: `schema Foo { version "1.0" register X uint256 fn(msg) { X -|1|> msg } }`,
550+
want: `function name "msg" conflicts with Solidity built-in`,
551+
},
552+
{
553+
name: "scalar indexed",
554+
src: `schema Foo { version "1.0" register X uint256 fn(f) { var k address X[k] -|1|> f } }`,
555+
want: "register X is uint256 (scalar), cannot index",
556+
},
557+
{
558+
name: "unknown register in arc",
559+
src: `schema Foo { version "1.0" fn(f) { var x amount NOPE -|x|> f } }`,
560+
want: `unknown register "NOPE"`,
561+
},
562+
{
563+
name: "unknown identifier in guard",
564+
src: `schema Foo { version "1.0" register X uint256 fn(f) { var amount amount require(NOPE >= amount) f -|amount|> X } }`,
565+
want: `guard references unknown identifier "NOPE"`,
566+
},
567+
{
568+
name: "undeclared event",
569+
src: `schema Foo { version "1.0" register X uint256 fn(f) { var amount amount @event Nope f -|amount|> X } }`,
570+
want: `undeclared event "Nope"`,
571+
},
572+
{
573+
name: "unknown type",
574+
src: `schema Foo { version "1.0" register X maps[address]uint256 }`,
575+
want: `unknown type "maps"`,
576+
},
577+
{
578+
name: "map key depth mismatch",
579+
src: `schema Foo { version "1.0" register X map[address,address]uint256 fn(f) { var k address X[k] -|1|> f } }`,
580+
want: "needs 2 index key(s)",
581+
},
582+
{
583+
name: "valid schema passes",
584+
src: `schema Foo { version "1.0"
585+
register X map[address]uint256 observable
586+
event E { to: address indexed amount: uint256 }
587+
fn(inc) { var to address var amount amount @event E inc -|amount|> X[to] }
588+
}`,
589+
want: "", // no error
590+
},
591+
}
592+
593+
for _, tc := range cases {
594+
t.Run(tc.name, func(t *testing.T) {
595+
ast, err := Parse(tc.src)
596+
if err != nil {
597+
if tc.want != "" && strings.Contains(err.Error(), tc.want) {
598+
return // parse error matches expected
599+
}
600+
t.Fatalf("unexpected parse error: %v", err)
601+
}
602+
_, err = Build(ast)
603+
if tc.want == "" {
604+
if err != nil {
605+
t.Fatalf("expected no error, got: %v", err)
606+
}
607+
return
608+
}
609+
if err == nil {
610+
t.Fatalf("expected error containing %q, got nil", tc.want)
611+
}
612+
if !strings.Contains(err.Error(), tc.want) {
613+
t.Fatalf("expected error containing %q, got: %v", tc.want, err)
614+
}
615+
})
616+
}
617+
}

dsl/parser.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,9 @@ func (p *Parser) parseTypeString() (string, error) {
218218
// Check for map[...] type
219219
// Supports: map[address]uint256, map[address]map[address]uint256,
220220
// and shorthand: map[address,address]uint256 → map[address]map[address]uint256
221+
if tok.Value != "map" && !isKnownType(tok.Value) {
222+
return "", fmt.Errorf("line %d: unknown type %q (expected uint256, address, bool, or map[...])", tok.Line, tok.Value)
223+
}
221224
if tok.Value == "map" && p.peek().Type == TokenLBracket {
222225
p.advance() // [
223226
// Collect comma-separated key types
@@ -492,3 +495,14 @@ func (p *Parser) parsePlaceRef() (string, []string, error) {
492495

493496
return nameTok.Value, indices, nil
494497
}
498+
499+
// isKnownType returns true for valid Solidity-compatible type names.
500+
func isKnownType(t string) bool {
501+
known := map[string]bool{
502+
"uint256": true, "uint128": true, "uint64": true, "uint32": true, "uint16": true, "uint8": true,
503+
"int256": true, "int128": true, "int64": true, "int32": true, "int16": true, "int8": true,
504+
"address": true, "bool": true, "bytes32": true, "bytes": true, "string": true,
505+
"amount": true, // DSL shorthand for uint256
506+
}
507+
return known[t]
508+
}

internal/server/server_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ func TestPostSVG(t *testing.T) {
278278

279279
func TestCompile(t *testing.T) {
280280
srv := testServer(t)
281-
btw := `schema Test {
281+
btw := `schema Counter {
282282
version "1.0"
283283
register balance uint256 observable
284284
fn(inc) {
@@ -295,7 +295,7 @@ func TestCompile(t *testing.T) {
295295
}
296296
var resp map[string]interface{}
297297
json.Unmarshal(w.Body.Bytes(), &resp)
298-
if resp["name"] != "Test" {
299-
t.Fatalf("expected name=Test, got %v", resp["name"])
298+
if resp["name"] != "Counter" {
299+
t.Fatalf("expected name=Counter, got %v", resp["name"])
300300
}
301301
}

0 commit comments

Comments
 (0)