Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,11 @@ func (s *Send) TokenLiteral() string { return s.Token.Literal }
// Receive represents a channel receive: c ? x or c ? x ; y
type Receive struct {
Token lexer.Token // the ? token
Channel string // channel name
ChannelIndices []Expression // non-empty for cs[i] ? x or cs[i][j] ? x
Variable string // variable to receive into (simple receive)
Variables []string // additional variables for sequential receives (c ? x ; y)
Channel string // channel name
ChannelIndices []Expression // non-empty for cs[i] ? x or cs[i][j] ? x
Variable string // variable to receive into (simple receive)
VariableIndices []Expression // non-empty for c ? flags[0] or c ? grid[i][j]
Variables []string // additional variables for sequential receives (c ? x ; y)
}

func (r *Receive) statementNode() {}
Expand All @@ -396,11 +397,12 @@ func (a *AltBlock) TokenLiteral() string { return a.Token.Literal }

// AltCase represents a single case in an ALT block
type AltCase struct {
Guard Expression // optional guard condition (nil if no guard)
Channel string // channel name
ChannelIndices []Expression // non-empty for cs[i] ? x or cs[i][j] ? x in ALT
Variable string // variable to receive into
Body []Statement // the body to execute
Guard Expression // optional guard condition (nil if no guard)
Channel string // channel name
ChannelIndices []Expression // non-empty for cs[i] ? x or cs[i][j] ? x in ALT
Variable string // variable to receive into
VariableIndices []Expression // non-empty for c ? flags[0] or c ? grid[i][j]
Body []Statement // the body to execute
IsTimer bool // true if this is a timer AFTER case
IsSkip bool // true if this is a guarded SKIP case (guard & SKIP)
Timer string // timer name (when IsTimer)
Expand Down
32 changes: 26 additions & 6 deletions codegen/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,9 @@ func (g *Generator) generateReceive(recv *ast.Receive) {
g.tmpCounter++
g.writeLine(fmt.Sprintf("%s := <-%s", tmpName, chanRef))
varRef := goIdent(recv.Variable)
if g.refParams[recv.Variable] {
if len(recv.VariableIndices) > 0 {
varRef += g.generateIndicesStr(recv.VariableIndices)
} else if g.refParams[recv.Variable] {
varRef = "*" + varRef
}
g.writeLine(fmt.Sprintf("%s = %s._0", varRef, tmpName))
Expand All @@ -1359,7 +1361,9 @@ func (g *Generator) generateReceive(recv *ast.Receive) {
}
} else {
varRef := goIdent(recv.Variable)
if g.refParams[recv.Variable] {
if len(recv.VariableIndices) > 0 {
varRef += g.generateIndicesStr(recv.VariableIndices)
} else if g.refParams[recv.Variable] {
varRef = "*" + varRef
}
g.writeLine(fmt.Sprintf("%s = <-%s", varRef, chanRef))
Expand Down Expand Up @@ -1903,13 +1907,25 @@ func (g *Generator) generateAltBlock(alt *ast.AltBlock) {
g.generateExpression(c.Deadline)
g.write(" - int(time.Now().UnixMicro())) * time.Microsecond):\n")
} else if c.Guard != nil {
g.write(fmt.Sprintf("case %s = <-_alt%d:\n", goIdent(c.Variable), i))
varRef := goIdent(c.Variable)
if len(c.VariableIndices) > 0 {
varRef += g.generateIndicesStr(c.VariableIndices)
}
g.write(fmt.Sprintf("case %s = <-_alt%d:\n", varRef, i))
} else if len(c.ChannelIndices) > 0 {
g.write(fmt.Sprintf("case %s = <-%s", goIdent(c.Variable), goIdent(c.Channel)))
varRef := goIdent(c.Variable)
if len(c.VariableIndices) > 0 {
varRef += g.generateIndicesStr(c.VariableIndices)
}
g.write(fmt.Sprintf("case %s = <-%s", varRef, goIdent(c.Channel)))
g.generateIndices(c.ChannelIndices)
g.write(":\n")
} else {
g.write(fmt.Sprintf("case %s = <-%s:\n", goIdent(c.Variable), goIdent(c.Channel)))
varRef := goIdent(c.Variable)
if len(c.VariableIndices) > 0 {
varRef += g.generateIndicesStr(c.VariableIndices)
}
g.write(fmt.Sprintf("case %s = <-%s:\n", varRef, goIdent(c.Channel)))
}
g.indent++
guardedSkip := c.IsSkip && c.Guard != nil
Expand Down Expand Up @@ -2040,7 +2056,11 @@ func (g *Generator) generateReplicatedAlt(alt *ast.AltBlock) {
}

// Assign received value from reflect.Value
g.writeLine(fmt.Sprintf("%s = _altValue.Interface().(%s)", goIdent(c.Variable), recvType))
varRef := goIdent(c.Variable)
if len(c.VariableIndices) > 0 {
varRef += g.generateIndicesStr(c.VariableIndices)
}
g.writeLine(fmt.Sprintf("%s = _altValue.Interface().(%s)", varRef, recvType))

// Generate body
for _, s := range c.Body {
Expand Down
38 changes: 38 additions & 0 deletions codegen/e2e_concurrency_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,41 @@ func TestE2E_PriPar(t *testing.T) {
t.Errorf("expected %q, got %q", expected, output)
}
}

func TestE2E_ReceiveIntoIndexedVariable(t *testing.T) {
occam := `SEQ
CHAN OF INT c:
[3]INT arr:
arr[0] := 0
arr[1] := 0
arr[2] := 0
PAR
c ! 42
c ? arr[1]
print.int(arr[1])
`
output := transpileCompileRun(t, occam)
expected := "42\n"
if output != expected {
t.Errorf("expected %q, got %q", expected, output)
}
}

func TestE2E_IndexedChannelReceiveIntoIndexedVariable(t *testing.T) {
occam := `SEQ
[2]CHAN OF INT cs:
[3]INT arr:
arr[0] := 0
arr[1] := 0
arr[2] := 0
PAR
cs[0] ! 99
cs[0] ? arr[2]
print.int(arr[2])
`
output := transpileCompileRun(t, occam)
expected := "99\n"
if output != expected {
t.Errorf("expected %q, got %q", expected, output)
}
}
49 changes: 48 additions & 1 deletion parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,16 @@ func (p *Parser) parseIndexedOperation() ast.Statement {
}
stmt.Variable = p.curToken.Literal

// Collect variable indices: cs[i] ? flags[0] or cs[i] ? grid[j][k]
for p.peekTokenIs(lexer.LBRACKET) {
p.nextToken() // move to [
p.nextToken() // move past [
stmt.VariableIndices = append(stmt.VariableIndices, p.parseExpression(LOWEST))
if !p.expectPeek(lexer.RBRACKET) {
return nil
}
}

// Check for sequential receive
for p.peekTokenIs(lexer.SEMICOLON) {
p.nextToken() // move to ;
Expand Down Expand Up @@ -1293,6 +1303,16 @@ func (p *Parser) parseReceive() ast.Statement {
}
stmt.Variable = p.curToken.Literal

// Collect variable indices: c ? flags[0] or c ? grid[i][j]
for p.peekTokenIs(lexer.LBRACKET) {
p.nextToken() // move to [
p.nextToken() // move past [
stmt.VariableIndices = append(stmt.VariableIndices, p.parseExpression(LOWEST))
if !p.expectPeek(lexer.RBRACKET) {
return nil
}
}

// Check for sequential receive: c ? x ; y ; z
for p.peekTokenIs(lexer.SEMICOLON) {
p.nextToken() // move to ;
Expand Down Expand Up @@ -1756,13 +1776,22 @@ func (p *Parser) parseAltCase() *ast.AltCase {
p.nextToken() // move past AFTER
altCase.Deadline = p.parseExpression(LOWEST)
} else {
// Simple case: channel ? var
// Simple case: channel ? var or channel ? var[i]
altCase.Channel = name
p.nextToken() // move to ?
if !p.expectPeek(lexer.IDENT) {
return nil
}
altCase.Variable = p.curToken.Literal
// Collect variable indices: ch ? flags[0]
for p.peekTokenIs(lexer.LBRACKET) {
p.nextToken() // move to [
p.nextToken() // move past [
altCase.VariableIndices = append(altCase.VariableIndices, p.parseExpression(LOWEST))
if !p.expectPeek(lexer.RBRACKET) {
return nil
}
}
}
} else if p.curTokenIs(lexer.IDENT) && p.peekTokenIs(lexer.LBRACKET) {
// Indexed channel case: cs[i] ? var or cs[i][j] ? var
Expand All @@ -1783,6 +1812,15 @@ func (p *Parser) parseAltCase() *ast.AltCase {
return nil
}
altCase.Variable = p.curToken.Literal
// Collect variable indices: cs[i] ? flags[0]
for p.peekTokenIs(lexer.LBRACKET) {
p.nextToken() // move to [
p.nextToken() // move past [
altCase.VariableIndices = append(altCase.VariableIndices, p.parseExpression(LOWEST))
if !p.expectPeek(lexer.RBRACKET) {
return nil
}
}
} else {
// Guard followed by & channel ? var, or guard & SKIP
guard := p.parseExpression(LOWEST)
Expand Down Expand Up @@ -1823,6 +1861,15 @@ func (p *Parser) parseAltCase() *ast.AltCase {
return nil
}
altCase.Variable = p.curToken.Literal
// Collect variable indices: guard & ch ? flags[0]
for p.peekTokenIs(lexer.LBRACKET) {
p.nextToken() // move to [
p.nextToken() // move past [
altCase.VariableIndices = append(altCase.VariableIndices, p.parseExpression(LOWEST))
if !p.expectPeek(lexer.RBRACKET) {
return nil
}
}
}
}

Expand Down
90 changes: 90 additions & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3888,3 +3888,93 @@ PROC test(CHAN OF CMD ch)
t.Errorf("expected SeqBlock as second body statement, got %T", evolveCase.Body[1])
}
}

func TestReceiveIndexedVariable(t *testing.T) {
input := `ch ? flags[0]
`
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)

if len(program.Statements) != 1 {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}

recv, ok := program.Statements[0].(*ast.Receive)
if !ok {
t.Fatalf("expected Receive, got %T", program.Statements[0])
}

if recv.Channel != "ch" {
t.Errorf("expected channel 'ch', got %s", recv.Channel)
}

if recv.Variable != "flags" {
t.Errorf("expected variable 'flags', got %s", recv.Variable)
}

if len(recv.VariableIndices) != 1 {
t.Fatalf("expected 1 variable index, got %d", len(recv.VariableIndices))
}
}

func TestReceiveMultiIndexedVariable(t *testing.T) {
input := `ch ? grid[i][j]
`
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)

if len(program.Statements) != 1 {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}

recv, ok := program.Statements[0].(*ast.Receive)
if !ok {
t.Fatalf("expected Receive, got %T", program.Statements[0])
}

if recv.Variable != "grid" {
t.Errorf("expected variable 'grid', got %s", recv.Variable)
}

if len(recv.VariableIndices) != 2 {
t.Fatalf("expected 2 variable indices, got %d", len(recv.VariableIndices))
}
}

func TestIndexedChannelReceiveIndexedVariable(t *testing.T) {
input := `cs[0] ? flags[1]
`
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)

if len(program.Statements) != 1 {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}

recv, ok := program.Statements[0].(*ast.Receive)
if !ok {
t.Fatalf("expected Receive, got %T", program.Statements[0])
}

if recv.Channel != "cs" {
t.Errorf("expected channel 'cs', got %s", recv.Channel)
}

if len(recv.ChannelIndices) != 1 {
t.Fatalf("expected 1 channel index, got %d", len(recv.ChannelIndices))
}

if recv.Variable != "flags" {
t.Errorf("expected variable 'flags', got %s", recv.Variable)
}

if len(recv.VariableIndices) != 1 {
t.Fatalf("expected 1 variable index, got %d", len(recv.VariableIndices))
}
}
Loading