Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,9 @@ type VariantReceive struct {
}

type VariantCase struct {
Tag string // variant tag name
Variables []string // variables to bind payload fields
Body Statement
Tag string // variant tag name
Variables []string // variables to bind payload fields
Body []Statement // case body (may include scoped declarations)
}

func (vr *VariantReceive) statementNode() {}
Expand Down
40 changes: 26 additions & 14 deletions codegen/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,10 @@ func (g *Generator) containsPar(stmt ast.Statement) bool {
}
case *ast.VariantReceive:
for _, c := range s.Cases {
if c.Body != nil && g.containsPar(c.Body) {
return true
for _, inner := range c.Body {
if g.containsPar(inner) {
return true
}
}
}
}
Expand Down Expand Up @@ -639,8 +641,10 @@ func (g *Generator) containsPrint(stmt ast.Statement) bool {
}
case *ast.VariantReceive:
for _, c := range s.Cases {
if c.Body != nil && g.containsPrint(c.Body) {
return true
for _, inner := range c.Body {
if g.containsPrint(inner) {
return true
}
}
}
}
Expand Down Expand Up @@ -715,8 +719,10 @@ func (g *Generator) containsTimer(stmt ast.Statement) bool {
}
case *ast.VariantReceive:
for _, c := range s.Cases {
if c.Body != nil && g.containsTimer(c.Body) {
return true
for _, inner := range c.Body {
if g.containsTimer(inner) {
return true
}
}
}
}
Expand Down Expand Up @@ -788,8 +794,10 @@ func (g *Generator) containsStop(stmt ast.Statement) bool {
}
case *ast.VariantReceive:
for _, c := range s.Cases {
if c.Body != nil && g.containsStop(c.Body) {
return true
for _, inner := range c.Body {
if g.containsStop(inner) {
return true
}
}
}
}
Expand Down Expand Up @@ -909,8 +917,10 @@ func (g *Generator) containsMostExpr(stmt ast.Statement) bool {
}
case *ast.VariantReceive:
for _, c := range s.Cases {
if c.Body != nil && g.containsMostExpr(c.Body) {
return true
for _, inner := range c.Body {
if g.containsMostExpr(inner) {
return true
}
}
}
}
Expand Down Expand Up @@ -1417,8 +1427,8 @@ func (g *Generator) generateVariantReceive(vr *ast.VariantReceive) {
for i, v := range vc.Variables {
g.writeLine(fmt.Sprintf("%s = _v._%d", goIdent(v), i))
}
if vc.Body != nil {
g.generateStatement(vc.Body)
for _, s := range vc.Body {
g.generateStatement(s)
}
g.indent--
}
Expand Down Expand Up @@ -3084,8 +3094,10 @@ func (g *Generator) walkStatements(stmt ast.Statement, fn func(ast.Expression) b
}
case *ast.VariantReceive:
for _, c := range s.Cases {
if c.Body != nil && g.walkStatements(c.Body, fn) {
return true
for _, inner := range c.Body {
if g.walkStatements(inner, fn) {
return true
}
}
}
}
Expand Down
43 changes: 43 additions & 0 deletions codegen/e2e_protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,49 @@ SEQ
}
}

func TestE2E_VariantReceiveScopedDecl(t *testing.T) {
// Issue #86: scoped declarations in variant receive case bodies
occam := `PROTOCOL CMD
CASE
set.val; INT
evolve
terminate

PROC test(CHAN OF CMD control)
INT state:
BOOL running:
SEQ
state := 0
running := TRUE
WHILE running
control ? CASE
set.val; state
SKIP
evolve
INT next:
SEQ
next := state + 10
state := next
terminate
running := FALSE
print.int(state)
:
CHAN OF CMD ch:
SEQ
PAR
test(ch)
SEQ
ch ! set.val; 5
ch ! evolve
ch ! terminate
`
output := transpileCompileRun(t, occam)
expected := "15\n"
if output != expected {
t.Errorf("expected %q, got %q", expected, output)
}
}

func TestE2E_VariantProtocolTrailingColon(t *testing.T) {
// Issue #73: trailing colon on variant protocol declarations
occam := `PROTOCOL MSG
Expand Down
13 changes: 2 additions & 11 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -1384,12 +1384,7 @@ func (p *Parser) parseVariantReceive(channel string, token lexer.Token) *ast.Var
if p.peekTokenIs(lexer.INDENT) {
p.nextToken() // consume INDENT
p.nextToken() // move to body
vc.Body = p.parseStatement()

// Advance past the last token of the statement if needed
if !p.curTokenIs(lexer.NEWLINE) && !p.curTokenIs(lexer.DEDENT) && !p.curTokenIs(lexer.EOF) {
p.nextToken()
}
vc.Body = p.parseBodyStatements()
}

stmt.Cases = append(stmt.Cases, vc)
Expand Down Expand Up @@ -1479,11 +1474,7 @@ func (p *Parser) parseVariantReceiveWithIndex(channel string, channelIndices []a
if p.peekTokenIs(lexer.INDENT) {
p.nextToken() // consume INDENT
p.nextToken() // move to body
vc.Body = p.parseStatement()

if !p.curTokenIs(lexer.NEWLINE) && !p.curTokenIs(lexer.DEDENT) && !p.curTokenIs(lexer.EOF) {
p.nextToken()
}
vc.Body = p.parseBodyStatements()
}

stmt.Cases = append(stmt.Cases, vc)
Expand Down
76 changes: 76 additions & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3812,3 +3812,79 @@ func TestMultiDimOpenArrayParam(t *testing.T) {
t.Errorf("expected ChanElemType=INT, got %s", p0.ChanElemType)
}
}

func TestVariantReceiveScopedDecl(t *testing.T) {
input := `PROTOCOL CMD
CASE
data; INT
evolve
quit

PROC test(CHAN OF CMD ch)
BOOL done:
SEQ
done := FALSE
WHILE NOT done
ch ? CASE
data; done
SKIP
evolve
BOOL flag:
SEQ
flag := TRUE
done := flag
quit
done := TRUE
:
`
l := lexer.New(input)
pr := New(l)
program := pr.ParseProgram()
checkParserErrors(t, pr)

// Find the PROC
if len(program.Statements) < 2 {
t.Fatalf("expected at least 2 statements, got %d", len(program.Statements))
}
proc, ok := program.Statements[1].(*ast.ProcDecl)
if !ok {
t.Fatalf("expected ProcDecl, got %T", program.Statements[1])
}

// Walk to the variant receive inside the WHILE
// proc body: VarDecl(done), SeqBlock{ assign, WhileLoop{ VariantReceive } }
seq, ok := proc.Body[1].(*ast.SeqBlock)
if !ok {
t.Fatalf("expected SeqBlock, got %T", proc.Body[1])
}
wl, ok := seq.Statements[1].(*ast.WhileLoop)
if !ok {
t.Fatalf("expected WhileLoop, got %T", seq.Statements[1])
}
if len(wl.Body) < 1 {
t.Fatalf("expected at least 1 statement in while body, got %d", len(wl.Body))
}
vr, ok := wl.Body[0].(*ast.VariantReceive)
if !ok {
t.Fatalf("expected VariantReceive, got %T", wl.Body[0])
}

if len(vr.Cases) != 3 {
t.Fatalf("expected 3 variant cases, got %d", len(vr.Cases))
}

// "evolve" case should have 2 body statements: VarDecl + SeqBlock
evolveCase := vr.Cases[1]
if evolveCase.Tag != "evolve" {
t.Errorf("expected tag 'evolve', got %s", evolveCase.Tag)
}
if len(evolveCase.Body) != 2 {
t.Fatalf("expected 2 body statements in 'evolve' case, got %d", len(evolveCase.Body))
}
if _, ok := evolveCase.Body[0].(*ast.VarDecl); !ok {
t.Errorf("expected VarDecl as first body statement, got %T", evolveCase.Body[0])
}
if _, ok := evolveCase.Body[1].(*ast.SeqBlock); !ok {
t.Errorf("expected SeqBlock as second body statement, got %T", evolveCase.Body[1])
}
}
Loading