diff --git a/ast/ast.go b/ast/ast.go index a504e54..634f647 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -373,10 +373,11 @@ func (s *Send) TokenLiteral() string { return s.Token.Literal } // Receive represents a channel receive: c ? x or c ? x ; y type Receive struct { Token lexer.Token // the ? token - Channel string // channel name - ChannelIndices []Expression // non-empty for cs[i] ? x or cs[i][j] ? x - Variable string // variable to receive into (simple receive) - Variables []string // additional variables for sequential receives (c ? x ; y) + Channel string // channel name + ChannelIndices []Expression // non-empty for cs[i] ? x or cs[i][j] ? x + Variable string // variable to receive into (simple receive) + VariableIndices []Expression // non-empty for c ? flags[0] or c ? grid[i][j] + Variables []string // additional variables for sequential receives (c ? x ; y) } func (r *Receive) statementNode() {} @@ -396,11 +397,12 @@ func (a *AltBlock) TokenLiteral() string { return a.Token.Literal } // AltCase represents a single case in an ALT block type AltCase struct { - Guard Expression // optional guard condition (nil if no guard) - Channel string // channel name - ChannelIndices []Expression // non-empty for cs[i] ? x or cs[i][j] ? x in ALT - Variable string // variable to receive into - Body []Statement // the body to execute + Guard Expression // optional guard condition (nil if no guard) + Channel string // channel name + ChannelIndices []Expression // non-empty for cs[i] ? x or cs[i][j] ? x in ALT + Variable string // variable to receive into + VariableIndices []Expression // non-empty for c ? flags[0] or c ? grid[i][j] + Body []Statement // the body to execute IsTimer bool // true if this is a timer AFTER case IsSkip bool // true if this is a guarded SKIP case (guard & SKIP) Timer string // timer name (when IsTimer) diff --git a/codegen/codegen.go b/codegen/codegen.go index 849ec88..8d6e4c8 100644 --- a/codegen/codegen.go +++ b/codegen/codegen.go @@ -1346,7 +1346,9 @@ func (g *Generator) generateReceive(recv *ast.Receive) { g.tmpCounter++ g.writeLine(fmt.Sprintf("%s := <-%s", tmpName, chanRef)) varRef := goIdent(recv.Variable) - if g.refParams[recv.Variable] { + if len(recv.VariableIndices) > 0 { + varRef += g.generateIndicesStr(recv.VariableIndices) + } else if g.refParams[recv.Variable] { varRef = "*" + varRef } g.writeLine(fmt.Sprintf("%s = %s._0", varRef, tmpName)) @@ -1359,7 +1361,9 @@ func (g *Generator) generateReceive(recv *ast.Receive) { } } else { varRef := goIdent(recv.Variable) - if g.refParams[recv.Variable] { + if len(recv.VariableIndices) > 0 { + varRef += g.generateIndicesStr(recv.VariableIndices) + } else if g.refParams[recv.Variable] { varRef = "*" + varRef } g.writeLine(fmt.Sprintf("%s = <-%s", varRef, chanRef)) @@ -1903,13 +1907,25 @@ func (g *Generator) generateAltBlock(alt *ast.AltBlock) { g.generateExpression(c.Deadline) g.write(" - int(time.Now().UnixMicro())) * time.Microsecond):\n") } else if c.Guard != nil { - g.write(fmt.Sprintf("case %s = <-_alt%d:\n", goIdent(c.Variable), i)) + varRef := goIdent(c.Variable) + if len(c.VariableIndices) > 0 { + varRef += g.generateIndicesStr(c.VariableIndices) + } + g.write(fmt.Sprintf("case %s = <-_alt%d:\n", varRef, i)) } else if len(c.ChannelIndices) > 0 { - g.write(fmt.Sprintf("case %s = <-%s", goIdent(c.Variable), goIdent(c.Channel))) + varRef := goIdent(c.Variable) + if len(c.VariableIndices) > 0 { + varRef += g.generateIndicesStr(c.VariableIndices) + } + g.write(fmt.Sprintf("case %s = <-%s", varRef, goIdent(c.Channel))) g.generateIndices(c.ChannelIndices) g.write(":\n") } else { - g.write(fmt.Sprintf("case %s = <-%s:\n", goIdent(c.Variable), goIdent(c.Channel))) + varRef := goIdent(c.Variable) + if len(c.VariableIndices) > 0 { + varRef += g.generateIndicesStr(c.VariableIndices) + } + g.write(fmt.Sprintf("case %s = <-%s:\n", varRef, goIdent(c.Channel))) } g.indent++ guardedSkip := c.IsSkip && c.Guard != nil @@ -2040,7 +2056,11 @@ func (g *Generator) generateReplicatedAlt(alt *ast.AltBlock) { } // Assign received value from reflect.Value - g.writeLine(fmt.Sprintf("%s = _altValue.Interface().(%s)", goIdent(c.Variable), recvType)) + varRef := goIdent(c.Variable) + if len(c.VariableIndices) > 0 { + varRef += g.generateIndicesStr(c.VariableIndices) + } + g.writeLine(fmt.Sprintf("%s = _altValue.Interface().(%s)", varRef, recvType)) // Generate body for _, s := range c.Body { diff --git a/codegen/e2e_concurrency_test.go b/codegen/e2e_concurrency_test.go index 229120a..ccb8317 100644 --- a/codegen/e2e_concurrency_test.go +++ b/codegen/e2e_concurrency_test.go @@ -392,3 +392,41 @@ func TestE2E_PriPar(t *testing.T) { t.Errorf("expected %q, got %q", expected, output) } } + +func TestE2E_ReceiveIntoIndexedVariable(t *testing.T) { + occam := `SEQ + CHAN OF INT c: + [3]INT arr: + arr[0] := 0 + arr[1] := 0 + arr[2] := 0 + PAR + c ! 42 + c ? arr[1] + print.int(arr[1]) +` + output := transpileCompileRun(t, occam) + expected := "42\n" + if output != expected { + t.Errorf("expected %q, got %q", expected, output) + } +} + +func TestE2E_IndexedChannelReceiveIntoIndexedVariable(t *testing.T) { + occam := `SEQ + [2]CHAN OF INT cs: + [3]INT arr: + arr[0] := 0 + arr[1] := 0 + arr[2] := 0 + PAR + cs[0] ! 99 + cs[0] ? arr[2] + print.int(arr[2]) +` + output := transpileCompileRun(t, occam) + expected := "99\n" + if output != expected { + t.Errorf("expected %q, got %q", expected, output) + } +} diff --git a/parser/parser.go b/parser/parser.go index 704fcb8..9e90386 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -791,6 +791,16 @@ func (p *Parser) parseIndexedOperation() ast.Statement { } stmt.Variable = p.curToken.Literal + // Collect variable indices: cs[i] ? flags[0] or cs[i] ? grid[j][k] + for p.peekTokenIs(lexer.LBRACKET) { + p.nextToken() // move to [ + p.nextToken() // move past [ + stmt.VariableIndices = append(stmt.VariableIndices, p.parseExpression(LOWEST)) + if !p.expectPeek(lexer.RBRACKET) { + return nil + } + } + // Check for sequential receive for p.peekTokenIs(lexer.SEMICOLON) { p.nextToken() // move to ; @@ -1293,6 +1303,16 @@ func (p *Parser) parseReceive() ast.Statement { } stmt.Variable = p.curToken.Literal + // Collect variable indices: c ? flags[0] or c ? grid[i][j] + for p.peekTokenIs(lexer.LBRACKET) { + p.nextToken() // move to [ + p.nextToken() // move past [ + stmt.VariableIndices = append(stmt.VariableIndices, p.parseExpression(LOWEST)) + if !p.expectPeek(lexer.RBRACKET) { + return nil + } + } + // Check for sequential receive: c ? x ; y ; z for p.peekTokenIs(lexer.SEMICOLON) { p.nextToken() // move to ; @@ -1756,13 +1776,22 @@ func (p *Parser) parseAltCase() *ast.AltCase { p.nextToken() // move past AFTER altCase.Deadline = p.parseExpression(LOWEST) } else { - // Simple case: channel ? var + // Simple case: channel ? var or channel ? var[i] altCase.Channel = name p.nextToken() // move to ? if !p.expectPeek(lexer.IDENT) { return nil } altCase.Variable = p.curToken.Literal + // Collect variable indices: ch ? flags[0] + for p.peekTokenIs(lexer.LBRACKET) { + p.nextToken() // move to [ + p.nextToken() // move past [ + altCase.VariableIndices = append(altCase.VariableIndices, p.parseExpression(LOWEST)) + if !p.expectPeek(lexer.RBRACKET) { + return nil + } + } } } else if p.curTokenIs(lexer.IDENT) && p.peekTokenIs(lexer.LBRACKET) { // Indexed channel case: cs[i] ? var or cs[i][j] ? var @@ -1783,6 +1812,15 @@ func (p *Parser) parseAltCase() *ast.AltCase { return nil } altCase.Variable = p.curToken.Literal + // Collect variable indices: cs[i] ? flags[0] + for p.peekTokenIs(lexer.LBRACKET) { + p.nextToken() // move to [ + p.nextToken() // move past [ + altCase.VariableIndices = append(altCase.VariableIndices, p.parseExpression(LOWEST)) + if !p.expectPeek(lexer.RBRACKET) { + return nil + } + } } else { // Guard followed by & channel ? var, or guard & SKIP guard := p.parseExpression(LOWEST) @@ -1823,6 +1861,15 @@ func (p *Parser) parseAltCase() *ast.AltCase { return nil } altCase.Variable = p.curToken.Literal + // Collect variable indices: guard & ch ? flags[0] + for p.peekTokenIs(lexer.LBRACKET) { + p.nextToken() // move to [ + p.nextToken() // move past [ + altCase.VariableIndices = append(altCase.VariableIndices, p.parseExpression(LOWEST)) + if !p.expectPeek(lexer.RBRACKET) { + return nil + } + } } } diff --git a/parser/parser_test.go b/parser/parser_test.go index 1002b29..1fed10a 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -3888,3 +3888,93 @@ PROC test(CHAN OF CMD ch) t.Errorf("expected SeqBlock as second body statement, got %T", evolveCase.Body[1]) } } + +func TestReceiveIndexedVariable(t *testing.T) { + input := `ch ? flags[0] +` + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(program.Statements)) + } + + recv, ok := program.Statements[0].(*ast.Receive) + if !ok { + t.Fatalf("expected Receive, got %T", program.Statements[0]) + } + + if recv.Channel != "ch" { + t.Errorf("expected channel 'ch', got %s", recv.Channel) + } + + if recv.Variable != "flags" { + t.Errorf("expected variable 'flags', got %s", recv.Variable) + } + + if len(recv.VariableIndices) != 1 { + t.Fatalf("expected 1 variable index, got %d", len(recv.VariableIndices)) + } +} + +func TestReceiveMultiIndexedVariable(t *testing.T) { + input := `ch ? grid[i][j] +` + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(program.Statements)) + } + + recv, ok := program.Statements[0].(*ast.Receive) + if !ok { + t.Fatalf("expected Receive, got %T", program.Statements[0]) + } + + if recv.Variable != "grid" { + t.Errorf("expected variable 'grid', got %s", recv.Variable) + } + + if len(recv.VariableIndices) != 2 { + t.Fatalf("expected 2 variable indices, got %d", len(recv.VariableIndices)) + } +} + +func TestIndexedChannelReceiveIndexedVariable(t *testing.T) { + input := `cs[0] ? flags[1] +` + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(program.Statements)) + } + + recv, ok := program.Statements[0].(*ast.Receive) + if !ok { + t.Fatalf("expected Receive, got %T", program.Statements[0]) + } + + if recv.Channel != "cs" { + t.Errorf("expected channel 'cs', got %s", recv.Channel) + } + + if len(recv.ChannelIndices) != 1 { + t.Fatalf("expected 1 channel index, got %d", len(recv.ChannelIndices)) + } + + if recv.Variable != "flags" { + t.Errorf("expected variable 'flags', got %s", recv.Variable) + } + + if len(recv.VariableIndices) != 1 { + t.Fatalf("expected 1 variable index, got %d", len(recv.VariableIndices)) + } +}