Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 131 additions & 33 deletions docs/PTO_IR_manual.md
Original file line number Diff line number Diff line change
Expand Up @@ -8685,7 +8685,7 @@ This section documents PTO communication primitives. PTOAS currently exposes:

- Synchronous point-to-point ops: `pto.comm.tput`, `pto.comm.tget`
- Synchronous signal ops: `pto.comm.tnotify`, `pto.comm.twait`, `pto.comm.ttest`
- Synchronous collective ops: `pto.comm.tbroadcast`, `pto.comm.comm_tgather`, `pto.comm.comm_tscatter`, `pto.comm.treduce`
- Synchronous collective ops: `pto.comm.tbroadcast`, `pto.comm.tgather`, `pto.comm.tscatter`, `pto.comm.treduce`
- Asynchronous communication/session ops: `pto.comm.build_async_session`, `pto.comm.tput_async`, `pto.comm.tget_async`, `pto.comm.wait_async_event`, `pto.comm.test_async_event`

##### `pto.comm.build_async_session` - Create Async DMA Session
Expand Down Expand Up @@ -8811,24 +8811,23 @@ This section documents PTO communication primitives. PTOAS currently exposes:
|------|------|-------------|
| `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote destination buffer |
| `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local source buffer |
| `ping` | `pto.tile_buf` / local VEC memref | Required staging tile |
| `pong` | `pto.tile_buf` / local VEC memref | Optional second staging tile for ping-pong transfer |
| `atomicType` | `#pto.atomic_type<...>` | Atomic mode, default `atomic_none` |
| `buf` | `buf(%ping)` or `buf(%ping, %pong)` | Staging bundle: one or two local VEC tiles |
| `atomicType` | `#pto<atomic_type ...>` | Atomic mode, e.g. `atomic_none` or `atomic_add` |

**Constraints & Verification:**

- `dst` / `src` must be GM-shaped values with positive static shapes.
- `dst` and `src` must have the same element type and static shape.
- `ping` / `pong` must be local VEC tile-like values whose element type matches `src`.

**Basic Example:**
**Examples:**

Staging operands use the `buf(...)` bundle: one tile `buf(%ping)`, or ping–pong `buf(%ping, %pong)` for overlapping transfers.

```mlir
pto.comm.tput %dst, %src, %ping {atomicType = #pto.atomic_type<atomic_none>} :
!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.comm.tput(%dst, %src, buf(%ping) : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) {atomicType = #pto<atomic_type atomic_none>}

pto.comm.tput %dst, %src, %ping, %pong {atomicType = #pto.atomic_type<atomic_add>} :
!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.comm.tput(%dst, %src, buf(%ping, %pong) : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>) {atomicType = #pto<atomic_type atomic_add>}
```

---
Expand All @@ -8843,19 +8842,20 @@ pto.comm.tput %dst, %src, %ping, %pong {atomicType = #pto.atomic_type<atomic_add
|------|------|-------------|
| `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local destination buffer |
| `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote source buffer |
| `ping` | `pto.tile_buf` / local VEC memref | Required staging tile |
| `pong` | `pto.tile_buf` / local VEC memref | Optional second staging tile for ping-pong transfer |
| `ping` | `pto.tile_buf` / local VEC memref | Required staging tile (wrapped in `buf(%ping)`) |
| `pong` | `pto.tile_buf` / local VEC memref | Optional second staging tile (`buf(%ping, %pong)`) |

**Constraints & Verification:**

- Same GM/global-like and staging constraints as `pto.comm.tput`.
- `dst` and `src` must have the same element type and static shape.

**Basic Example:**
**Examples:**

```mlir
pto.comm.tget %dst, %src, %ping :
!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.comm.tget(%dst, %src, buf(%ping) : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)

pto.comm.tget(%dst, %src, buf(%ping, %pong) : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
```

---
Expand All @@ -8868,21 +8868,22 @@ pto.comm.tget %dst, %src, %ping :

| Op | Operands | Attributes | Result |
|----|----------|------------|--------|
| `pto.comm.tnotify` | `signal`, `value` | `notifyOp = #pto.notify_op<atomic_add/set>` | none |
| `pto.comm.twait` | `signal`, `cmpValue` | `cmp = #pto.wait_cmp<eq/ne/gt/ge/lt/le>` | none |
| `pto.comm.ttest` | `signal`, `cmpValue` | `cmp = #pto.wait_cmp<eq/ne/gt/ge/lt/le>` | `i1` |
| `pto.comm.tnotify` | `signal`, `value` | `notifyOp = #pto<notify_op atomic_add>` or `#pto<notify_op set>` | none |
| `pto.comm.twait` | `signal`, `cmpValue` | `cmp = #pto<wait_cmp eq/ne/gt/ge/lt/le>` | none |
| `pto.comm.ttest` | `signal`, `cmpValue` | `cmp = #pto<wait_cmp eq/ne/gt/ge/lt/le>` | `i1` |

**Constraints & Verification:**

- `signal` must be a GM-shaped value with element type `i32`.
- `value` / `cmpValue` must be signless integer scalars.

**Basic Example:**
**Examples:**

```mlir
pto.comm.tnotify %sig, %v {notifyOp = #pto.notify_op<set>} : !pto.partition_tensor_view<1xi32>, i32
pto.comm.twait %sig, %v {cmp = #pto.wait_cmp<ge>} : !pto.partition_tensor_view<1xi32>, i32
%ok = pto.comm.ttest %sig, %v {cmp = #pto.wait_cmp<eq>} : !pto.partition_tensor_view<1xi32>, i32 -> i1
pto.comm.tnotify(%sig, %v : !pto.partition_tensor_view<1xi32>, i32) {notifyOp = #pto<notify_op set>}
pto.comm.tnotify(%sig, %v : !pto.partition_tensor_view<1xi32>, i32) {notifyOp = #pto<notify_op atomic_add>}
pto.comm.twait(%sig, %v : !pto.partition_tensor_view<1xi32>, i32) {cmp = #pto<wait_cmp ge>}
%ok = pto.comm.ttest(%sig, %v : !pto.partition_tensor_view<1xi32>, i32) {cmp = #pto<wait_cmp eq>} -> i1
```

---
Expand All @@ -8896,7 +8897,7 @@ pto.comm.twait %sig, %v {cmp = #pto.wait_cmp<ge>} : !pto.partition_tensor_view<1
| Name | Type | Description |
|------|------|-------------|
| `src` | GM-shaped value | Root source buffer |
| `ping` / `pong` | local VEC tile-like values | Staging tiles |
| `recv` | `recv(%ping)` or `recv(%ping, %pong)` | One or two local VEC staging tiles |
| `group` | variadic GM-shaped values | Parallel group members |
| `root` | `i32` attr | Root rank index inside `group` |

Expand All @@ -8906,41 +8907,111 @@ pto.comm.twait %sig, %v {cmp = #pto.wait_cmp<ge>} : !pto.partition_tensor_view<1
- `src` must have the same type as each `group` member.
- `root` must be in range `[0, group.size)`.

**Basic Example:**
**Examples:**

Single receive buffer:

```mlir
pto.comm.tbroadcast(%src, recv(%ping), group(%g0, %g1, %g2) :
!pto.partition_tensor_view<128xf32>,
!pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>,
!pto.partition_tensor_view<128xf32>,
!pto.partition_tensor_view<128xf32>,
!pto.partition_tensor_view<128xf32>) {root = 1 : i32}
```

Optional ping–pong (`recv(%ping, %pong)` adds a second tile type in the operand-type list):

```mlir
pto.comm.tbroadcast %src, %ping, %g0, %g1, %g2 {root = 1, operandSegmentSizes = array<i32: 1, 1, 0, 3>} :
!pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>
pto.comm.tbroadcast(%src, recv(%ping, %pong), group(%g0, %g1, %g2) :
!pto.partition_tensor_view<128xf32>,
!pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>,
!pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>,
!pto.partition_tensor_view<128xf32>,
!pto.partition_tensor_view<128xf32>,
!pto.partition_tensor_view<128xf32>) {root = 1 : i32}
```

---

##### `pto.comm.comm_tgather` - Collective Gather
##### `pto.comm.tgather` - Collective Gather

**Summary:** Communication collective that lowers to `pto::comm::TGATHER(...)`. This op is distinct from tile-level `pto.tgather`.

**Arguments:** `dst`, `ping`, optional `pong`, variadic `group`, `root`
**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `dst` | GM-shaped value | Destination buffer (gather target) |
| `recv` | `recv(%ping)` or `recv(%ping, %pong)` | Staging tile(s) |
| `group` | variadic GM-shaped values | Parallel group members |
| `root` | `i32` attr | Root rank index inside `group` |

**Constraints & Verification:**

- `group` must be non-empty and all members must have identical types.
- `dst` element type must match the group element type.
- `ping` / `pong` must be local VEC tile-like values with matching element type.

**Examples:**

```mlir
pto.comm.tgather(%dst, recv(%ping), group(%g0, %g1, %g2) :
!pto.partition_tensor_view<128xf32>,
!pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>,
!pto.partition_tensor_view<128xf32>,
!pto.partition_tensor_view<128xf32>,
!pto.partition_tensor_view<128xf32>) {root = 1 : i32}

pto.comm.tgather(%dst, recv(%ping, %pong), group(%g0, %g1, %g2) :
!pto.partition_tensor_view<128xf32>,
!pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>,
!pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>,
!pto.partition_tensor_view<128xf32>,
!pto.partition_tensor_view<128xf32>,
!pto.partition_tensor_view<128xf32>) {root = 1 : i32}
```

---

##### `pto.comm.comm_tscatter` - Collective Scatter
##### `pto.comm.tscatter` - Collective Scatter

**Summary:** Communication collective that lowers to `pto::comm::TSCATTER(...)`. This op is distinct from tile-level `pto.tscatter`.

**Arguments:** `src`, `ping`, optional `pong`, variadic `group`, `root`
**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `src` | GM-shaped value | Source buffer (scatter root) |
| `recv` | `recv(%ping)` or `recv(%ping, %pong)` | Staging tile(s) |
| `group` | variadic GM-shaped values | Parallel group members |
| `root` | `i32` attr | Root rank index inside `group` |

**Constraints & Verification:**

- `group` must be non-empty and all members must have identical types.
- `src` element type must match the group element type.
- `ping` / `pong` must be local VEC tile-like values with matching element type.

**Examples:**

```mlir
pto.comm.tscatter(%src, recv(%ping), group(%g0, %g1, %g2) :
!pto.partition_tensor_view<128xf32>,
!pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>,
!pto.partition_tensor_view<128xf32>,
!pto.partition_tensor_view<128xf32>,
!pto.partition_tensor_view<128xf32>) {root = 1 : i32}

pto.comm.tscatter(%src, recv(%ping, %pong), group(%g0, %g1, %g2) :
!pto.partition_tensor_view<128xf32>,
!pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>,
!pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>,
!pto.partition_tensor_view<128xf32>,
!pto.partition_tensor_view<128xf32>,
!pto.partition_tensor_view<128xf32>) {root = 1 : i32}
```

---

##### `pto.comm.treduce` - Collective Reduce
Expand All @@ -8951,18 +9022,45 @@ pto.comm.tbroadcast %src, %ping, %g0, %g1, %g2 {root = 1, operandSegmentSizes =

| Name | Type | Description |
|------|------|-------------|
| `dst` | GM-shaped value | Root destination buffer |
| `dst` | GM-shaped value | Reduced output buffer |
| `acc` | local VEC tile-like value | Accumulation tile |
| `recvPing` / `recvPong` | local VEC tile-like values | Receive staging tiles |
| `recv` | `recv(%ping)` or `recv(%ping, %pong)` | One or two receive staging tiles |
| `group` | variadic GM-shaped values | Parallel group members |
| `reduceOp` | `#pto.reduce_op<sum/max/min>` | Reduction mode |
| `reduceOp` | `#pto<reduce_op sum>` / `#pto<reduce_op max>` / `#pto<reduce_op min>` | Reduction mode |
| `root` | `i32` attr | Root rank index inside `group` |

**Constraints & Verification:**

- `group` must be non-empty and all members must have identical types.
- `dst` element type must match the group element type.
- `acc` and `recvPing` / `recvPong` must be local VEC tile-like values whose element type matches `dst`.
- `acc` and `recv(%ping)` / `recv(%ping, %pong)` operands must be local VEC tile-like values whose element type matches `dst`.

**Examples:**

Sum with a single receive tile:

```mlir
pto.comm.treduce(%dst, %acc, recv(%ping), group(%g0, %g1, %g2) :
!pto.partition_tensor_view<128xf32>,
!pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>,
!pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>,
!pto.partition_tensor_view<128xf32>,
!pto.partition_tensor_view<128xf32>,
!pto.partition_tensor_view<128xf32>) {reduceOp = #pto<reduce_op sum>, root = 1 : i32}
```

Max with ping–pong receive buffers (two staging tiles — operand-type list includes three `tile_buf` entries: `acc`, `ping`, `pong`):

```mlir
pto.comm.treduce(%dst, %acc, recv(%ping, %pong), group(%g0, %g1, %g2) :
!pto.partition_tensor_view<128xf32>,
!pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>,
!pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>,
!pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>,
!pto.partition_tensor_view<128xf32>,
!pto.partition_tensor_view<128xf32>,
!pto.partition_tensor_view<128xf32>) {reduceOp = #pto<reduce_op max>, root = 1 : i32}
```

---

Expand Down
30 changes: 28 additions & 2 deletions include/PTO/IR/PTOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1681,6 +1681,11 @@ def TPutOp : PTO_Op<"comm.tput", [
);
let results = (outs);
let hasVerifier = 1;
let assemblyFormat = [{
`(` $dst `,` $src `,` `buf` `(` $ping (`,` $pong^)? `)`
`:` type($dst) `,` type($src) `,` type($ping) (`,` type($pong)^)? `)`
attr-dict
}];
}

def TGetOp : PTO_Op<"comm.tget", [
Expand All @@ -1695,6 +1700,11 @@ def TGetOp : PTO_Op<"comm.tget", [
);
let results = (outs);
let hasVerifier = 1;
let assemblyFormat = [{
`(` $dst `,` $src `,` `buf` `(` $ping (`,` $pong^)? `)`
`:` type($dst) `,` type($src) `,` type($ping) (`,` type($pong)^)? `)`
attr-dict
}];
}

def TNotifyOp : PTO_Op<"comm.tnotify", [
Expand All @@ -1708,6 +1718,10 @@ def TNotifyOp : PTO_Op<"comm.tnotify", [
);
let results = (outs);
let hasVerifier = 1;
let assemblyFormat = [{
`(` $signal `,` $value `:` type($signal) `,` type($value) `)`
attr-dict
}];
}

def TWaitOp : PTO_Op<"comm.twait", [
Expand All @@ -1721,6 +1735,10 @@ def TWaitOp : PTO_Op<"comm.twait", [
);
let results = (outs);
let hasVerifier = 1;
let assemblyFormat = [{
`(` $signal `,` $cmpValue `:` type($signal) `,` type($cmpValue) `)`
attr-dict
}];
}

def TTestOp : PTO_Op<"comm.ttest", [
Expand All @@ -1734,6 +1752,10 @@ def TTestOp : PTO_Op<"comm.ttest", [
);
let results = (outs I1:$result);
let hasVerifier = 1;
let assemblyFormat = [{
`(` $signal `,` $cmpValue `:` type($signal) `,` type($cmpValue) `)`
attr-dict `->` type($result)
}];
}

def TBroadcastOp : PTO_Op<"comm.tbroadcast", [
Expand All @@ -1750,9 +1772,10 @@ def TBroadcastOp : PTO_Op<"comm.tbroadcast", [
);
let results = (outs);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}

def CommTGatherOp : PTO_Op<"comm.comm_tgather", [
def CommTGatherOp : PTO_Op<"comm.tgather", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
Expand All @@ -1766,9 +1789,10 @@ def CommTGatherOp : PTO_Op<"comm.comm_tgather", [
);
let results = (outs);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}

def CommTScatterOp : PTO_Op<"comm.comm_tscatter", [
def CommTScatterOp : PTO_Op<"comm.tscatter", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
Expand All @@ -1782,6 +1806,7 @@ def CommTScatterOp : PTO_Op<"comm.comm_tscatter", [
);
let results = (outs);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}

def TReduceOp : PTO_Op<"comm.treduce", [
Expand All @@ -1800,6 +1825,7 @@ def TReduceOp : PTO_Op<"comm.treduce", [
);
let results = (outs);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}

def InitializeL2G2LPipeOp : PTO_Op<"initialize_l2g2l_pipe", [
Expand Down
Loading
Loading