Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions examples/workers/l3/moe_multi_chip_experts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# `moe_multi_chip_experts/` — one expert per chip

Runs a small distributed Mixture-of-Experts pipeline across multiple chips.
Each rank owns one expert, exchanges token slices through HCCL window buffers,
applies a simple per-expert compute kernel, and gathers the processed expert
results back to the source ranks.

This example is intentionally tiny: `NUM_TOKENS = 10`, `HIDDEN_DIM = 16`, and
only the first `COUNT = 4` tokens are processed. The small shape makes the
data movement easy to inspect while still exercising cross-chip dispatch,
compute, and combine.

## What This Demonstrates

| Concept | Where it shows up |
| ------- | ----------------- |
| L3 multi-chip worker | `Worker(level=3, device_ids=[...])` in `main.py` |
| HCCL bootstrap buffers | `ChipBootstrapConfig` with `scratch1` and `scratch2` |
| Cross-rank dispatch | `kernels/aiv/moe_dispatch_alltoall.cpp` |
| Per-rank expert compute | `kernels/aiv/moe_simple_compute.cpp` |
| Cross-rank combine | `kernels/aiv/moe_combine_alltoall.cpp` |
| Device orchestration | `kernels/orchestration/moe_end2end_orch.cpp` |
| Pytest integration | `test_moe_multi_chip_experts.py` calls `main.run(...)` |

## Layout

```text
moe_multi_chip_experts/
main.py # CLI demo and reusable run() entry
test_moe_multi_chip_experts.py # pytest wrapper, matching other L3 examples
kernels/
aiv/
moe_dispatch_alltoall.cpp # publish each rank's expert input
moe_simple_compute.cpp # add 1.0 to dispatched token slices
moe_combine_alltoall.cpp # gather processed expert outputs
orchestration/
moe_end2end_orch.cpp # submit dispatch -> compute -> combine
README.md
```

## Pipeline

For `N` chips, each chip owns one expert and starts with:

```text
send[expert_id][token][hidden]
recv[source_rank][token][hidden]
output[expert_id][token][hidden]
```

The orchestration submits three AIV kernels:

```text
┌──────────┐ ┌─────────┐ ┌─────────┐
│ Dispatch │ ───▶ │ Compute │ ───▶ │ Combine │
└──────────┘ └─────────┘ └─────────┘
```

1. Dispatch writes each rank's expert slice into the owner rank's `recv`.
2. Compute adds `1.0` to the first `COUNT` tokens in `recv`.
3. Combine copies each expert's processed slice into the source rank's
`output[expert_id]` row.

`scratch1` is the HCCL window used by dispatch. `scratch2` is the HCCL window
used by combine. Compute only updates `recv`; it does not use either scratch
window.

The two communication phases use independent windows mainly because each
kernel places its barrier signal slots at the tail of its scratch buffer and
does not reset those slots before use. Dispatch leaves its signal slots
incremented after its cross-rank barrier. If combine reused the same window,
its `TWAIT` could observe the old dispatch signals and pass before combine has
staged its own data. A separate `scratch2` gives combine independent data
storage and independent signal slots.

## Data Pattern

Inputs are initialized with unique values:

```text
value = card_id * 1_000_000 + expert_id * 10_000 + token * 100 + dim
```

After compute, every checked output value should be the corresponding input
value plus `1.0`. `main.py` computes the golden reference in Python and checks
every `output[expert_id][token][hidden]` element for the processed token
range.

## Run

Hardware:

```bash
python examples/workers/l3/moe_multi_chip_experts/main.py -p a2a3 -d 0-1
```

Simulation:

```bash
python examples/workers/l3/moe_multi_chip_experts/main.py -p a2a3sim -d 0-1
```

The pytest wrapper follows the same style as the other L3 examples:

```bash
python -m pytest examples/workers/l3/moe_multi_chip_experts --platform a2a3 --device 0-1
```

For the CLI, device ids can be written as a range (`-d 0-1`) or a
comma-separated list (`-d 0,1`). For pytest, pass the same device spec to
`--device`. The examples use ranges because that matches the other L3 docs.

Expected successful output for the two-chip commands above includes:

```text
[End2End] End-to-end pipeline completed!
Total: 256/256 correct
[End2End] All values correct! End-to-end pipeline works perfectly.
```

## Notes

- `test_moe_multi_chip_experts.py` is a thin pytest wrapper around
`main.run(...)`.
- The pytest case runs on `a2a3` hardware and requires two available device
ids.
- Each rank allocates independent `scratch1` and `scratch2` HCCL windows
during worker bootstrap.
9 changes: 9 additions & 0 deletions examples/workers/l3/moe_multi_chip_experts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# -----------------------------------------------------------------------------------------------------------
"""Multi-chip MoE example package."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
/*
* Copyright (c) PyPTO Contributors.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
* -----------------------------------------------------------------------------------------------------------
*/
/**
* MoE Combine All-to-All Kernel (Direct Store Version)
*
* This kernel implements the combine phase of distributed MoE:
* Each card i sends recv[i][card_j] (expert_i's result for card_j) to card j,
* then directly stores all received results to output (one expert per output row).
*
* Data flow:
* Phase 1 (stage-in): recv[:][:][:COUNT][:] → scratch[my_rank][:][:][:]
* Phase 2 (barrier): signal matrix + TWAIT cross-rank sync
* Phase 3 (store): for expert_i in num_cards: copy scratch[expert_i][my_rank][:][:] to output[expert_i][:][:]
*
* Output layout:
* output[expert_i][token_t][:] = data from expert_i for this card, token t
*
* args layout:
* tensor(0) = recv_local [num_cards][num_tokens][hidden_dim]
* tensor(1) = output_local [num_cards][count][hidden_dim] - stores all experts' data
* tensor(2) = scratch HCCL window buffer
* tensor(3) = scratch_print Debug output buffer (Phase 1 stage-in mirror)
* scalar(0) = card_id which card this is
* scalar(1) = num_cards total number of cards
* scalar(2) = CommContext device pointer for cross-card communication
*/

#include <cstdint>
#include <pto/pto-inst.hpp>
#include "pto/comm/comm_types.hpp"
#include "pto/comm/pto_comm_inst.hpp"
#include "platform_comm/comm_context.h"
#include "tensor.h"

#ifndef __gm__
#define __gm__
#endif

#ifndef __aicore__
#define __aicore__ [aicore]
#endif

// Configuration matching the in-test golden references
static constexpr size_t NUM_TOKENS = 10;
static constexpr size_t HIDDEN_DIM = 16;
static constexpr size_t COUNT = 4; // tokens to process per (card, expert) pair
static constexpr int kMaxSupportedCards = 16;

template <typename T>
AICORE inline __gm__ T *CommRemotePtr(__gm__ CommContext *ctx, __gm__ T *localPtr, int pe) {
uint64_t localBase = ctx->windowsIn[ctx->rankId];
uint64_t offset = (uint64_t)localPtr - localBase;
return (__gm__ T *)(ctx->windowsIn[pe] + offset);
}

extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t *args) {
// Unpack tensors
__gm__ Tensor *recv_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]);
__gm__ Tensor *output_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]);
__gm__ Tensor *scratch_tensor = reinterpret_cast<__gm__ Tensor *>(args[2]);
__gm__ Tensor *scratch_print_tensor = reinterpret_cast<__gm__ Tensor *>(args[3]);

// Unpack scalars
int64_t card_id = static_cast<int64_t>(args[4]);
int num_cards = static_cast<int>(args[5]);
__gm__ CommContext *commCtx = reinterpret_cast<__gm__ CommContext *>(args[6]);

// Get base pointers
__gm__ float *recv = reinterpret_cast<__gm__ float *>(recv_tensor->buffer.addr) + recv_tensor->start_offset;
__gm__ float *output = reinterpret_cast<__gm__ float *>(output_tensor->buffer.addr) + output_tensor->start_offset;
__gm__ float *scratch =
reinterpret_cast<__gm__ float *>(scratch_tensor->buffer.addr) + scratch_tensor->start_offset;
__gm__ float *scratch_print =
reinterpret_cast<__gm__ float *>(scratch_print_tensor->buffer.addr) + scratch_print_tensor->start_offset;

// Signal area at tail of scratch: num_cards int32 slots
// Must be placed AFTER all data slots to avoid corruption
size_t total_data_size = num_cards * num_cards * NUM_TOKENS * HIDDEN_DIM;
__gm__ int32_t *signal_base = reinterpret_cast<__gm__ int32_t *>(scratch + total_data_size);

using ShapeDyn = pto::Shape<pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC>;
using StrideDyn = pto::Stride<pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC>;
using Global = pto::GlobalTensor<float, ShapeDyn, StrideDyn, pto::Layout::ND>;

int my_rank = static_cast<int>(commCtx->rankId);

if (num_cards <= 0 || num_cards > kMaxSupportedCards) {
pipe_barrier(PIPE_ALL);
return;
}

// ------------------------------------------------------------------
// Phase 1: stage-in — copy recv to scratch
// This card's expert result for all cards (as destination)
//
//
// For card_i with expert_id, copy recv[card_j][:][:] to scratch[expert_id][card_j][:][:]
// ------------------------------------------------------------------
for (int card_j = 0; card_j < num_cards; ++card_j) {
for (size_t t = 0; t < COUNT; ++t) {
// Source: recv[card_j][t][:HIDDEN_DIM] (expert_id's processed data from card_j)
// recv layout: [num_cards][NUM_TOKENS][HIDDEN_DIM]
// Base points to current (card_j, t), stride should keep access within current token
ShapeDyn src_shape(1, 1, 1, 1, HIDDEN_DIM);
StrideDyn src_stride(
NUM_TOKENS * HIDDEN_DIM, NUM_TOKENS * HIDDEN_DIM, NUM_TOKENS * HIDDEN_DIM, HIDDEN_DIM, 1
);
Global srcG(recv + card_j * NUM_TOKENS * HIDDEN_DIM + t * HIDDEN_DIM, src_shape, src_stride);

// Destination: scratch[my_rank][card_j][t][:HIDDEN_DIM]
// Offset = my_rank * (num_cards * NUM_TOKENS * HIDDEN_DIM)
// + card_j * (NUM_TOKENS * HIDDEN_DIM)
// + t * HIDDEN_DIM
size_t dst_offset =
my_rank * num_cards * NUM_TOKENS * HIDDEN_DIM + card_j * NUM_TOKENS * HIDDEN_DIM + t * HIDDEN_DIM;

ShapeDyn dst_shape(1, 1, 1, 1, HIDDEN_DIM);
StrideDyn dst_stride(
num_cards * NUM_TOKENS * HIDDEN_DIM, num_cards * NUM_TOKENS * HIDDEN_DIM, NUM_TOKENS * HIDDEN_DIM,
HIDDEN_DIM, 1
);
Global dstG(scratch + dst_offset, dst_shape, dst_stride);
Global dstG_print(scratch_print + dst_offset, dst_shape, dst_stride);

using TileType = pto::Tile<pto::TileType::Vec, float, 1, HIDDEN_DIM, pto::BLayout::RowMajor, -1, -1>;
TileType tile(1, HIDDEN_DIM);
TASSIGN(tile, 0);

TLOAD(tile, srcG);
set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
TSTORE(dstG, tile);
TSTORE(dstG_print, tile);
set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
}
}
pipe_barrier(PIPE_ALL);

// ------------------------------------------------------------------
// Phase 2: device barrier — each card notifies peers that its
// recv[:][my_card] data is visible in scratch, then waits for all peers.
// ------------------------------------------------------------------
for (int peer = 0; peer < num_cards; ++peer) {
if (peer == my_rank) continue;
__gm__ int32_t *remote_signal = CommRemotePtr(commCtx, signal_base + my_rank, peer);
pto::comm::Signal sig(remote_signal);
pto::comm::TNOTIFY(sig, (int32_t)1, pto::comm::NotifyOp::AtomicAdd);
}
for (int peer = 0; peer < num_cards; ++peer) {
if (peer == my_rank) continue;
pto::comm::Signal sig(signal_base + peer);
pto::comm::TWAIT(sig, (int32_t)1, pto::comm::WaitCmp::GE);
}
pipe_barrier(PIPE_ALL);

// ------------------------------------------------------------------
// Phase 3: direct store — copy each expert's data to output
// Read scratch[expert_i][my_rank][t][:HIDDEN_DIM] from each expert i
// and store to output[expert_i][t][:HIDDEN_DIM]
//
// For card_id with my_rank:
// output[expert_0][t][:] = scratch[expert_0][my_rank][t][:]
// output[expert_1][t][:] = scratch[expert_1][my_rank][t][:]
// etc.
// ------------------------------------------------------------------
for (int expert_i = 0; expert_i < num_cards; ++expert_i) {
for (size_t t = 0; t < COUNT; ++t) {
// Source: scratch[expert_i][my_rank][t][:HIDDEN_DIM]
// Offset = expert_i * (num_cards * NUM_TOKENS * HIDDEN_DIM)
// + my_rank * (NUM_TOKENS * HIDDEN_DIM)
// + t * HIDDEN_DIM
__gm__ float *src_base = (expert_i == my_rank) ? scratch : CommRemotePtr(commCtx, scratch, expert_i);
size_t src_offset =
expert_i * num_cards * NUM_TOKENS * HIDDEN_DIM + my_rank * NUM_TOKENS * HIDDEN_DIM + t * HIDDEN_DIM;

ShapeDyn src_shape(1, 1, 1, 1, HIDDEN_DIM);
StrideDyn src_stride(
num_cards * NUM_TOKENS * HIDDEN_DIM, num_cards * NUM_TOKENS * HIDDEN_DIM, NUM_TOKENS * HIDDEN_DIM,
HIDDEN_DIM, 1
);
Global srcG(src_base + src_offset, src_shape, src_stride);

// Destination: output[expert_i][t][:HIDDEN_DIM]
// Offset = expert_i * (COUNT * HIDDEN_DIM) + t * HIDDEN_DIM
size_t dst_offset = expert_i * COUNT * HIDDEN_DIM + t * HIDDEN_DIM;

ShapeDyn dst_shape(1, 1, 1, 1, HIDDEN_DIM);
StrideDyn dst_stride(COUNT * HIDDEN_DIM, HIDDEN_DIM, HIDDEN_DIM, HIDDEN_DIM, 1);
Global dstG(output + dst_offset, dst_shape, dst_stride);

using TileType = pto::Tile<pto::TileType::Vec, float, 1, HIDDEN_DIM, pto::BLayout::RowMajor, -1, -1>;
TileType tile(1, HIDDEN_DIM);
TASSIGN(tile, 0);

// Load from scratch
TLOAD(tile, srcG);
set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);

// Store to output
TSTORE(dstG, tile);
set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
}
}

pipe_barrier(PIPE_ALL);
}
Loading
Loading