From 9ada56e8eebb0a5faba1d073ed2acd6c895b6631 Mon Sep 17 00:00:00 2001 From: FangRui Date: Sat, 9 May 2026 16:18:10 +0800 Subject: [PATCH] feat: update comm ops ods and add testcases Signed-off-by: FangRui --- docs/PTO_IR_manual.md | 164 ++++++++-- include/PTO/IR/PTOOps.td | 30 +- lib/PTO/IR/PTO.cpp | 267 +++++++++++++++++ lib/PTO/Transforms/PTOToEmitC.cpp | 283 +++++++----------- test/lit/pto/comm_collective_emitc.pto | 20 +- test/lit/pto/comm_p2p_emitc.pto | 24 +- .../CommSync/tbroadcast_root_binding.pto | 25 ++ .../CommSync/tbroadcast_root_binding.py | 75 +++++ .../samples/CommSync/tgather_root_binding.pto | 25 ++ test/samples/CommSync/tgather_root_binding.py | 75 +++++ .../tnotify_atomic_add_binding-pto-ir.pto | 12 + .../CommSync/tnotify_atomic_add_binding.py | 56 ++++ .../samples/CommSync/treduce_root_binding.pto | 26 ++ test/samples/CommSync/treduce_root_binding.py | 83 +++++ .../CommSync/tscatter_root_binding.pto | 25 ++ .../samples/CommSync/tscatter_root_binding.py | 75 +++++ test/samples/CommSync/twait_atomic_binding.py | 80 +++++ test/samples/runop.sh | 166 ++++++++-- tools/ptobc/generated/ptobc_opcodes_v0.h | 12 +- 19 files changed, 1271 insertions(+), 252 deletions(-) create mode 100644 test/samples/CommSync/tbroadcast_root_binding.pto create mode 100644 test/samples/CommSync/tbroadcast_root_binding.py create mode 100644 test/samples/CommSync/tgather_root_binding.pto create mode 100644 test/samples/CommSync/tgather_root_binding.py create mode 100644 test/samples/CommSync/tnotify_atomic_add_binding-pto-ir.pto create mode 100644 test/samples/CommSync/tnotify_atomic_add_binding.py create mode 100644 test/samples/CommSync/treduce_root_binding.pto create mode 100644 test/samples/CommSync/treduce_root_binding.py create mode 100644 test/samples/CommSync/tscatter_root_binding.pto create mode 100644 test/samples/CommSync/tscatter_root_binding.py create mode 100644 test/samples/CommSync/twait_atomic_binding.py diff --git a/docs/PTO_IR_manual.md b/docs/PTO_IR_manual.md index 9b3a62f9b..783401e98 100644 --- a/docs/PTO_IR_manual.md +++ b/docs/PTO_IR_manual.md @@ -8685,7 +8685,7 @@ This section documents PTO communication primitives. PTOAS currently exposes: - Synchronous point-to-point ops: `pto.comm.tput`, `pto.comm.tget` - Synchronous signal ops: `pto.comm.tnotify`, `pto.comm.twait`, `pto.comm.ttest` -- Synchronous collective ops: `pto.comm.tbroadcast`, `pto.comm.comm_tgather`, `pto.comm.comm_tscatter`, `pto.comm.treduce` +- Synchronous collective ops: `pto.comm.tbroadcast`, `pto.comm.tgather`, `pto.comm.tscatter`, `pto.comm.treduce` - Asynchronous communication/session ops: `pto.comm.build_async_session`, `pto.comm.tput_async`, `pto.comm.tget_async`, `pto.comm.wait_async_event`, `pto.comm.test_async_event` ##### `pto.comm.build_async_session` - Create Async DMA Session @@ -8811,9 +8811,8 @@ This section documents PTO communication primitives. PTOAS currently exposes: |------|------|-------------| | `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote destination buffer | | `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local source buffer | -| `ping` | `pto.tile_buf` / local VEC memref | Required staging tile | -| `pong` | `pto.tile_buf` / local VEC memref | Optional second staging tile for ping-pong transfer | -| `atomicType` | `#pto.atomic_type<...>` | Atomic mode, default `atomic_none` | +| `buf` | `buf(%ping)` or `buf(%ping, %pong)` | Staging bundle: one or two local VEC tiles | +| `atomicType` | `#pto` | Atomic mode, e.g. `atomic_none` or `atomic_add` | **Constraints & Verification:** @@ -8821,14 +8820,14 @@ This section documents PTO communication primitives. PTOAS currently exposes: - `dst` and `src` must have the same element type and static shape. - `ping` / `pong` must be local VEC tile-like values whose element type matches `src`. -**Basic Example:** +**Examples:** + +Staging operands use the `buf(...)` bundle: one tile `buf(%ping)`, or ping–pong `buf(%ping, %pong)` for overlapping transfers. ```mlir -pto.comm.tput %dst, %src, %ping {atomicType = #pto.atomic_type} : - !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf +pto.comm.tput(%dst, %src, buf(%ping) : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf) {atomicType = #pto} -pto.comm.tput %dst, %src, %ping, %pong {atomicType = #pto.atomic_type} : - !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf +pto.comm.tput(%dst, %src, buf(%ping, %pong) : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf) {atomicType = #pto} ``` --- @@ -8843,19 +8842,20 @@ pto.comm.tput %dst, %src, %ping, %pong {atomicType = #pto.atomic_type, !pto.partition_tensor_view<128xf32>, !pto.tile_buf +pto.comm.tget(%dst, %src, buf(%ping) : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf) + +pto.comm.tget(%dst, %src, buf(%ping, %pong) : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf) ``` --- @@ -8868,21 +8868,22 @@ pto.comm.tget %dst, %src, %ping : | Op | Operands | Attributes | Result | |----|----------|------------|--------| -| `pto.comm.tnotify` | `signal`, `value` | `notifyOp = #pto.notify_op` | none | -| `pto.comm.twait` | `signal`, `cmpValue` | `cmp = #pto.wait_cmp` | none | -| `pto.comm.ttest` | `signal`, `cmpValue` | `cmp = #pto.wait_cmp` | `i1` | +| `pto.comm.tnotify` | `signal`, `value` | `notifyOp = #pto` or `#pto` | none | +| `pto.comm.twait` | `signal`, `cmpValue` | `cmp = #pto` | none | +| `pto.comm.ttest` | `signal`, `cmpValue` | `cmp = #pto` | `i1` | **Constraints & Verification:** - `signal` must be a GM-shaped value with element type `i32`. - `value` / `cmpValue` must be signless integer scalars. -**Basic Example:** +**Examples:** ```mlir -pto.comm.tnotify %sig, %v {notifyOp = #pto.notify_op} : !pto.partition_tensor_view<1xi32>, i32 -pto.comm.twait %sig, %v {cmp = #pto.wait_cmp} : !pto.partition_tensor_view<1xi32>, i32 -%ok = pto.comm.ttest %sig, %v {cmp = #pto.wait_cmp} : !pto.partition_tensor_view<1xi32>, i32 -> i1 +pto.comm.tnotify(%sig, %v : !pto.partition_tensor_view<1xi32>, i32) {notifyOp = #pto} +pto.comm.tnotify(%sig, %v : !pto.partition_tensor_view<1xi32>, i32) {notifyOp = #pto} +pto.comm.twait(%sig, %v : !pto.partition_tensor_view<1xi32>, i32) {cmp = #pto} +%ok = pto.comm.ttest(%sig, %v : !pto.partition_tensor_view<1xi32>, i32) {cmp = #pto} -> i1 ``` --- @@ -8896,7 +8897,7 @@ pto.comm.twait %sig, %v {cmp = #pto.wait_cmp} : !pto.partition_tensor_view<1 | Name | Type | Description | |------|------|-------------| | `src` | GM-shaped value | Root source buffer | -| `ping` / `pong` | local VEC tile-like values | Staging tiles | +| `recv` | `recv(%ping)` or `recv(%ping, %pong)` | One or two local VEC staging tiles | | `group` | variadic GM-shaped values | Parallel group members | | `root` | `i32` attr | Root rank index inside `group` | @@ -8906,20 +8907,45 @@ pto.comm.twait %sig, %v {cmp = #pto.wait_cmp} : !pto.partition_tensor_view<1 - `src` must have the same type as each `group` member. - `root` must be in range `[0, group.size)`. -**Basic Example:** +**Examples:** + +Single receive buffer: + +```mlir +pto.comm.tbroadcast(%src, recv(%ping), group(%g0, %g1, %g2) : + !pto.partition_tensor_view<128xf32>, + !pto.tile_buf, + !pto.partition_tensor_view<128xf32>, + !pto.partition_tensor_view<128xf32>, + !pto.partition_tensor_view<128xf32>) {root = 1 : i32} +``` + +Optional ping–pong (`recv(%ping, %pong)` adds a second tile type in the operand-type list): ```mlir -pto.comm.tbroadcast %src, %ping, %g0, %g1, %g2 {root = 1, operandSegmentSizes = array} : - !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32> +pto.comm.tbroadcast(%src, recv(%ping, %pong), group(%g0, %g1, %g2) : + !pto.partition_tensor_view<128xf32>, + !pto.tile_buf, + !pto.tile_buf, + !pto.partition_tensor_view<128xf32>, + !pto.partition_tensor_view<128xf32>, + !pto.partition_tensor_view<128xf32>) {root = 1 : i32} ``` --- -##### `pto.comm.comm_tgather` - Collective Gather +##### `pto.comm.tgather` - Collective Gather **Summary:** Communication collective that lowers to `pto::comm::TGATHER(...)`. This op is distinct from tile-level `pto.tgather`. -**Arguments:** `dst`, `ping`, optional `pong`, variadic `group`, `root` +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `dst` | GM-shaped value | Destination buffer (gather target) | +| `recv` | `recv(%ping)` or `recv(%ping, %pong)` | Staging tile(s) | +| `group` | variadic GM-shaped values | Parallel group members | +| `root` | `i32` attr | Root rank index inside `group` | **Constraints & Verification:** @@ -8927,13 +8953,39 @@ pto.comm.tbroadcast %src, %ping, %g0, %g1, %g2 {root = 1, operandSegmentSizes = - `dst` element type must match the group element type. - `ping` / `pong` must be local VEC tile-like values with matching element type. +**Examples:** + +```mlir +pto.comm.tgather(%dst, recv(%ping), group(%g0, %g1, %g2) : + !pto.partition_tensor_view<128xf32>, + !pto.tile_buf, + !pto.partition_tensor_view<128xf32>, + !pto.partition_tensor_view<128xf32>, + !pto.partition_tensor_view<128xf32>) {root = 1 : i32} + +pto.comm.tgather(%dst, recv(%ping, %pong), group(%g0, %g1, %g2) : + !pto.partition_tensor_view<128xf32>, + !pto.tile_buf, + !pto.tile_buf, + !pto.partition_tensor_view<128xf32>, + !pto.partition_tensor_view<128xf32>, + !pto.partition_tensor_view<128xf32>) {root = 1 : i32} +``` + --- -##### `pto.comm.comm_tscatter` - Collective Scatter +##### `pto.comm.tscatter` - Collective Scatter **Summary:** Communication collective that lowers to `pto::comm::TSCATTER(...)`. This op is distinct from tile-level `pto.tscatter`. -**Arguments:** `src`, `ping`, optional `pong`, variadic `group`, `root` +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `src` | GM-shaped value | Source buffer (scatter root) | +| `recv` | `recv(%ping)` or `recv(%ping, %pong)` | Staging tile(s) | +| `group` | variadic GM-shaped values | Parallel group members | +| `root` | `i32` attr | Root rank index inside `group` | **Constraints & Verification:** @@ -8941,6 +8993,25 @@ pto.comm.tbroadcast %src, %ping, %g0, %g1, %g2 {root = 1, operandSegmentSizes = - `src` element type must match the group element type. - `ping` / `pong` must be local VEC tile-like values with matching element type. +**Examples:** + +```mlir +pto.comm.tscatter(%src, recv(%ping), group(%g0, %g1, %g2) : + !pto.partition_tensor_view<128xf32>, + !pto.tile_buf, + !pto.partition_tensor_view<128xf32>, + !pto.partition_tensor_view<128xf32>, + !pto.partition_tensor_view<128xf32>) {root = 1 : i32} + +pto.comm.tscatter(%src, recv(%ping, %pong), group(%g0, %g1, %g2) : + !pto.partition_tensor_view<128xf32>, + !pto.tile_buf, + !pto.tile_buf, + !pto.partition_tensor_view<128xf32>, + !pto.partition_tensor_view<128xf32>, + !pto.partition_tensor_view<128xf32>) {root = 1 : i32} +``` + --- ##### `pto.comm.treduce` - Collective Reduce @@ -8951,18 +9022,45 @@ pto.comm.tbroadcast %src, %ping, %g0, %g1, %g2 {root = 1, operandSegmentSizes = | Name | Type | Description | |------|------|-------------| -| `dst` | GM-shaped value | Root destination buffer | +| `dst` | GM-shaped value | Reduced output buffer | | `acc` | local VEC tile-like value | Accumulation tile | -| `recvPing` / `recvPong` | local VEC tile-like values | Receive staging tiles | +| `recv` | `recv(%ping)` or `recv(%ping, %pong)` | One or two receive staging tiles | | `group` | variadic GM-shaped values | Parallel group members | -| `reduceOp` | `#pto.reduce_op` | Reduction mode | +| `reduceOp` | `#pto` / `#pto` / `#pto` | Reduction mode | | `root` | `i32` attr | Root rank index inside `group` | **Constraints & Verification:** - `group` must be non-empty and all members must have identical types. - `dst` element type must match the group element type. -- `acc` and `recvPing` / `recvPong` must be local VEC tile-like values whose element type matches `dst`. +- `acc` and `recv(%ping)` / `recv(%ping, %pong)` operands must be local VEC tile-like values whose element type matches `dst`. + +**Examples:** + +Sum with a single receive tile: + +```mlir +pto.comm.treduce(%dst, %acc, recv(%ping), group(%g0, %g1, %g2) : + !pto.partition_tensor_view<128xf32>, + !pto.tile_buf, + !pto.tile_buf, + !pto.partition_tensor_view<128xf32>, + !pto.partition_tensor_view<128xf32>, + !pto.partition_tensor_view<128xf32>) {reduceOp = #pto, root = 1 : i32} +``` + +Max with ping–pong receive buffers (two staging tiles — operand-type list includes three `tile_buf` entries: `acc`, `ping`, `pong`): + +```mlir +pto.comm.treduce(%dst, %acc, recv(%ping, %pong), group(%g0, %g1, %g2) : + !pto.partition_tensor_view<128xf32>, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.partition_tensor_view<128xf32>, + !pto.partition_tensor_view<128xf32>, + !pto.partition_tensor_view<128xf32>) {reduceOp = #pto, root = 1 : i32} +``` --- diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 8c81f0824..d0195eb28 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -1681,6 +1681,11 @@ def TPutOp : PTO_Op<"comm.tput", [ ); let results = (outs); let hasVerifier = 1; + let assemblyFormat = [{ + `(` $dst `,` $src `,` `buf` `(` $ping (`,` $pong^)? `)` + `:` type($dst) `,` type($src) `,` type($ping) (`,` type($pong)^)? `)` + attr-dict + }]; } def TGetOp : PTO_Op<"comm.tget", [ @@ -1695,6 +1700,11 @@ def TGetOp : PTO_Op<"comm.tget", [ ); let results = (outs); let hasVerifier = 1; + let assemblyFormat = [{ + `(` $dst `,` $src `,` `buf` `(` $ping (`,` $pong^)? `)` + `:` type($dst) `,` type($src) `,` type($ping) (`,` type($pong)^)? `)` + attr-dict + }]; } def TNotifyOp : PTO_Op<"comm.tnotify", [ @@ -1708,6 +1718,10 @@ def TNotifyOp : PTO_Op<"comm.tnotify", [ ); let results = (outs); let hasVerifier = 1; + let assemblyFormat = [{ + `(` $signal `,` $value `:` type($signal) `,` type($value) `)` + attr-dict + }]; } def TWaitOp : PTO_Op<"comm.twait", [ @@ -1721,6 +1735,10 @@ def TWaitOp : PTO_Op<"comm.twait", [ ); let results = (outs); let hasVerifier = 1; + let assemblyFormat = [{ + `(` $signal `,` $cmpValue `:` type($signal) `,` type($cmpValue) `)` + attr-dict + }]; } def TTestOp : PTO_Op<"comm.ttest", [ @@ -1734,6 +1752,10 @@ def TTestOp : PTO_Op<"comm.ttest", [ ); let results = (outs I1:$result); let hasVerifier = 1; + let assemblyFormat = [{ + `(` $signal `,` $cmpValue `:` type($signal) `,` type($cmpValue) `)` + attr-dict `->` type($result) + }]; } def TBroadcastOp : PTO_Op<"comm.tbroadcast", [ @@ -1750,9 +1772,10 @@ def TBroadcastOp : PTO_Op<"comm.tbroadcast", [ ); let results = (outs); let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; } -def CommTGatherOp : PTO_Op<"comm.comm_tgather", [ +def CommTGatherOp : PTO_Op<"comm.tgather", [ AttrSizedOperandSegments, DeclareOpInterfaceMethods ]> { @@ -1766,9 +1789,10 @@ def CommTGatherOp : PTO_Op<"comm.comm_tgather", [ ); let results = (outs); let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; } -def CommTScatterOp : PTO_Op<"comm.comm_tscatter", [ +def CommTScatterOp : PTO_Op<"comm.tscatter", [ AttrSizedOperandSegments, DeclareOpInterfaceMethods ]> { @@ -1782,6 +1806,7 @@ def CommTScatterOp : PTO_Op<"comm.comm_tscatter", [ ); let results = (outs); let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; } def TReduceOp : PTO_Op<"comm.treduce", [ @@ -1800,6 +1825,7 @@ def TReduceOp : PTO_Op<"comm.treduce", [ ); let results = (outs); let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; } def InitializeL2G2LPipeOp : PTO_Op<"initialize_l2g2l_pipe", [ diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 7ba2a747e..a24263700 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -791,6 +791,273 @@ void mlir::pto::TGatherOp::print(OpAsmPrinter &p) { } } +namespace { + +struct CommRecvClause { + OpAsmParser::UnresolvedOperand ping; + std::optional pong; + Type pingTy; + Type pongTy; +}; + +static ParseResult parseCommRecvClause(OpAsmParser &parser, + CommRecvClause &recvClause) { + if (parser.parseKeyword("recv") || parser.parseLParen() || + parser.parseOperand(recvClause.ping)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + OpAsmParser::UnresolvedOperand pong; + if (parser.parseOperand(pong)) + return failure(); + recvClause.pong = pong; + } + return parser.parseRParen(); +} + +static ParseResult parseCommCollectiveTail( + OpAsmParser &parser, OperationState &result, + ArrayRef fixedOperands, + SmallVectorImpl &fixedTypes, CommRecvClause &recvClause, + SmallVectorImpl &groupOps, + SmallVectorImpl &groupTypes, ArrayRef operandSegmentsPrefix, + ArrayRef requiredAttrs) { + if (parser.parseComma() || parser.parseKeyword("group") || parser.parseLParen()) + return failure(); + + OpAsmParser::UnresolvedOperand group; + if (parser.parseOperand(group)) + return failure(); + groupOps.push_back(group); + while (succeeded(parser.parseOptionalComma())) { + if (parser.parseOperand(group)) + return failure(); + groupOps.push_back(group); + } + + if (parser.parseRParen()) + return failure(); + + if (parser.parseColon()) + return failure(); + + for (size_t i = 0; i < fixedTypes.size(); ++i) { + if (i != 0 && parser.parseComma()) + return failure(); + if (parser.parseType(fixedTypes[i])) + return failure(); + } + if (parser.parseComma() || parser.parseType(recvClause.pingTy)) + return failure(); + if (recvClause.pong) { + if (parser.parseComma() || parser.parseType(recvClause.pongTy)) + return failure(); + } + for (size_t i = 0; i < groupOps.size(); ++i) { + Type groupTy; + if (parser.parseComma() || parser.parseType(groupTy)) + return failure(); + groupTypes.push_back(groupTy); + } + if (parser.parseRParen()) + return failure(); + + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs)) + return failure(); + for (StringRef attrName : requiredAttrs) { + if (!attrs.get(attrName)) { + return parser.emitError(parser.getCurrentLocation()) + << "expected '" << attrName << "' attribute"; + } + } + result.addAttributes(attrs); + + for (auto [operand, type] : llvm::zip_equal(fixedOperands, fixedTypes)) { + if (parser.resolveOperand(operand, type, result.operands)) + return failure(); + } + if (parser.resolveOperand(recvClause.ping, recvClause.pingTy, result.operands)) + return failure(); + if (recvClause.pong && + parser.resolveOperand(*recvClause.pong, recvClause.pongTy, result.operands)) + return failure(); + if (parser.resolveOperands(groupOps, groupTypes, parser.getCurrentLocation(), + result.operands)) + return failure(); + + SmallVector segmentSizes(operandSegmentsPrefix.begin(), + operandSegmentsPrefix.end()); + segmentSizes.push_back(static_cast(groupOps.size())); + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr(segmentSizes)); + return success(); +} + +static void printCommRecvClause(OpAsmPrinter &p, Value ping, Value pong) { + p << "recv(" << ping; + if (pong) + p << ", " << pong; + p << ")"; +} + +static void printCommGroupTypes(OpAsmPrinter &p, ValueRange group) { + for (Value groupValue : group) + p << ", " << groupValue.getType(); +} + +static void printCommGroupClause(OpAsmPrinter &p, ValueRange group) { + p << "group("; + p.printOperands(group); + p << ")"; +} + +} // namespace + +ParseResult mlir::pto::TBroadcastOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand src; + CommRecvClause recvClause; + SmallVector groupOps; + SmallVector groupTypes; + + if (parser.parseLParen() || parser.parseOperand(src) || parser.parseComma()) + return failure(); + if (failed(parseCommRecvClause(parser, recvClause))) + return failure(); + + SmallVector fixedOperands{src}; + SmallVector fixedTypes(1); + if (failed(parseCommCollectiveTail(parser, result, fixedOperands, fixedTypes, + recvClause, groupOps, groupTypes, + {1, 1, recvClause.pong ? 1 : 0}, {"root"}))) + return failure(); + return success(); +} + +void mlir::pto::TBroadcastOp::print(OpAsmPrinter &p) { + p << "(" << getSrc() << ", "; + printCommRecvClause(p, getPing(), getPong()); + p << ", "; + printCommGroupClause(p, getGroup()); + p << " : " << getSrc().getType() << ", " << getPing().getType(); + if (getPong()) + p << ", " << getPong().getType(); + printCommGroupTypes(p, getGroup()); + p << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); +} + +ParseResult mlir::pto::CommTGatherOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand dst; + CommRecvClause recvClause; + SmallVector groupOps; + SmallVector groupTypes; + + if (parser.parseLParen() || parser.parseOperand(dst) || parser.parseComma()) + return failure(); + if (failed(parseCommRecvClause(parser, recvClause))) + return failure(); + + SmallVector fixedOperands{dst}; + SmallVector fixedTypes(1); + if (failed(parseCommCollectiveTail( + parser, result, fixedOperands, fixedTypes, recvClause, groupOps, + groupTypes, {1, 1, recvClause.pong ? 1 : 0}, + {"root"}))) + return failure(); + return success(); +} + +void mlir::pto::CommTGatherOp::print(OpAsmPrinter &p) { + p << "(" << getDst() << ", "; + printCommRecvClause(p, getPing(), getPong()); + p << ", "; + printCommGroupClause(p, getGroup()); + p << " : " << getDst().getType() << ", " << getPing().getType(); + if (getPong()) + p << ", " << getPong().getType(); + printCommGroupTypes(p, getGroup()); + p << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); +} + +ParseResult mlir::pto::CommTScatterOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand src; + CommRecvClause recvClause; + SmallVector groupOps; + SmallVector groupTypes; + + if (parser.parseLParen() || parser.parseOperand(src) || parser.parseComma()) + return failure(); + if (failed(parseCommRecvClause(parser, recvClause))) + return failure(); + + SmallVector fixedOperands{src}; + SmallVector fixedTypes(1); + if (failed(parseCommCollectiveTail( + parser, result, fixedOperands, fixedTypes, recvClause, groupOps, + groupTypes, {1, 1, recvClause.pong ? 1 : 0}, + {"root"}))) + return failure(); + return success(); +} + +void mlir::pto::CommTScatterOp::print(OpAsmPrinter &p) { + p << "(" << getSrc() << ", "; + printCommRecvClause(p, getPing(), getPong()); + p << ", "; + printCommGroupClause(p, getGroup()); + p << " : " << getSrc().getType() << ", " << getPing().getType(); + if (getPong()) + p << ", " << getPong().getType(); + printCommGroupTypes(p, getGroup()); + p << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); +} + +ParseResult mlir::pto::TReduceOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand dst, acc; + CommRecvClause recvClause; + SmallVector groupOps; + SmallVector groupTypes; + + if (parser.parseLParen() || parser.parseOperand(dst) || parser.parseComma() || + parser.parseOperand(acc) || parser.parseComma()) + return failure(); + if (failed(parseCommRecvClause(parser, recvClause))) + return failure(); + + SmallVector fixedOperands{dst, acc}; + SmallVector fixedTypes(2); + if (failed(parseCommCollectiveTail( + parser, result, fixedOperands, fixedTypes, recvClause, groupOps, + groupTypes, {1, 1, 1, recvClause.pong ? 1 : 0}, + {"reduceOp", "root"}))) + return failure(); + return success(); +} + +void mlir::pto::TReduceOp::print(OpAsmPrinter &p) { + p << "(" << getDst() << ", " << getAcc() << ", "; + printCommRecvClause(p, getRecvPing(), getRecvPong()); + p << ", "; + printCommGroupClause(p, getGroup()); + p << " : " << getDst().getType() << ", " << getAcc().getType() << ", " + << getRecvPing().getType(); + if (getRecvPong()) + p << ", " << getRecvPong().getType(); + printCommGroupTypes(p, getGroup()); + p << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); +} + ParseResult mlir::pto::MakeTensorViewOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand ptr; diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 1ef2b6e07..97db0d78d 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -5713,6 +5713,48 @@ static FailureOr buildCommTileValue(ConversionPatternRewriter &rewriter, return buildAsyncScratchTileValue(rewriter, loc, originalValue, emittedValue); } +static FailureOr buildCollectiveParallelGroup( + ConversionPatternRewriter &rewriter, Location loc, + ArrayRef groupGTs, int64_t root) { + if (groupGTs.empty()) + return failure(); + + auto firstTy = dyn_cast(groupGTs.front().getType()); + if (!firstTy) + return failure(); + + auto *ctx = rewriter.getContext(); + auto arrayTy = emitc::ArrayType::get({static_cast(groupGTs.size())}, + firstTy); + auto groupArray = cast>( + rewriter + .create(loc, arrayTy, + emitc::OpaqueAttr::get(ctx, "{}")) + .getResult()); + + auto indexTy = emitc::OpaqueType::get(ctx, "int"); + for (auto [idx, groupVal] : llvm::enumerate(groupGTs)) { + Value idxVal = + makeEmitCIntConstant(rewriter, loc, indexTy, static_cast(idx)); + Value slot = + rewriter.create(loc, groupArray, ValueRange{idxVal}) + .getResult(); + rewriter.create(loc, slot, groupVal); + } + + std::string pgTypeStr = + (Twine("pto::comm::ParallelGroup<") + firstTy.getValue() + ">").str(); + auto pgTy = emitc::OpaqueType::get(ctx, pgTypeStr); + Value sizeVal = makeEmitCIntConstant(rewriter, loc, indexTy, + static_cast(groupGTs.size())); + Value rootVal = makeEmitCIntConstant(rewriter, loc, indexTy, root); + return rewriter + .create( + loc, TypeRange{pgTy}, (Twine(pgTypeStr) + "::Create").str(), + ArrayAttr{}, ArrayAttr{}, ValueRange{groupArray, sizeVal, rootVal}) + .getResult(0); +} + static std::string notifyOpTok(pto::NotifyOp op) { switch (op) { case pto::NotifyOp::AtomicAdd: @@ -5782,9 +5824,6 @@ struct PTOCommCollectiveToEmitC : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - SmallVector operands; - std::string helperName; - auto buildPong = [&](Value original, Value emitted, StringRef name) -> FailureOr { if (!original) return failure(); @@ -5801,20 +5840,22 @@ struct PTOCommCollectiveToEmitC : public OpConversionPattern { buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) return rewriter.notifyMatchFailure(op, "failed to materialize broadcast operands"); - operands.push_back(*srcGT); - operands.push_back(*pingTile); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize broadcast group"); if (op.getPong()) { FailureOr pongTile = buildPong(op.getPong(), adaptor.getPong(), "__pong"); if (failed(pongTile)) return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - operands.push_back(*pongTile); + rewriter.create( + loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile, *pongTile}); + } else { + rewriter.create( + loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile}); } - for (size_t i = 0; i < groupGTs->size(); ++i) - operands.push_back((*groupGTs)[i]); - helperName = "PTOAS__COMM_TBROADCAST"; - if (op.getPong()) - helperName = "PTOAS__COMM_TBROADCAST_PONG"; } else if constexpr (std::is_same_v) { FailureOr dstGT = buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), @@ -5825,20 +5866,22 @@ struct PTOCommCollectiveToEmitC : public OpConversionPattern { buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); if (failed(dstGT) || failed(pingTile) || failed(groupGTs)) return rewriter.notifyMatchFailure(op, "failed to materialize gather operands"); - operands.push_back(*dstGT); - operands.push_back(*pingTile); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize gather group"); if (op.getPong()) { FailureOr pongTile = buildPong(op.getPong(), adaptor.getPong(), "__pong"); if (failed(pongTile)) return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - operands.push_back(*pongTile); + rewriter.create( + loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *pingTile, *pongTile}); + } else { + rewriter.create( + loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *pingTile}); } - for (size_t i = 0; i < groupGTs->size(); ++i) - operands.push_back((*groupGTs)[i]); - helperName = "PTOAS__COMM_TGATHER"; - if (op.getPong()) - helperName = "PTOAS__COMM_TGATHER_PONG"; } else if constexpr (std::is_same_v) { FailureOr srcGT = buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), @@ -5849,20 +5892,22 @@ struct PTOCommCollectiveToEmitC : public OpConversionPattern { buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) return rewriter.notifyMatchFailure(op, "failed to materialize scatter operands"); - operands.push_back(*srcGT); - operands.push_back(*pingTile); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize scatter group"); if (op.getPong()) { FailureOr pongTile = buildPong(op.getPong(), adaptor.getPong(), "__pong"); if (failed(pongTile)) return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - operands.push_back(*pongTile); + rewriter.create( + loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile, *pongTile}); + } else { + rewriter.create( + loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile}); } - for (size_t i = 0; i < groupGTs->size(); ++i) - operands.push_back((*groupGTs)[i]); - helperName = "PTOAS__COMM_TSCATTER"; - if (op.getPong()) - helperName = "PTOAS__COMM_TSCATTER_PONG"; } else { FailureOr dstGT = buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), @@ -5875,29 +5920,31 @@ struct PTOCommCollectiveToEmitC : public OpConversionPattern { buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); if (failed(dstGT) || failed(accTile) || failed(recvPing) || failed(groupGTs)) return rewriter.notifyMatchFailure(op, "failed to materialize reduce operands"); - operands.push_back(*dstGT); - operands.push_back(*accTile); - operands.push_back(*recvPing); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize reduce group"); if (op.getRecvPong()) { FailureOr recvPong = buildPong(op.getRecvPong(), adaptor.getRecvPong(), "__recv_pong"); if (failed(recvPong)) return rewriter.notifyMatchFailure(op, "failed to materialize recv_pong"); - operands.push_back(*recvPong); + auto reduceTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); + Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, + reduceOpTok(op.getReduceOp())); + rewriter.create( + loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *accTile, *recvPing, *recvPong, reduceOp}); + } else { + auto reduceTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); + Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, + reduceOpTok(op.getReduceOp())); + rewriter.create( + loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *accTile, *recvPing, reduceOp}); } - for (size_t i = 0; i < groupGTs->size(); ++i) - operands.push_back((*groupGTs)[i]); - helperName = "PTOAS__COMM_TREDUCE"; - if (op.getRecvPong()) - helperName = "PTOAS__COMM_TREDUCE_PONG"; - } - - std::string callee = (Twine(helperName) + "<" + Twine(op.getRoot())).str(); - if constexpr (std::is_same_v) - callee += ", " + reduceOpTok(op.getReduceOp()); - callee += ">"; - rewriter.create(loc, TypeRange{}, callee, ArrayAttr{}, - ArrayAttr{}, operands); + } rewriter.eraseOp(op); return success(); } @@ -5966,26 +6013,30 @@ struct PTOSignalCommToEmitC : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "failed to materialize signal operand"); if constexpr (std::is_same_v) { - std::string actualCallee = - "PTOAS__COMM_TNOTIFY<" + notifyOpTok(op.getNotifyOp()) + ">"; - SmallVector operands{*signalGT, peelUnrealized(adaptor.getValue())}; - rewriter.create(op.getLoc(), TypeRange{}, actualCallee, + auto notifyTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::NotifyOp"); + Value notifyOp = makeEmitCOpaqueConstant( + rewriter, op.getLoc(), notifyTy, notifyOpTok(op.getNotifyOp())); + SmallVector operands{*signalGT, peelUnrealized(adaptor.getValue()), + notifyOp}; + rewriter.create(op.getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, operands); rewriter.eraseOp(op); } else { - SmallVector operands{*signalGT, peelUnrealized(adaptor.getCmpValue())}; + auto waitCmpTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::WaitCmp"); + Value waitCmp = makeEmitCOpaqueConstant( + rewriter, op.getLoc(), waitCmpTy, waitCmpTok(op.getCmp())); + SmallVector operands{*signalGT, peelUnrealized(adaptor.getCmpValue()), + waitCmp}; if constexpr (std::is_same_v) { Type resultTy = this->getTypeConverter()->convertType(op.getResult().getType()); if (!resultTy) return rewriter.notifyMatchFailure(op, "failed to convert ttest result type"); - std::string actualCallee = - "PTOAS__COMM_TTEST<" + waitCmpTok(op.getCmp()) + ">"; rewriter.replaceOpWithNewOp( - op, TypeRange{resultTy}, actualCallee, ArrayAttr{}, ArrayAttr{}, operands); + op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, operands); } else { - std::string actualCallee = - "PTOAS__COMM_TWAIT<" + waitCmpTok(op.getCmp()) + ">"; - rewriter.create(op.getLoc(), TypeRange{}, actualCallee, + rewriter.create(op.getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, operands); rewriter.eraseOp(op); } @@ -11232,11 +11283,11 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add>(typeConverter, ctx, "pto::comm::TGET"); patterns.add>(typeConverter, ctx, - "([&](auto &__signal, auto __value){ pto::comm::TNOTIFY(__signal, __value, "); + "pto::comm::TNOTIFY"); patterns.add>(typeConverter, ctx, - "([&](auto &__signal, auto __cmp){ pto::comm::TWAIT(__signal, __cmp, "); + "pto::comm::TWAIT"); patterns.add>(typeConverter, ctx, - "([&](auto &__signal, auto __cmp){ return pto::comm::TTEST(__signal, __cmp, "); + "pto::comm::TTEST"); patterns.add>(typeConverter, ctx, "TBROADCAST"); patterns.add>(typeConverter, ctx, @@ -11346,7 +11397,6 @@ struct EmitPTOManualPass bool needsEventIdArrayHelper = false; bool needsTRandomHelper = false; bool needsGlobalTensorDataHelper = false; - bool needsCommHelper = false; bool needsCommInclude = false; mop.walk([&](Operation *op) { if (isa(op)) @@ -11355,11 +11405,6 @@ struct EmitPTOManualPass needsTRandomHelper = true; if (isa(op)) needsGlobalTensorDataHelper = true; - if (isa(op)) - needsCommHelper = true; if (isa(dst, key, counter); } -)cpp")); - } - if (needsCommHelper) { - builder.create( - loc, builder.getStringAttr(R"cpp( -template -static AICORE inline void PTOAS__COMM_TBROADCAST( - Src &src, Ping &ping, FirstGroup &firstGroup, RestGroup &...restGroup) { - using GroupTensor = std::decay_t; - GroupTensor group[] = {firstGroup, restGroup...}; - auto pg = pto::comm::ParallelGroup::Create( - group, 1 + sizeof...(RestGroup), Root); - pto::comm::TBROADCAST(pg, src, ping); -} - -template -static AICORE inline void PTOAS__COMM_TBROADCAST_PONG( - Src &src, Ping &ping, Pong &pong, FirstGroup &firstGroup, RestGroup &...restGroup) { - using GroupTensor = std::decay_t; - GroupTensor group[] = {firstGroup, restGroup...}; - auto pg = pto::comm::ParallelGroup::Create( - group, 1 + sizeof...(RestGroup), Root); - pto::comm::TBROADCAST(pg, src, ping, pong); -} - -template -static AICORE inline void PTOAS__COMM_TGATHER( - Dst &dst, Ping &ping, FirstGroup &firstGroup, RestGroup &...restGroup) { - using GroupTensor = std::decay_t; - GroupTensor group[] = {firstGroup, restGroup...}; - auto pg = pto::comm::ParallelGroup::Create( - group, 1 + sizeof...(RestGroup), Root); - pto::comm::TGATHER(pg, dst, ping); -} - -template -static AICORE inline void PTOAS__COMM_TGATHER_PONG( - Dst &dst, Ping &ping, Pong &pong, FirstGroup &firstGroup, RestGroup &...restGroup) { - using GroupTensor = std::decay_t; - GroupTensor group[] = {firstGroup, restGroup...}; - auto pg = pto::comm::ParallelGroup::Create( - group, 1 + sizeof...(RestGroup), Root); - pto::comm::TGATHER(pg, dst, ping, pong); -} - -template -static AICORE inline void PTOAS__COMM_TSCATTER( - Src &src, Ping &ping, FirstGroup &firstGroup, RestGroup &...restGroup) { - using GroupTensor = std::decay_t; - GroupTensor group[] = {firstGroup, restGroup...}; - auto pg = pto::comm::ParallelGroup::Create( - group, 1 + sizeof...(RestGroup), Root); - pto::comm::TSCATTER(pg, src, ping); -} - -template -static AICORE inline void PTOAS__COMM_TSCATTER_PONG( - Src &src, Ping &ping, Pong &pong, FirstGroup &firstGroup, RestGroup &...restGroup) { - using GroupTensor = std::decay_t; - GroupTensor group[] = {firstGroup, restGroup...}; - auto pg = pto::comm::ParallelGroup::Create( - group, 1 + sizeof...(RestGroup), Root); - pto::comm::TSCATTER(pg, src, ping, pong); -} - -template -static AICORE inline void PTOAS__COMM_TREDUCE( - Dst &dst, Acc &acc, RecvPing &recvPing, FirstGroup &firstGroup, - RestGroup &...restGroup) { - using GroupTensor = std::decay_t; - GroupTensor group[] = {firstGroup, restGroup...}; - auto pg = pto::comm::ParallelGroup::Create( - group, 1 + sizeof...(RestGroup), Root); - pto::comm::TREDUCE(pg, dst, acc, recvPing, Op); -} - -template -static AICORE inline void PTOAS__COMM_TREDUCE_PONG( - Dst &dst, Acc &acc, RecvPing &recvPing, RecvPong &recvPong, - FirstGroup &firstGroup, RestGroup &...restGroup) { - using GroupTensor = std::decay_t; - GroupTensor group[] = {firstGroup, restGroup...}; - auto pg = pto::comm::ParallelGroup::Create( - group, 1 + sizeof...(RestGroup), Root); - pto::comm::TREDUCE(pg, dst, acc, recvPing, recvPong, Op); -} - -template -static AICORE inline void PTOAS__COMM_TNOTIFY(Signal &signal, int32_t value) { - pto::comm::TNOTIFY(signal, value, Op); -} - -template -static AICORE inline void PTOAS__COMM_TWAIT(Signal &signal, int32_t value) { - pto::comm::TWAIT(signal, value, Cmp); -} - -template -static AICORE inline bool PTOAS__COMM_TTEST(Signal &signal, int32_t value) { - return pto::comm::TTEST(signal, value, Cmp); -} )cpp")); } builder.create( diff --git a/test/lit/pto/comm_collective_emitc.pto b/test/lit/pto/comm_collective_emitc.pto index cf85be819..01c3afceb 100644 --- a/test/lit/pto/comm_collective_emitc.pto +++ b/test/lit/pto/comm_collective_emitc.pto @@ -18,22 +18,22 @@ module { %ping = pto.alloc_tile : !pto.tile_buf %pong = pto.alloc_tile : !pto.tile_buf %acc = pto.alloc_tile : !pto.tile_buf - "pto.comm.tbroadcast"(%src, %ping, %peer0, %peer1, %peer2) <{operandSegmentSizes = array, root = 1 : i32}> : (!pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) -> () - "pto.comm.tbroadcast"(%src, %ping, %pong, %peer0, %peer1, %peer2) <{operandSegmentSizes = array, root = 1 : i32}> : (!pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) -> () - "pto.comm.comm_tgather"(%dst, %ping, %peer0, %peer1, %peer2) <{operandSegmentSizes = array, root = 1 : i32}> : (!pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) -> () - "pto.comm.comm_tgather"(%dst, %ping, %pong, %peer0, %peer1, %peer2) <{operandSegmentSizes = array, root = 1 : i32}> : (!pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) -> () - "pto.comm.comm_tscatter"(%src, %ping, %peer0, %peer1, %peer2) <{operandSegmentSizes = array, root = 1 : i32}> : (!pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) -> () - "pto.comm.comm_tscatter"(%src, %ping, %pong, %peer0, %peer1, %peer2) <{operandSegmentSizes = array, root = 1 : i32}> : (!pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) -> () - "pto.comm.treduce"(%dst, %acc, %ping, %peer0, %peer1, %peer2) <{operandSegmentSizes = array, reduceOp = #pto, root = 1 : i32}> : (!pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) -> () - "pto.comm.treduce"(%dst, %acc, %ping, %pong, %peer0, %peer1, %peer2) <{operandSegmentSizes = array, reduceOp = #pto, root = 1 : i32}> : (!pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) -> () + pto.comm.tbroadcast(%src, recv(%ping), group(%peer0, %peer1, %peer2) : !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) {root = 1 : i32} + pto.comm.tbroadcast(%src, recv(%ping, %pong), group(%peer0, %peer1, %peer2) : !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) {root = 1 : i32} + pto.comm.tgather(%dst, recv(%ping), group(%peer0, %peer1, %peer2) : !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) {root = 1 : i32} + pto.comm.tgather(%dst, recv(%ping, %pong), group(%peer0, %peer1, %peer2) : !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) {root = 1 : i32} + pto.comm.tscatter(%src, recv(%ping), group(%peer0, %peer1, %peer2) : !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) {root = 1 : i32} + pto.comm.tscatter(%src, recv(%ping, %pong), group(%peer0, %peer1, %peer2) : !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) {root = 1 : i32} + pto.comm.treduce(%dst, %acc, recv(%ping), group(%peer0, %peer1, %peer2) : !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) {reduceOp = #pto, root = 1 : i32} + pto.comm.treduce(%dst, %acc, recv(%ping, %pong), group(%peer0, %peer1, %peer2) : !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>) {reduceOp = #pto, root = 1 : i32} return } } +// A3: pto::comm::ReduceOp::Max +// A3: pto::comm::ReduceOp::Sum // A3: pto::comm::ParallelGroup // A3: pto::comm::TBROADCAST( // A3: pto::comm::TGATHER( // A3: pto::comm::TSCATTER( // A3: pto::comm::TREDUCE( -// A3: pto::comm::ReduceOp::Sum -// A3: pto::comm::ReduceOp::Max diff --git a/test/lit/pto/comm_p2p_emitc.pto b/test/lit/pto/comm_p2p_emitc.pto index 44f6e5e4d..191b76777 100644 --- a/test/lit/pto/comm_p2p_emitc.pto +++ b/test/lit/pto/comm_p2p_emitc.pto @@ -14,20 +14,24 @@ module { %signal = pto.partition_view %signal_view, offsets = [%c0], sizes = [%c1] : !pto.tensor_view<1xi32> -> !pto.partition_tensor_view<1xi32> %ping = pto.alloc_tile : !pto.tile_buf %pong = pto.alloc_tile : !pto.tile_buf - "pto.comm.tput"(%dst, %src, %ping) <{atomicType = #pto}> : (!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf) -> () - "pto.comm.tput"(%dst, %src, %ping, %pong) <{atomicType = #pto}> : (!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf) -> () - "pto.comm.tget"(%dst, %src, %ping) : (!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf) -> () - "pto.comm.tget"(%dst, %src, %ping, %pong) : (!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf) -> () - "pto.comm.tnotify"(%signal, %c7_i32) <{notifyOp = #pto}> : (!pto.partition_tensor_view<1xi32>, i32) -> () - "pto.comm.twait"(%signal, %c7_i32) <{cmp = #pto}> : (!pto.partition_tensor_view<1xi32>, i32) -> () - %tested = "pto.comm.ttest"(%signal, %c7_i32) <{cmp = #pto}> : (!pto.partition_tensor_view<1xi32>, i32) -> i1 + pto.comm.tput(%dst, %src, buf(%ping) : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf) {atomicType = #pto} + pto.comm.tput(%dst, %src, buf(%ping, %pong) : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf) {atomicType = #pto} + pto.comm.tget(%dst, %src, buf(%ping) : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf) + pto.comm.tget(%dst, %src, buf(%ping, %pong) : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf, !pto.tile_buf) + pto.comm.tnotify(%signal, %c7_i32 : !pto.partition_tensor_view<1xi32>, i32) {notifyOp = #pto} + pto.comm.twait(%signal, %c7_i32 : !pto.partition_tensor_view<1xi32>, i32) {cmp = #pto} + %tested = pto.comm.ttest(%signal, %c7_i32 : !pto.partition_tensor_view<1xi32>, i32) {cmp = #pto} -> i1 return } } +// A3: pto::comm::WaitCmp::EQ +// A3: pto::comm::WaitCmp::GE +// A3: pto::comm::NotifyOp::Set // A3: pto::comm::TPUT( // A3: pto::comm::TPUT( // A3: pto::comm::TGET( -// A3: PTOAS__COMM_TNOTIFY( -// A3: PTOAS__COMM_TWAIT( -// A3: PTOAS__COMM_TTEST( +// A3: pto::comm::TGET( +// A3: pto::comm::TNOTIFY( +// A3: pto::comm::TWAIT( +// A3: pto::comm::TTEST( diff --git a/test/samples/CommSync/tbroadcast_root_binding.pto b/test/samples/CommSync/tbroadcast_root_binding.pto new file mode 100644 index 000000000..758422899 --- /dev/null +++ b/test/samples/CommSync/tbroadcast_root_binding.pto @@ -0,0 +1,25 @@ +module { + func.func @TBroadCastKernelImpl(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: !pto.ptr, %arg4: !pto.ptr, %arg5: i32, %arg6: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c1_i32 = arith.constant 1 : i32 + %0 = pto.make_tensor_view %arg0, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %1 = pto.partition_view %0, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %2 = pto.make_tensor_view %arg1, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %3 = pto.partition_view %2, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %4 = pto.make_tensor_view %arg2, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %5 = pto.partition_view %4, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %6 = pto.make_tensor_view %arg3, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %7 = pto.partition_view %6, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %8 = pto.make_tensor_view %arg4, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %9 = pto.partition_view %8, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %10 = pto.alloc_tile : !pto.tile_buf + %11 = arith.cmpi eq, %arg5, %c1_i32 : i32 + scf.if %11 { + pto.comm.tbroadcast(%1, recv(%10), group(%3, %5, %7, %9) : !pto.partition_tensor_view<256xf32>, !pto.tile_buf, !pto.partition_tensor_view<256xf32>, !pto.partition_tensor_view<256xf32>, !pto.partition_tensor_view<256xf32>, !pto.partition_tensor_view<256xf32>) {root = 1 : i32} + } + pto.barrier + return + } +} diff --git a/test/samples/CommSync/tbroadcast_root_binding.py b/test/samples/CommSync/tbroadcast_root_binding.py new file mode 100644 index 000000000..a79ecd759 --- /dev/null +++ b/test/samples/CommSync/tbroadcast_root_binding.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, F32Type, IndexType, InsertionPoint, IntegerType, Location, Module +from mlir.dialects import arith, func, pto, scf + + +def _make_part(ptr, tv_ty, pv_ty, c0, c1, count): + tv = pto.MakeTensorViewOp(tv_ty, ptr, [count], [c1]).result + return pto.PartitionViewOp(pv_ty, tv, offsets=[c0], sizes=[count]).result + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + module = Module.create() + + f32 = F32Type.get(ctx) + i32 = IntegerType.get_signless(32, ctx) + idx = IndexType.get(ctx) + + ptr_f32 = pto.PtrType.get(f32, ctx) + tv_f32 = pto.TensorViewType.get([256], f32, ctx) + pv_f32 = pto.PartitionTensorViewType.get([256], f32, ctx) + + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + pipe_all = pto.PipeAttr.get(pto.PIPE.PIPE_ALL, ctx) + tb_f32 = pto.TileBufType.get([1, 256], f32, vec, [1, 256], None, ctx) + + fn_ty = func.FunctionType.get( + [ptr_f32, ptr_f32, ptr_f32, ptr_f32, ptr_f32, i32, i32], [] + ) + with InsertionPoint(module.body): + fn = func.FuncOp("TBroadCastKernelImpl", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + input_ptr, out0_ptr, out1_ptr, out2_ptr, out3_ptr, my_rank, nranks = entry.arguments + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c256 = arith.ConstantOp(idx, 256).result + c1_i32 = arith.ConstantOp(i32, 1).result + _ = nranks + + src = _make_part(input_ptr, tv_f32, pv_f32, c0, c1, c256) + out0 = _make_part(out0_ptr, tv_f32, pv_f32, c0, c1, c256) + out1 = _make_part(out1_ptr, tv_f32, pv_f32, c0, c1, c256) + out2 = _make_part(out2_ptr, tv_f32, pv_f32, c0, c1, c256) + out3 = _make_part(out3_ptr, tv_f32, pv_f32, c0, c1, c256) + group = [out0, out1, out2, out3] + + ping = pto.AllocTileOp(tb_f32).result + + is_root = arith.CmpIOp(arith.CmpIPredicate.eq, my_rank, c1_i32).result + root_if = scf.IfOp(is_root, [], hasElse=False) + with InsertionPoint(root_if.then_block): + pto.TBroadcastOp(src, ping, group, 1) + scf.YieldOp([]) + + pto.barrier(pipe_all) + func.ReturnOp([]) + + module.operation.verify() + return module + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/CommSync/tgather_root_binding.pto b/test/samples/CommSync/tgather_root_binding.pto new file mode 100644 index 000000000..15628a1d1 --- /dev/null +++ b/test/samples/CommSync/tgather_root_binding.pto @@ -0,0 +1,25 @@ +module { + func.func @TGatherKernelImpl(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: !pto.ptr, %arg4: !pto.ptr, %arg5: i32, %arg6: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c1_i32 = arith.constant 1 : i32 + %0 = pto.make_tensor_view %arg0, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %1 = pto.partition_view %0, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %2 = pto.make_tensor_view %arg1, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %3 = pto.partition_view %2, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %4 = pto.make_tensor_view %arg2, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %5 = pto.partition_view %4, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %6 = pto.make_tensor_view %arg3, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %7 = pto.partition_view %6, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %8 = pto.make_tensor_view %arg4, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %9 = pto.partition_view %8, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %10 = pto.alloc_tile : !pto.tile_buf + %11 = arith.cmpi eq, %arg5, %c1_i32 : i32 + scf.if %11 { + pto.comm.tgather(%1, recv(%10), group(%3, %5, %7, %9) : !pto.partition_tensor_view<256xf32>, !pto.tile_buf, !pto.partition_tensor_view<256xf32>, !pto.partition_tensor_view<256xf32>, !pto.partition_tensor_view<256xf32>, !pto.partition_tensor_view<256xf32>) {root = 1 : i32} + } + pto.barrier + return + } +} diff --git a/test/samples/CommSync/tgather_root_binding.py b/test/samples/CommSync/tgather_root_binding.py new file mode 100644 index 000000000..10e688b23 --- /dev/null +++ b/test/samples/CommSync/tgather_root_binding.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, F32Type, IndexType, InsertionPoint, IntegerType, Location, Module +from mlir.dialects import arith, func, pto, scf + + +def _make_part(ptr, tv_ty, pv_ty, c0, c1, count): + tv = pto.MakeTensorViewOp(tv_ty, ptr, [count], [c1]).result + return pto.PartitionViewOp(pv_ty, tv, offsets=[c0], sizes=[count]).result + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + module = Module.create() + + f32 = F32Type.get(ctx) + i32 = IntegerType.get_signless(32, ctx) + idx = IndexType.get(ctx) + + ptr_f32 = pto.PtrType.get(f32, ctx) + tv_f32 = pto.TensorViewType.get([256], f32, ctx) + pv_f32 = pto.PartitionTensorViewType.get([256], f32, ctx) + + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + pipe_all = pto.PipeAttr.get(pto.PIPE.PIPE_ALL, ctx) + tb_f32 = pto.TileBufType.get([1, 256], f32, vec, [1, 256], None, ctx) + + fn_ty = func.FunctionType.get( + [ptr_f32, ptr_f32, ptr_f32, ptr_f32, ptr_f32, i32, i32], [] + ) + with InsertionPoint(module.body): + fn = func.FuncOp("TGatherKernelImpl", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + dst_ptr, src0_ptr, src1_ptr, src2_ptr, src3_ptr, my_rank, nranks = entry.arguments + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c256 = arith.ConstantOp(idx, 256).result + c1_i32 = arith.ConstantOp(i32, 1).result + _ = nranks + + dst = _make_part(dst_ptr, tv_f32, pv_f32, c0, c1, c256) + src0 = _make_part(src0_ptr, tv_f32, pv_f32, c0, c1, c256) + src1 = _make_part(src1_ptr, tv_f32, pv_f32, c0, c1, c256) + src2 = _make_part(src2_ptr, tv_f32, pv_f32, c0, c1, c256) + src3 = _make_part(src3_ptr, tv_f32, pv_f32, c0, c1, c256) + group = [src0, src1, src2, src3] + + ping = pto.AllocTileOp(tb_f32).result + + is_root = arith.CmpIOp(arith.CmpIPredicate.eq, my_rank, c1_i32).result + root_if = scf.IfOp(is_root, [], hasElse=False) + with InsertionPoint(root_if.then_block): + pto.CommTGatherOp(dst, ping, group, 1) + scf.YieldOp([]) + + pto.barrier(pipe_all) + func.ReturnOp([]) + + module.operation.verify() + return module + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/CommSync/tnotify_atomic_add_binding-pto-ir.pto b/test/samples/CommSync/tnotify_atomic_add_binding-pto-ir.pto new file mode 100644 index 000000000..0badefdb8 --- /dev/null +++ b/test/samples/CommSync/tnotify_atomic_add_binding-pto-ir.pto @@ -0,0 +1,12 @@ +module { + func.func @TNotifyAtomicAddKernel(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_i32 = arith.constant 1 : i32 + %0 = pto.make_tensor_view %arg0, shape = [%c1], strides = [%c1] : !pto.tensor_view<1xi32> + %1 = pto.partition_view %0, offsets = [%c0], sizes = [%c1] : !pto.tensor_view<1xi32> + pto.comm.tnotify(%1, %c1_i32 : !pto.partition_tensor_view<1xi32>, i32) {notifyOp = #pto} + return + } +} + diff --git a/test/samples/CommSync/tnotify_atomic_add_binding.py b/test/samples/CommSync/tnotify_atomic_add_binding.py new file mode 100644 index 000000000..d1609e218 --- /dev/null +++ b/test/samples/CommSync/tnotify_atomic_add_binding.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, IndexType, InsertionPoint, IntegerType, Location, Module +from mlir.dialects import arith, func, pto + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + module = Module.create() + + i32 = IntegerType.get_signless(32, ctx) + idx = IndexType.get(ctx) + + ptr_i32 = pto.PtrType.get(i32, ctx) + tv_i32 = pto.TensorViewType.get([1], i32, ctx) + pv_i32 = pto.PartitionTensorViewType.get([1], i32, ctx) + + fn_ty = func.FunctionType.get([ptr_i32], []) + with InsertionPoint(module.body): + fn = func.FuncOp("TNotifyAtomicAddKernel", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + signal_ptr = entry.arguments[0] + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + one_i32 = arith.ConstantOp(i32, 1).result + + signal_view = pto.MakeTensorViewOp(tv_i32, signal_ptr, [c1], [c1]).result + signal = pto.PartitionViewOp( + pv_i32, signal_view, offsets=[c0], sizes=[c1] + ).result + + pto.TNotifyOp( + signal, + one_i32, + pto.NotifyOpAttr.get(pto.NotifyOp.AtomicAdd, ctx), + ) + + func.ReturnOp([]) + + module.operation.verify() + return module + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/CommSync/treduce_root_binding.pto b/test/samples/CommSync/treduce_root_binding.pto new file mode 100644 index 000000000..6c79bc2e9 --- /dev/null +++ b/test/samples/CommSync/treduce_root_binding.pto @@ -0,0 +1,26 @@ +module { + func.func @TReduceKernelImpl(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: !pto.ptr, %arg4: !pto.ptr, %arg5: i32, %arg6: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c1_i32 = arith.constant 1 : i32 + %0 = pto.make_tensor_view %arg0, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %1 = pto.partition_view %0, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %2 = pto.make_tensor_view %arg1, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %3 = pto.partition_view %2, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %4 = pto.make_tensor_view %arg2, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %5 = pto.partition_view %4, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %6 = pto.make_tensor_view %arg3, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %7 = pto.partition_view %6, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %8 = pto.make_tensor_view %arg4, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %9 = pto.partition_view %8, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %10 = pto.alloc_tile : !pto.tile_buf + %11 = pto.alloc_tile : !pto.tile_buf + %12 = arith.cmpi eq, %arg5, %c1_i32 : i32 + scf.if %12 { + pto.comm.treduce(%9, %10, recv(%11), group(%1, %3, %5, %7) : !pto.partition_tensor_view<256xf32>, !pto.tile_buf, !pto.tile_buf, !pto.partition_tensor_view<256xf32>, !pto.partition_tensor_view<256xf32>, !pto.partition_tensor_view<256xf32>, !pto.partition_tensor_view<256xf32>) {reduceOp = #pto, root = 1 : i32} + } + pto.barrier + return + } +} diff --git a/test/samples/CommSync/treduce_root_binding.py b/test/samples/CommSync/treduce_root_binding.py new file mode 100644 index 000000000..4d3d0de38 --- /dev/null +++ b/test/samples/CommSync/treduce_root_binding.py @@ -0,0 +1,83 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, F32Type, IndexType, InsertionPoint, IntegerType, Location, Module +from mlir.dialects import arith, func, pto, scf + + +def _make_part(ptr, tv_ty, pv_ty, c0, c1, count): + tv = pto.MakeTensorViewOp(tv_ty, ptr, [count], [c1]).result + return pto.PartitionViewOp(pv_ty, tv, offsets=[c0], sizes=[count]).result + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + module = Module.create() + + f32 = F32Type.get(ctx) + i32 = IntegerType.get_signless(32, ctx) + idx = IndexType.get(ctx) + + ptr_f32 = pto.PtrType.get(f32, ctx) + tv_f32 = pto.TensorViewType.get([256], f32, ctx) + pv_f32 = pto.PartitionTensorViewType.get([256], f32, ctx) + + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + pipe_all = pto.PipeAttr.get(pto.PIPE.PIPE_ALL, ctx) + tb_f32 = pto.TileBufType.get([1, 256], f32, vec, [1, 256], None, ctx) + + fn_ty = func.FunctionType.get( + [ptr_f32, ptr_f32, ptr_f32, ptr_f32, ptr_f32, i32, i32], [] + ) + with InsertionPoint(module.body): + fn = func.FuncOp("TReduceKernelImpl", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + input0_ptr, input1_ptr, input2_ptr, input3_ptr, output_ptr, my_rank, nranks = entry.arguments + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c256 = arith.ConstantOp(idx, 256).result + c1_i32 = arith.ConstantOp(i32, 1).result + _ = nranks + + in0 = _make_part(input0_ptr, tv_f32, pv_f32, c0, c1, c256) + in1 = _make_part(input1_ptr, tv_f32, pv_f32, c0, c1, c256) + in2 = _make_part(input2_ptr, tv_f32, pv_f32, c0, c1, c256) + in3 = _make_part(input3_ptr, tv_f32, pv_f32, c0, c1, c256) + dst = _make_part(output_ptr, tv_f32, pv_f32, c0, c1, c256) + group = [in0, in1, in2, in3] + + acc = pto.AllocTileOp(tb_f32).result + recv = pto.AllocTileOp(tb_f32).result + + is_root = arith.CmpIOp(arith.CmpIPredicate.eq, my_rank, c1_i32).result + root_if = scf.IfOp(is_root, [], hasElse=False) + with InsertionPoint(root_if.then_block): + pto.TReduceOp( + dst, + acc, + recv, + group, + pto.ReduceOpAttr.get(pto.ReduceOp.Sum, ctx), + 1, + ) + scf.YieldOp([]) + + pto.barrier(pipe_all) + func.ReturnOp([]) + + module.operation.verify() + return module + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/CommSync/tscatter_root_binding.pto b/test/samples/CommSync/tscatter_root_binding.pto new file mode 100644 index 000000000..2a9ad3ef9 --- /dev/null +++ b/test/samples/CommSync/tscatter_root_binding.pto @@ -0,0 +1,25 @@ +module { + func.func @TScatterKernelImpl(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: !pto.ptr, %arg4: !pto.ptr, %arg5: i32, %arg6: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c1_i32 = arith.constant 1 : i32 + %0 = pto.make_tensor_view %arg0, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %1 = pto.partition_view %0, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %2 = pto.make_tensor_view %arg1, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %3 = pto.partition_view %2, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %4 = pto.make_tensor_view %arg2, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %5 = pto.partition_view %4, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %6 = pto.make_tensor_view %arg3, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %7 = pto.partition_view %6, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %8 = pto.make_tensor_view %arg4, shape = [%c256], strides = [%c1] : !pto.tensor_view<256xf32> + %9 = pto.partition_view %8, offsets = [%c0], sizes = [%c256] : !pto.tensor_view<256xf32> + %10 = pto.alloc_tile : !pto.tile_buf + %11 = arith.cmpi eq, %arg5, %c1_i32 : i32 + scf.if %11 { + pto.comm.tscatter(%1, recv(%10), group(%3, %5, %7, %9) : !pto.partition_tensor_view<256xf32>, !pto.tile_buf, !pto.partition_tensor_view<256xf32>, !pto.partition_tensor_view<256xf32>, !pto.partition_tensor_view<256xf32>, !pto.partition_tensor_view<256xf32>) {root = 1 : i32} + } + pto.barrier + return + } +} diff --git a/test/samples/CommSync/tscatter_root_binding.py b/test/samples/CommSync/tscatter_root_binding.py new file mode 100644 index 000000000..b7441cbdb --- /dev/null +++ b/test/samples/CommSync/tscatter_root_binding.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, F32Type, IndexType, InsertionPoint, IntegerType, Location, Module +from mlir.dialects import arith, func, pto, scf + + +def _make_part(ptr, tv_ty, pv_ty, c0, c1, count): + tv = pto.MakeTensorViewOp(tv_ty, ptr, [count], [c1]).result + return pto.PartitionViewOp(pv_ty, tv, offsets=[c0], sizes=[count]).result + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + module = Module.create() + + f32 = F32Type.get(ctx) + i32 = IntegerType.get_signless(32, ctx) + idx = IndexType.get(ctx) + + ptr_f32 = pto.PtrType.get(f32, ctx) + tv_f32 = pto.TensorViewType.get([256], f32, ctx) + pv_f32 = pto.PartitionTensorViewType.get([256], f32, ctx) + + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + pipe_all = pto.PipeAttr.get(pto.PIPE.PIPE_ALL, ctx) + tb_f32 = pto.TileBufType.get([1, 256], f32, vec, [1, 256], None, ctx) + + fn_ty = func.FunctionType.get( + [ptr_f32, ptr_f32, ptr_f32, ptr_f32, ptr_f32, i32, i32], [] + ) + with InsertionPoint(module.body): + fn = func.FuncOp("TScatterKernelImpl", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + src_ptr, dst0_ptr, dst1_ptr, dst2_ptr, dst3_ptr, my_rank, nranks = entry.arguments + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c256 = arith.ConstantOp(idx, 256).result + c1_i32 = arith.ConstantOp(i32, 1).result + _ = nranks + + src = _make_part(src_ptr, tv_f32, pv_f32, c0, c1, c256) + dst0 = _make_part(dst0_ptr, tv_f32, pv_f32, c0, c1, c256) + dst1 = _make_part(dst1_ptr, tv_f32, pv_f32, c0, c1, c256) + dst2 = _make_part(dst2_ptr, tv_f32, pv_f32, c0, c1, c256) + dst3 = _make_part(dst3_ptr, tv_f32, pv_f32, c0, c1, c256) + group = [dst0, dst1, dst2, dst3] + + ping = pto.AllocTileOp(tb_f32).result + + is_root = arith.CmpIOp(arith.CmpIPredicate.eq, my_rank, c1_i32).result + root_if = scf.IfOp(is_root, [], hasElse=False) + with InsertionPoint(root_if.then_block): + pto.CommTScatterOp(src, ping, group, 1) + scf.YieldOp([]) + + pto.barrier(pipe_all) + func.ReturnOp([]) + + module.operation.verify() + return module + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/CommSync/twait_atomic_binding.py b/test/samples/CommSync/twait_atomic_binding.py new file mode 100644 index 000000000..60be0becb --- /dev/null +++ b/test/samples/CommSync/twait_atomic_binding.py @@ -0,0 +1,80 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, IndexType, InsertionPoint, IntegerType, Location, Module +from mlir.dialects import arith, func, pto, scf + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + module = Module.create() + + i32 = IntegerType.get_signless(32, ctx) + idx = IndexType.get(ctx) + ptr_i32 = pto.PtrType.get(i32, ctx) + tv_i32 = pto.TensorViewType.get([1], i32, ctx) + pv_i32 = pto.PartitionTensorViewType.get([1], i32, ctx) + pipe_all = pto.PipeAttr.get(pto.PIPE.PIPE_ALL, ctx) + + fn_ty = func.FunctionType.get([ptr_i32, ptr_i32, i32, i32, i32], []) + with InsertionPoint(module.body): + fn = func.FuncOp("TWaitAtomicKernel", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + local_counter_ptr, remote_counter_ptr, threshold, iters, my_rank = entry.arguments + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c0_i32 = arith.ConstantOp(i32, 0).result + c1_i32 = arith.ConstantOp(i32, 1).result + + local_view = pto.MakeTensorViewOp(tv_i32, local_counter_ptr, [c1], [c1]).result + remote_view = pto.MakeTensorViewOp(tv_i32, remote_counter_ptr, [c1], [c1]).result + local_counter = pto.PartitionViewOp( + pv_i32, local_view, offsets=[c0], sizes=[c1] + ).result + remote_counter = pto.PartitionViewOp( + pv_i32, remote_view, offsets=[c0], sizes=[c1] + ).result + + is_non_root = arith.CmpIOp(arith.CmpIPredicate.ne, my_rank, c0_i32).result + branch = scf.IfOp(is_non_root, [], hasElse=True) + + with InsertionPoint(branch.then_block): + iters_idx = arith.IndexCastOp(idx, iters).result + loop = scf.ForOp(c0, iters_idx, c1, []) + with InsertionPoint(loop.body): + pto.TNotifyOp( + remote_counter, + c1_i32, + pto.NotifyOpAttr.get(pto.NotifyOp.AtomicAdd, ctx), + ) + scf.YieldOp([]) + pto.barrier(pipe_all) + scf.YieldOp([]) + + with InsertionPoint(branch.else_block): + pto.TWaitOp( + local_counter, + threshold, + pto.WaitCmpAttr.get(pto.WaitCmp.GE, ctx), + ) + pto.barrier(pipe_all) + scf.YieldOp([]) + + func.ReturnOp([]) + + module.operation.verify() + return module + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/runop.sh b/test/samples/runop.sh index 7753c438c..a337c053a 100755 --- a/test/samples/runop.sh +++ b/test/samples/runop.sh @@ -19,7 +19,7 @@ PYTHON_BIN="${PYTHON_BIN:-}" PTOAS_OUT_DIR="${PTOAS_OUT_DIR:-}" PTOAS_ENABLE_INSERT_SYNC="${PTOAS_ENABLE_INSERT_SYNC:-1}" PTOAS_FLAGS="${PTOAS_FLAGS:-}" -PTO_PTO_DIRS="${PTO_PTO_DIRS:-Sync Qwen3DecodeA3 Qwen3DecodeA5}" +PTO_PTO_DIRS="${PTO_PTO_DIRS:-Sync Qwen3DecodeA3 Qwen3DecodeA5 CommSync}" ENABLE_BC=0 usage() { @@ -1034,13 +1034,13 @@ PY fi fi - if [[ "$base" == "comm_p2p" || "$base" == "comm_p2p_binding_variants" ]]; then + if [[ "$base" == "comm_p2p" || "$base" == "comm_p2p_binding_variants" || "$base" == "comm_multicard_all_ops" ]]; then for pat in \ "pto::comm::TPUT(" \ "pto::comm::TGET(" \ - "PTOAS__COMM_TNOTIFY<" \ - "PTOAS__COMM_TWAIT<" \ - "PTOAS__COMM_TTEST<"; do + "pto::comm::TNOTIFY(" \ + "pto::comm::TWAIT(" \ + "pto::comm::TTEST("; do if ! grep -Fq "$pat" "$cpp"; then echo -e "${A}(${base}.py)\tFAIL\tmissing $pat lowering" overall=1 @@ -1052,37 +1052,80 @@ PY overall=1 continue fi - if [[ "$base" == "comm_p2p_binding_variants" ]]; then + if [[ "$base" == "comm_p2p_binding_variants" || "$base" == "comm_multicard_all_ops" ]]; then for pat in \ - "pto::comm::NotifyOp::Set" \ "pto::comm::NotifyOp::AtomicAdd" \ - "pto::comm::WaitCmp::GE" \ - "pto::comm::WaitCmp::LE" \ - "pto::comm::WaitCmp::EQ" \ - "pto::comm::WaitCmp::NE"; do + "pto::comm::WaitCmp::GE"; do if ! grep -Fq "$pat" "$cpp"; then echo -e "${A}(${base}.py)\tFAIL\tmissing $pat lowering" overall=1 continue 2 fi done + if [[ "$base" != "twait_atomic_binding" ]]; then + for pat in \ + "pto::comm::NotifyOp::Set" \ + "pto::comm::WaitCmp::LE" \ + "pto::comm::WaitCmp::EQ" \ + "pto::comm::WaitCmp::NE"; do + if ! grep -Fq "$pat" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing $pat lowering" + overall=1 + continue 2 + fi + done + fi fi fi - if [[ "$base" == "comm_collective" || "$base" == "comm_collective_binding_variants" ]]; then + if [[ "$base" == "twait_atomic_binding" ]]; then + for pat in \ + "__global__ AICORE void TWaitAtomicKernel(" \ + "pto::comm::TNOTIFY(" \ + "pto::comm::TWAIT("; do + if ! grep -Fq "$pat" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing $pat lowering" + overall=1 + continue 2 + fi + done + fi + + if [[ "$base" == "tnotify_atomic_add_binding" ]]; then + for pat in \ + "__global__ AICORE void TNotifyAtomicAddKernel(" \ + "pto::comm::TNOTIFY("; do + if ! grep -Fq "$pat" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing $pat lowering" + overall=1 + continue 2 + fi + done + fi + + if [[ "$base" == "comm_collective" || "$base" == "comm_collective_binding_variants" || "$base" == "comm_multicard_all_ops" ]]; then for pat in \ "pto::comm::ParallelGroup" \ - "pto::comm::TBROADCAST(" \ - "pto::comm::TGATHER(" \ - "pto::comm::TSCATTER(" \ - "pto::comm::TREDUCE("; do + "pto::comm::TBROADCAST("; do if ! grep -Fq "$pat" "$cpp"; then echo -e "${A}(${base}.py)\tFAIL\tmissing $pat lowering" overall=1 continue 2 fi done - if ! grep -Fq "pto::comm::ReduceOp::Sum" "$cpp" || ! grep -Fq "pto::comm::ReduceOp::Max" "$cpp"; then + if [[ "$base" != "tbroadcast_root_binding" && "$base" != "tgather_root_binding" && "$base" != "tscatter_root_binding" && "$base" != "treduce_root_binding" ]]; then + for pat in \ + "pto::comm::TGATHER(" \ + "pto::comm::TSCATTER(" \ + "pto::comm::TREDUCE("; do + if ! grep -Fq "$pat" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing $pat lowering" + overall=1 + continue 2 + fi + done + fi + if [[ "$base" != "tbroadcast_root_binding" && "$base" != "tgather_root_binding" && "$base" != "tscatter_root_binding" && "$base" != "treduce_root_binding" ]] && (! grep -Fq "pto::comm::ReduceOp::Sum" "$cpp" || ! grep -Fq "pto::comm::ReduceOp::Max" "$cpp"); then echo -e "${A}(${base}.py)\tFAIL\tmissing reduce-op enum lowering" overall=1 continue @@ -1096,6 +1139,95 @@ PY fi fi + if [[ "$base" == "tbroadcast_root_binding" ]]; then + for pat in \ + "__global__ AICORE void TBroadCastKernelImpl(" \ + "pto::comm::TBROADCAST(" \ + "pto::comm::ParallelGroup"; do + if ! grep -Fq "$pat" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing $pat lowering" + overall=1 + continue 2 + fi + done + fi + + if [[ "$base" == "tgather_root_binding" ]]; then + for pat in \ + "__global__ AICORE void TGatherKernelImpl(" \ + "pto::comm::TGATHER(" \ + "pto::comm::ParallelGroup"; do + if ! grep -Fq "$pat" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing $pat lowering" + overall=1 + continue 2 + fi + done + fi + + if [[ "$base" == "tscatter_root_binding" ]]; then + for pat in \ + "__global__ AICORE void TScatterKernelImpl(" \ + "pto::comm::TSCATTER(" \ + "pto::comm::ParallelGroup"; do + if ! grep -Fq "$pat" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing $pat lowering" + overall=1 + continue 2 + fi + done + fi + + if [[ "$base" == "treduce_root_binding" ]]; then + for pat in \ + "__global__ AICORE void TReduceKernelImpl(" \ + "pto::comm::TREDUCE(" \ + "pto::comm::ParallelGroup" \ + "pto::comm::ReduceOp::Sum"; do + if ! grep -Fq "$pat" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing $pat lowering" + overall=1 + continue 2 + fi + done + fi + + if [[ "$base" == "a5_comm_st_sync_flows" ]]; then + for pat in \ + "pto::comm::TPUT(" \ + "pto::comm::TGET(" \ + "pto::comm::TNOTIFY(" \ + "pto::comm::TWAIT(" \ + "pto::comm::TTEST(" \ + "pto::comm::ParallelGroup" \ + "pto::comm::TBROADCAST(" \ + "pto::comm::TGATHER(" \ + "pto::comm::TSCATTER(" \ + "pto::comm::TREDUCE(" \ + "pto::comm::ReduceOp::Sum"; do + if ! grep -Fq "$pat" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing $pat lowering" + overall=1 + continue 2 + fi + done + fi + + if [[ "$base" == "a5_comm_st_async_flows" ]]; then + for pat in \ + "pto::comm::BuildAsyncSession<" \ + "pto::comm::TGET_ASYNC<" \ + "pto::comm::TPUT_ASYNC<" \ + "pto::comm::AsyncEvent" \ + "pto::comm::AsyncSession"; do + if ! grep -Fq "$pat" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing $pat lowering" + overall=1 + continue 2 + fi + done + fi + # Regression guard for Issue #190: # Infer layout for a 2D column-vector view (16 x 1) should prefer DN. if [[ "$base" == "tensor_view_infer_layout_dn" ]]; then diff --git a/tools/ptobc/generated/ptobc_opcodes_v0.h b/tools/ptobc/generated/ptobc_opcodes_v0.h index fffc5cc51..1f1b529ed 100644 --- a/tools/ptobc/generated/ptobc_opcodes_v0.h +++ b/tools/ptobc/generated/ptobc_opcodes_v0.h @@ -210,8 +210,8 @@ inline constexpr OpInfo kOpTable[] = { {0x1094, "pto.comm.twait", 0, 0x00, 0x02, 0, 0, 0, 0x00}, {0x1095, "pto.comm.ttest", 0, 0x01, 0x02, 0, 1, 0, 0x00}, {0x1096, "pto.comm.tbroadcast", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1097, "pto.comm.comm_tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1098, "pto.comm.comm_tscatter", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1097, "pto.comm.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1098, "pto.comm.tscatter", 0, 0x00, 0x02, 0, 0, 0, 0x00}, {0x1099, "pto.comm.treduce", 0, 0x00, 0x02, 0, 0, 0, 0x00}, {0x2000, "arith.addi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, {0x2001, "arith.ceildivsi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, @@ -413,8 +413,8 @@ inline std::optional lookupOpcodeByName(llvm::StringRef name) { .Case("pto.comm.twait", 0x1094) .Case("pto.comm.ttest", 0x1095) .Case("pto.comm.tbroadcast", 0x1096) - .Case("pto.comm.comm_tgather", 0x1097) - .Case("pto.comm.comm_tscatter", 0x1098) + .Case("pto.comm.tgather", 0x1097) + .Case("pto.comm.tscatter", 0x1098) .Case("pto.comm.treduce", 0x1099) .Case("scf.for", 0x4000) .Case("scf.if", 0x4001) @@ -601,8 +601,8 @@ inline std::optional lookupOpcodeAndVariantByFullName(llvm::St .Case("pto.comm.twait", OpcodeAndVariant{0x1094, 0, 0}) .Case("pto.comm.ttest", OpcodeAndVariant{0x1095, 0, 0}) .Case("pto.comm.tbroadcast", OpcodeAndVariant{0x1096, 0, 0}) - .Case("pto.comm.comm_tgather", OpcodeAndVariant{0x1097, 0, 0}) - .Case("pto.comm.comm_tscatter", OpcodeAndVariant{0x1098, 0, 0}) + .Case("pto.comm.tgather", OpcodeAndVariant{0x1097, 0, 0}) + .Case("pto.comm.tscatter", OpcodeAndVariant{0x1098, 0, 0}) .Case("pto.comm.treduce", OpcodeAndVariant{0x1099, 0, 0}) .Case("scf.for", OpcodeAndVariant{0x4000, 0, 0}) .Case("scf.if", OpcodeAndVariant{0x4001, 0, 0})