diff --git a/ast/ast.go b/ast/ast.go index 95752e6..a504e54 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -453,9 +453,9 @@ type VariantReceive struct { } type VariantCase struct { - Tag string // variant tag name - Variables []string // variables to bind payload fields - Body Statement + Tag string // variant tag name + Variables []string // variables to bind payload fields + Body []Statement // case body (may include scoped declarations) } func (vr *VariantReceive) statementNode() {} diff --git a/codegen/codegen.go b/codegen/codegen.go index cec8b5b..849ec88 100644 --- a/codegen/codegen.go +++ b/codegen/codegen.go @@ -566,8 +566,10 @@ func (g *Generator) containsPar(stmt ast.Statement) bool { } case *ast.VariantReceive: for _, c := range s.Cases { - if c.Body != nil && g.containsPar(c.Body) { - return true + for _, inner := range c.Body { + if g.containsPar(inner) { + return true + } } } } @@ -639,8 +641,10 @@ func (g *Generator) containsPrint(stmt ast.Statement) bool { } case *ast.VariantReceive: for _, c := range s.Cases { - if c.Body != nil && g.containsPrint(c.Body) { - return true + for _, inner := range c.Body { + if g.containsPrint(inner) { + return true + } } } } @@ -715,8 +719,10 @@ func (g *Generator) containsTimer(stmt ast.Statement) bool { } case *ast.VariantReceive: for _, c := range s.Cases { - if c.Body != nil && g.containsTimer(c.Body) { - return true + for _, inner := range c.Body { + if g.containsTimer(inner) { + return true + } } } } @@ -788,8 +794,10 @@ func (g *Generator) containsStop(stmt ast.Statement) bool { } case *ast.VariantReceive: for _, c := range s.Cases { - if c.Body != nil && g.containsStop(c.Body) { - return true + for _, inner := range c.Body { + if g.containsStop(inner) { + return true + } } } } @@ -909,8 +917,10 @@ func (g *Generator) containsMostExpr(stmt ast.Statement) bool { } case *ast.VariantReceive: for _, c := range s.Cases { - if c.Body != nil && g.containsMostExpr(c.Body) { - return true + for _, inner := range c.Body { + if g.containsMostExpr(inner) { + return true + } } } } @@ -1417,8 +1427,8 @@ func (g *Generator) generateVariantReceive(vr *ast.VariantReceive) { for i, v := range vc.Variables { g.writeLine(fmt.Sprintf("%s = _v._%d", goIdent(v), i)) } - if vc.Body != nil { - g.generateStatement(vc.Body) + for _, s := range vc.Body { + g.generateStatement(s) } g.indent-- } @@ -3084,8 +3094,10 @@ func (g *Generator) walkStatements(stmt ast.Statement, fn func(ast.Expression) b } case *ast.VariantReceive: for _, c := range s.Cases { - if c.Body != nil && g.walkStatements(c.Body, fn) { - return true + for _, inner := range c.Body { + if g.walkStatements(inner, fn) { + return true + } } } } diff --git a/codegen/e2e_protocol_test.go b/codegen/e2e_protocol_test.go index 900ff95..292d718 100644 --- a/codegen/e2e_protocol_test.go +++ b/codegen/e2e_protocol_test.go @@ -174,6 +174,49 @@ SEQ } } +func TestE2E_VariantReceiveScopedDecl(t *testing.T) { + // Issue #86: scoped declarations in variant receive case bodies + occam := `PROTOCOL CMD + CASE + set.val; INT + evolve + terminate + +PROC test(CHAN OF CMD control) + INT state: + BOOL running: + SEQ + state := 0 + running := TRUE + WHILE running + control ? CASE + set.val; state + SKIP + evolve + INT next: + SEQ + next := state + 10 + state := next + terminate + running := FALSE + print.int(state) +: +CHAN OF CMD ch: +SEQ + PAR + test(ch) + SEQ + ch ! set.val; 5 + ch ! evolve + ch ! terminate +` + output := transpileCompileRun(t, occam) + expected := "15\n" + if output != expected { + t.Errorf("expected %q, got %q", expected, output) + } +} + func TestE2E_VariantProtocolTrailingColon(t *testing.T) { // Issue #73: trailing colon on variant protocol declarations occam := `PROTOCOL MSG diff --git a/parser/parser.go b/parser/parser.go index 3f78b71..704fcb8 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -1384,12 +1384,7 @@ func (p *Parser) parseVariantReceive(channel string, token lexer.Token) *ast.Var if p.peekTokenIs(lexer.INDENT) { p.nextToken() // consume INDENT p.nextToken() // move to body - vc.Body = p.parseStatement() - - // Advance past the last token of the statement if needed - if !p.curTokenIs(lexer.NEWLINE) && !p.curTokenIs(lexer.DEDENT) && !p.curTokenIs(lexer.EOF) { - p.nextToken() - } + vc.Body = p.parseBodyStatements() } stmt.Cases = append(stmt.Cases, vc) @@ -1479,11 +1474,7 @@ func (p *Parser) parseVariantReceiveWithIndex(channel string, channelIndices []a if p.peekTokenIs(lexer.INDENT) { p.nextToken() // consume INDENT p.nextToken() // move to body - vc.Body = p.parseStatement() - - if !p.curTokenIs(lexer.NEWLINE) && !p.curTokenIs(lexer.DEDENT) && !p.curTokenIs(lexer.EOF) { - p.nextToken() - } + vc.Body = p.parseBodyStatements() } stmt.Cases = append(stmt.Cases, vc) diff --git a/parser/parser_test.go b/parser/parser_test.go index ad499ac..1002b29 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -3812,3 +3812,79 @@ func TestMultiDimOpenArrayParam(t *testing.T) { t.Errorf("expected ChanElemType=INT, got %s", p0.ChanElemType) } } + +func TestVariantReceiveScopedDecl(t *testing.T) { + input := `PROTOCOL CMD + CASE + data; INT + evolve + quit + +PROC test(CHAN OF CMD ch) + BOOL done: + SEQ + done := FALSE + WHILE NOT done + ch ? CASE + data; done + SKIP + evolve + BOOL flag: + SEQ + flag := TRUE + done := flag + quit + done := TRUE +: +` + l := lexer.New(input) + pr := New(l) + program := pr.ParseProgram() + checkParserErrors(t, pr) + + // Find the PROC + if len(program.Statements) < 2 { + t.Fatalf("expected at least 2 statements, got %d", len(program.Statements)) + } + proc, ok := program.Statements[1].(*ast.ProcDecl) + if !ok { + t.Fatalf("expected ProcDecl, got %T", program.Statements[1]) + } + + // Walk to the variant receive inside the WHILE + // proc body: VarDecl(done), SeqBlock{ assign, WhileLoop{ VariantReceive } } + seq, ok := proc.Body[1].(*ast.SeqBlock) + if !ok { + t.Fatalf("expected SeqBlock, got %T", proc.Body[1]) + } + wl, ok := seq.Statements[1].(*ast.WhileLoop) + if !ok { + t.Fatalf("expected WhileLoop, got %T", seq.Statements[1]) + } + if len(wl.Body) < 1 { + t.Fatalf("expected at least 1 statement in while body, got %d", len(wl.Body)) + } + vr, ok := wl.Body[0].(*ast.VariantReceive) + if !ok { + t.Fatalf("expected VariantReceive, got %T", wl.Body[0]) + } + + if len(vr.Cases) != 3 { + t.Fatalf("expected 3 variant cases, got %d", len(vr.Cases)) + } + + // "evolve" case should have 2 body statements: VarDecl + SeqBlock + evolveCase := vr.Cases[1] + if evolveCase.Tag != "evolve" { + t.Errorf("expected tag 'evolve', got %s", evolveCase.Tag) + } + if len(evolveCase.Body) != 2 { + t.Fatalf("expected 2 body statements in 'evolve' case, got %d", len(evolveCase.Body)) + } + if _, ok := evolveCase.Body[0].(*ast.VarDecl); !ok { + t.Errorf("expected VarDecl as first body statement, got %T", evolveCase.Body[0]) + } + if _, ok := evolveCase.Body[1].(*ast.SeqBlock); !ok { + t.Errorf("expected SeqBlock as second body statement, got %T", evolveCase.Body[1]) + } +}