From d8dba325c08c2ef02fc3328809e0d87251f3ad9b Mon Sep 17 00:00:00 2001 From: PKUZHOU Date: Wed, 6 May 2026 16:32:46 +0800 Subject: [PATCH 1/6] Add Python distributed L4 to L3 dispatch --- docs/distributed-l4-implementation.zh.md | 341 ++++++++++++++++ docs/distributed-l4.md | 56 +++ examples/distributed/l4_l3_remote/README.md | 19 + examples/distributed/l4_l3_remote/__init__.py | 1 + .../distributed/l4_l3_remote/l3_worker.py | 5 + .../distributed/l4_l3_remote/l4_master.py | 52 +++ pyproject.toml | 3 +- python/simpler/distributed/__init__.py | 32 ++ python/simpler/distributed/catalog.py | 131 +++++++ python/simpler/distributed/l3_daemon.py | 221 +++++++++++ python/simpler/distributed/proto/__init__.py | 1 + python/simpler/distributed/proto/_gen.sh | 19 + .../simpler/distributed/proto/dispatch.proto | 92 +++++ .../simpler/distributed/proto/dispatch_pb2.py | 62 +++ .../distributed/proto/dispatch_pb2_grpc.py | 370 ++++++++++++++++++ python/simpler/distributed/remote_proxy.py | 131 +++++++ python/simpler/distributed/rpc.py | 110 ++++++ python/simpler/distributed/serialization.py | 71 ++++ python/simpler/distributed/tensor_pool.py | 69 ++++ python/simpler/worker.py | 106 +++++ tests/ut/py/test_distributed/test_catalog.py | 44 +++ .../ut/py/test_distributed/test_heartbeat.py | 29 ++ tests/ut/py/test_distributed/test_import.py | 6 + .../py/test_distributed/test_l4_l3_remote.py | 235 +++++++++++ .../py/test_distributed/test_rpc_roundtrip.py | 63 +++ .../py/test_distributed/test_tensor_pool.py | 39 ++ 26 files changed, 2307 insertions(+), 1 deletion(-) create mode 100644 docs/distributed-l4-implementation.zh.md create mode 100644 docs/distributed-l4.md create mode 100644 examples/distributed/l4_l3_remote/README.md create mode 100644 examples/distributed/l4_l3_remote/__init__.py create mode 100644 examples/distributed/l4_l3_remote/l3_worker.py create mode 100644 examples/distributed/l4_l3_remote/l4_master.py create mode 100644 python/simpler/distributed/__init__.py create mode 100644 python/simpler/distributed/catalog.py create mode 100644 python/simpler/distributed/l3_daemon.py create mode 100644 python/simpler/distributed/proto/__init__.py create mode 100755 python/simpler/distributed/proto/_gen.sh create mode 100644 python/simpler/distributed/proto/dispatch.proto create mode 100644 python/simpler/distributed/proto/dispatch_pb2.py create mode 100644 python/simpler/distributed/proto/dispatch_pb2_grpc.py create mode 100644 python/simpler/distributed/remote_proxy.py create mode 100644 python/simpler/distributed/rpc.py create mode 100644 python/simpler/distributed/serialization.py create mode 100644 python/simpler/distributed/tensor_pool.py create mode 100644 tests/ut/py/test_distributed/test_catalog.py create mode 100644 tests/ut/py/test_distributed/test_heartbeat.py create mode 100644 tests/ut/py/test_distributed/test_import.py create mode 100644 tests/ut/py/test_distributed/test_l4_l3_remote.py create mode 100644 tests/ut/py/test_distributed/test_rpc_roundtrip.py create mode 100644 tests/ut/py/test_distributed/test_tensor_pool.py diff --git a/docs/distributed-l4-implementation.zh.md b/docs/distributed-l4-implementation.zh.md new file mode 100644 index 000000000..6249797ff --- /dev/null +++ b/docs/distributed-l4-implementation.zh.md @@ -0,0 +1,341 @@ +# L4 到 L3 跨 Host Dispatch 当前实现说明 + +本文记录当前 `simpler.distributed` 的实现状态。当前版本是 Python-first 的 MVP:使用 gRPC/protobuf 做跨进程或跨 Host 控制面传输,通过本地 mailbox shim 复用现有 C++ scheduler,不新增 C++ 或 nanobind 接口。 + +## 当前目标 + +原有单机 L4 到 L3 路径依赖 `Worker.add_worker(l3_worker)`、`os.fork()` 和共享内存 mailbox。跨 Host 后,L4 不能再直接 fork 出远端 L3,也不能依赖 fork 后继承的 callable registry 和共享地址空间。 + +当前实现把边界替换为: + +```text +L4 Worker + -> 本地 PROCESS mailbox + -> Python shim thread + -> gRPC Dispatch + -> L3Daemon + -> backend process + -> inner Worker(level=3).run(...) +``` + +这使 L4 用户侧仍然通过 `orch.submit_next_level(...)` 下发任务,C++ scheduler 仍然看到一个普通的 PROCESS-mode next-level worker。 + +## 新增代码结构 + +```text +python/simpler/distributed/ + __init__.py + rpc.py # RpcServer / RpcClient 薄封装 + catalog.py # callable_id + version + cloudpickle payload + serialization.py # CallConfig / TaskArgs 与 protobuf 的转换 + remote_proxy.py # L4 侧 RemoteWorkerProxy + l3_daemon.py # L3 侧长驻 daemon + tensor_pool.py # inline / handle 字节池表面 + proto/ + dispatch.proto + dispatch_pb2.py + dispatch_pb2_grpc.py + _gen.sh + +tests/ut/py/test_distributed/ + test_import.py + test_rpc_roundtrip.py + test_catalog.py + test_l4_l3_remote.py + test_tensor_pool.py + test_heartbeat.py + +examples/distributed/l4_l3_remote/ + l3_worker.py + l4_master.py + README.md +``` + +`pyproject.toml` 新增运行时依赖: + +```toml +dependencies = ["grpcio>=1.80", "protobuf>=4.25", "cloudpickle>=2.2"] +``` + +`grpcio-tools>=1.80` 放在 test optional dependencies,用于重新生成 protobuf 文件。 + +## 协议设计 + +协议定义在 `python/simpler/distributed/proto/dispatch.proto`。 + +当前主要 service: + +```protobuf +service L3Worker { + rpc Dispatch(DispatchReq) returns (DispatchResp); + rpc Heartbeat(Empty) returns (Health); +} + +service Catalog { + rpc PullCallable(CallableRef) returns (CallablePayload); + rpc PushCallable(CallablePayload) returns (Empty); +} + +service TensorPool { + rpc PullTensor(TensorHandle) returns (stream TensorChunk); + rpc PushTensor(stream TensorChunk) returns (TensorHandle); +} +``` + +`DispatchReq` 当前承载: + +- `task_id`: L4 侧生成的请求 id +- `callable_id`: L4 registry 中的 callable id +- `callable_version`: callable payload 的 blake2b 版本号 +- `config_blob`: 序列化后的 `CallConfig` +- `scalar_args`: 标量参数 +- `tensor_args`: `ContinuousTensor` 元数据 +- `tensor_refs`: 为后续真实 tensor 数据面预留 + +`DispatchResp` 当前承载: + +- `error_code`: `0` 表示成功 +- `error_msg`: 远端失败摘要 +- `remote_traceback`: 远端 Python traceback +- `output_tensors`: 为后续 output 回传预留 + +## L4 侧实现 + +入口是 `Worker.add_remote_worker(endpoint, **options)`。 + +调用时机要求和 `add_worker()` 一致: + +- 只能在 `level >= 4` 的 Worker 上调用 +- 必须在 `Worker.init()` 前调用 +- 会复用 L4 侧已经 `register()` 的 callable registry + +初始化时做的事: + +1. 为每个 remote worker 分配一个本地 `SharedMemory` mailbox。 +2. 创建 `RemoteWorkerProxy(endpoint, catalog, **options)`。 +3. `RemoteWorkerProxy.handshake()`: + - 先发 `Heartbeat` + - 把本地 catalog 里的 callable payload 全部 `PushCallable` 到远端 + - 启动后台 heartbeat thread +4. 启动 `_remote_worker_loop` shim thread。 +5. 把 remote mailbox 注册给 C++ `_Worker.add_next_level_process(...)`。 + +之后 C++ scheduler 下发任务时,只是在本地 mailbox 写入 `TASK_READY`。shim thread 读出: + +- callable id +- `TaskArgs` +- `CallConfig` + +然后调用: + +```python +RemoteWorkerProxy.dispatch(callable_id, args, cfg) +``` + +远端返回成功后,shim thread 把 mailbox 状态写回 `TASK_DONE`。如果远端失败,则把错误写入 mailbox error 区域,后续由现有 drain/error propagation 路径抛回 L4 调用者。 + +## Callable Catalog + +`Catalog` 解决 fork-COW registry 在跨 Host 场景不可用的问题。 + +注册逻辑: + +```python +cid, version = catalog.register(fn, callable_id=cid) +``` + +版本号计算方式: + +```text +version = uint64(blake2b(cloudpickle.dumps(fn), digest_size=8)) +``` + +当前使用 `cloudpickle` 而不是标准库 `pickle`,原因是现有 L4/L3 测试和用户代码经常使用嵌套函数、lambda、closure。标准库 `pickle` 无法覆盖这些形态。 + +安全边界: + +callable payload 是可执行 Python 代码的反序列化结果,只能用于受信任集群内部。不要把 `Catalog` service 暴露给不可信客户端。 + +## L3Daemon 实现 + +`L3Daemon` 是远端 L3 节点的常驻入口。 + +启动方式: + +```bash +python -m simpler.distributed.l3_daemon --port 5050 --num-sub-workers 1 +``` + +重要实现点:daemon 不是直接在 gRPC handler thread 中运行 `Worker`。它会先启动一个 backend process: + +```text +L3Daemon process + - gRPC server threads + - Catalog service + - L3Worker service + - Pipe to backend + +Backend process + - Catalog mirror + - lazy inner Worker(level=3) + - inner Worker 的 sub/chip fork +``` + +这样做是为了避开 grpcio 与 fork 的冲突。`Worker(level=3)` 内部仍会 fork sub worker 或 chip worker;如果这个 fork 发生在已经启动 gRPC worker threads 的进程中,grpcio 可能触发 fork-safety 问题。当前设计让 backend process 在 gRPC server 启动前 fork 出来,后续所有 inner Worker 逻辑都在 backend process 内完成。 + +Catalog push 时: + +1. gRPC `Catalog.PushCallable` 先安装到 daemon 进程的 catalog。 +2. 同步转发 `("push", cid, version, payload)` 到 backend process。 +3. backend process 安装到自己的 catalog mirror。 + +Dispatch 时: + +1. gRPC handler 收到 `DispatchReq`。 +2. 将 protobuf bytes 通过 pipe 发给 backend。 +3. backend lazy 创建 inner `Worker(level=3)`。 +4. backend 把 catalog 中所有 callable 安装进 inner worker 的 `_callable_registry`。 +5. 查找 `req.callable_id / req.callable_version` 对应的 orch fn。 +6. 反序列化 `TaskArgs` 和 `CallConfig`。 +7. 调用: + +```python +inner.run(orch_fn, args, cfg) +``` + +## Heartbeat 与错误传播 + +`RemoteWorkerProxy` 在 handshake 后启动 heartbeat thread。 + +默认参数: + +- `heartbeat_interval=5.0` +- `heartbeat_timeout=1.0` +- `heartbeat_failures=3` + +连续失败达到阈值后,proxy 标记为 unavailable。后续 dispatch 会 fast-fail,不再进入 RPC 热路径。 + +错误传播路径: + +```text +backend exception + -> DispatchResp(error_code=1, error_msg, remote_traceback) + -> RemoteWorkerProxy.dispatch raises RuntimeError + -> _remote_worker_loop writes mailbox error + -> existing Worker.run drain path raises to caller +``` + +## Tensor 数据面当前状态 + +当前已经有 `tensor_pool.py` 和 proto 中的 `TensorRef / TensorHandle / TensorChunk`。 + +已实现: + +- 小字节数据 inline +- 大字节数据注册为 handle +- `PullTensor` streaming +- `PushTensor` streaming +- `ContinuousTensor` 元数据随 `DispatchReq.tensor_args` 传输 + +尚未完成: + +- 远端真实 tensor materialization +- output tensor 回写 +- `OUTPUT_EXISTING` 的远端到本地同步 +- 与 torch tensor / NPU device memory 的完整数据面打通 + +所以当前端到端 remote dispatch 测试主要覆盖 scalar `TaskArgs` 和 Python callable 执行链路。 + +## 使用示例 + +终端 1 启动 L3: + +```bash +python examples/distributed/l4_l3_remote/l3_worker.py --port 5050 +``` + +终端 2 启动 L4: + +```bash +python examples/distributed/l4_l3_remote/l4_master.py --remotes 127.0.0.1:5050 +``` + +期望输出: + +```text +remote counter=7 +``` + +## 当前测试方式 + +安装或构建: + +```bash +python -m pip install -e . +``` + +运行新增分布式测试和原有 L4 递归回归: + +```bash +python -m pytest tests/ut/py/test_distributed tests/ut/py/test_worker/test_l4_recursive.py -q +``` + +当前验证结果: + +```text +32 passed +``` + +额外检查: + +```bash +python -m compileall -q python/simpler/distributed tests/ut/py/test_distributed examples/distributed/l4_l3_remote +git diff --check +``` + +当前注意事项: + +- Python 3.13 下测试会出现多线程进程中 fork 的 `DeprecationWarning`。 +- 现有本地 L4 recursive 测试也会触发同类 warning。 +- 当前测试通过;warning 不代表断言失败。 + +## 当前实现边界 + +已完成: + +- gRPC/protobuf 控制面 +- callable catalog push/install +- L4 `add_remote_worker()` +- mailbox shim thread +- L3 daemon backend process +- scalar args dispatch +- remote traceback 传播 +- heartbeat fail-fast +- 示例和测试 + +未完成: + +- 完整 tensor 数据面 +- 多 remote 负载均衡策略 +- 节点发现或服务注册 +- 鉴权、TLS、租户隔离 +- C++ hot-path `RemoteWorkerThread` +- Urma/RDMA 数据面 + +## 设计取舍 + +当前实现优先保证低侵入: + +- 不改 C++ scheduler +- 不改 nanobind binding +- 不改变用户 orch function 的写法 +- 通过本地 mailbox shim 接入现有 PROCESS-mode 语义 + +代价是: + +- 每个 remote worker 多一个 Python shim thread +- gRPC 路径不是 hot-path 最优 +- tensor 数据面还没有真正接入 +- callable 反序列化要求可信环境 + +这个版本适合作为 L4 到 L3 跨 Host dispatch 的功能 MVP,用于继续验证协议、调度语义和错误传播;后续性能优化可以再下沉到 C++ 或替换数据面。 diff --git a/docs/distributed-l4.md b/docs/distributed-l4.md new file mode 100644 index 000000000..8084388c7 --- /dev/null +++ b/docs/distributed-l4.md @@ -0,0 +1,56 @@ +# Distributed L4 to L3 Dispatch + +This package adds a Python-first remote L3 transport for L4 `Worker` instances. +It uses gRPC/protobuf for control messages and keeps the existing C++ scheduler +and mailbox layout unchanged. + +## API + +Start an L3 daemon: + +```bash +python -m simpler.distributed.l3_daemon --port 5050 --num-sub-workers 1 +``` + +Attach it to an L4 worker: + +```python +from simpler.worker import Worker + +w4 = Worker(level=4, num_sub_workers=0) +l3_sub_cid = w4.register(l3_sub) +l3_orch_cid = w4.register(l3_orch) +w4.add_remote_worker("127.0.0.1:5050") +w4.init() +``` + +`add_remote_worker()` allocates one local mailbox and registers it as a normal +next-level PROCESS worker. A Python shim thread polls that mailbox and forwards +ready tasks to the remote L3 daemon with `L3Worker.Dispatch`. + +## Callable Catalog + +The L4 side serializes registered callables with `cloudpickle` and pushes them +to the daemon during handshake. Callable ids are preserved, so an L3 orch can +submit L3 sub callables by the same ids that were registered on L4. + +Callable payloads are trusted cluster traffic. Do not expose the catalog service +to untrusted clients. + +## Daemon Lifecycle + +The daemon starts a backend process before accepting gRPC traffic. gRPC handler +threads forward catalog and dispatch requests to that backend process. The +backend owns the inner `Worker`, so Worker child forks do not happen in a +process with active gRPC threads. + +## Tensors + +`tensor_pool.py` provides the planned inline/handle byte pool surface. Scalar +`TaskArgs` and `ContinuousTensor` metadata are wired through dispatch today; +full remote tensor materialization is isolated behind `TensorPool`. + +## Health + +Each `RemoteWorkerProxy` starts a heartbeat thread after handshake. Consecutive +heartbeat failures mark the remote unavailable, and later dispatches fail fast. diff --git a/examples/distributed/l4_l3_remote/README.md b/examples/distributed/l4_l3_remote/README.md new file mode 100644 index 000000000..6867657c2 --- /dev/null +++ b/examples/distributed/l4_l3_remote/README.md @@ -0,0 +1,19 @@ +# L4 to L3 Remote Dispatch + +Terminal 1: + +```bash +python examples/distributed/l4_l3_remote/l3_worker.py --port 5050 +``` + +Terminal 2: + +```bash +python examples/distributed/l4_l3_remote/l4_master.py --remotes 127.0.0.1:5050 +``` + +Expected output: + +```text +remote counter=7 +``` diff --git a/examples/distributed/l4_l3_remote/__init__.py b/examples/distributed/l4_l3_remote/__init__.py new file mode 100644 index 000000000..da80af90b --- /dev/null +++ b/examples/distributed/l4_l3_remote/__init__.py @@ -0,0 +1 @@ +"""Distributed L4 -> L3 example.""" diff --git a/examples/distributed/l4_l3_remote/l3_worker.py b/examples/distributed/l4_l3_remote/l3_worker.py new file mode 100644 index 000000000..6e2e74ac1 --- /dev/null +++ b/examples/distributed/l4_l3_remote/l3_worker.py @@ -0,0 +1,5 @@ +from simpler.distributed.l3_daemon import main + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/distributed/l4_l3_remote/l4_master.py b/examples/distributed/l4_l3_remote/l4_master.py new file mode 100644 index 000000000..aee49911d --- /dev/null +++ b/examples/distributed/l4_l3_remote/l4_master.py @@ -0,0 +1,52 @@ +import argparse + +from simpler.task_interface import CallConfig, TaskArgs +from simpler.worker import Worker + + +class Counter: + def __init__(self) -> None: + self.value = 0 + + def add(self, amount: int) -> None: + self.value += int(amount) + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--remotes", default="127.0.0.1:5050") + args = parser.parse_args() + + counter = Counter() + endpoints = [item.strip() for item in args.remotes.split(",") if item.strip()] + + def l3_sub(task_args): + counter.add(task_args.scalar(0)) + + w4 = Worker(level=4, num_sub_workers=0) + sub_cid = w4.register(l3_sub) + + def l3_orch(orch, task_args, config): + orch.submit_sub(sub_cid, task_args) + + l3_cid = w4.register(l3_orch) + for endpoint in endpoints: + w4.add_remote_worker(endpoint) + w4.init() + try: + def l4_orch(orch, task_args, config): + for value in (2, 5): + sub_args = TaskArgs() + sub_args.add_scalar(value) + orch.submit_next_level(l3_cid, sub_args, CallConfig()) + + w4.run(l4_orch) + finally: + w4.close() + + print(f"remote counter={counter.value}") + return 0 if counter.value == 7 else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/pyproject.toml b/pyproject.toml index 20296b40b..e66eaa977 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,13 +15,14 @@ build-backend = "scikit_build_core.build" name = "simpler" version = "0.1.0" requires-python = ">=3.9" +dependencies = ["grpcio>=1.80", "protobuf>=4.25", "cloudpickle>=2.2"] [project.optional-dependencies] # ``torch>=2.3`` is required by ``simpler_setup.torch_interop`` (uses # ``torch.uint16`` / ``torch.uint32``, added in PyTorch 2.3). CI installs # torch via a custom index URL and does not rely on this pin; it is here # so ``pip install -e '.[test]'`` resolves a usable version for local devs. -test = ["pytest>=6.0", "pytest-xdist>=3.0", "torch>=2.3"] +test = ["pytest>=6.0", "pytest-xdist>=3.0", "torch>=2.3", "grpcio-tools>=1.80"] [tool.ruff] line-length = 120 diff --git a/python/simpler/distributed/__init__.py b/python/simpler/distributed/__init__.py new file mode 100644 index 000000000..e08e69d23 --- /dev/null +++ b/python/simpler/distributed/__init__.py @@ -0,0 +1,32 @@ +"""Python-first distributed L4 -> L3 dispatch support.""" + +__all__ = [ + "Catalog", + "CatalogError", + "L3Daemon", + "RemoteUnavailable", + "RemoteWorkerProxy", + "RpcClient", + "RpcError", + "RpcServer", +] + + +def __getattr__(name): + if name in {"Catalog", "CatalogError"}: + from .catalog import Catalog, CatalogError + + return {"Catalog": Catalog, "CatalogError": CatalogError}[name] + if name == "L3Daemon": + from .l3_daemon import L3Daemon + + return L3Daemon + if name in {"RemoteUnavailable", "RemoteWorkerProxy"}: + from .remote_proxy import RemoteUnavailable, RemoteWorkerProxy + + return {"RemoteUnavailable": RemoteUnavailable, "RemoteWorkerProxy": RemoteWorkerProxy}[name] + if name in {"RpcClient", "RpcError", "RpcServer"}: + from .rpc import RpcClient, RpcError, RpcServer + + return {"RpcClient": RpcClient, "RpcError": RpcError, "RpcServer": RpcServer}[name] + raise AttributeError(name) diff --git a/python/simpler/distributed/catalog.py b/python/simpler/distributed/catalog.py new file mode 100644 index 000000000..13ba634f7 --- /dev/null +++ b/python/simpler/distributed/catalog.py @@ -0,0 +1,131 @@ +"""Callable catalog for cross-host Worker dispatch.""" + +from __future__ import annotations + +import hashlib +import importlib +import pickle +from collections.abc import Callable +from typing import Optional, Tuple + +import grpc + +from .proto import dispatch_pb2, dispatch_pb2_grpc + +try: + import cloudpickle as _pickle_impl +except Exception: # noqa: BLE001 + _pickle_impl = pickle + + +class CatalogError(RuntimeError): + pass + + +class Catalog: + """Stable callable ids plus versioned serialized payloads.""" + + def __init__(self, *, allowed_modules: Optional[Tuple[str, ...]] = None) -> None: + self._functions: dict[tuple[int, int], Callable] = {} + self._payloads: dict[tuple[int, int], bytes] = {} + self._latest: dict[int, int] = {} + self._next_id = 0 + self._allowed_modules = allowed_modules + + def register(self, fn: Callable, callable_id: Optional[int] = None) -> tuple[int, int]: + payload = _pickle_impl.dumps(fn) + version = _version(payload) + cid = self._next_id if callable_id is None else int(callable_id) + self.install_from_payload(cid, version, payload) + if callable_id is None: + self._next_id = max(self._next_id, cid + 1) + else: + self._next_id = max(self._next_id, cid + 1) + return cid, version + + def lookup(self, cid: int, version: Optional[int] = None) -> Optional[Callable]: + cid = int(cid) + if version is None or int(version) == 0: + version = self._latest.get(cid) + if version is None: + return None + return self._functions.get((cid, int(version))) + + def install_from_payload(self, cid: int, version: int, payload: bytes) -> None: + cid = int(cid) + version = int(version) + actual = _version(payload) + if version != actual: + raise CatalogError(f"callable {cid} version mismatch: expected {version}, payload has {actual}") + fn = _loads_with_allowlist(payload, self._allowed_modules) + if not callable(fn): + raise CatalogError(f"payload for callable {cid} did not deserialize to a callable") + key = (cid, version) + self._functions[key] = fn + self._payloads[key] = bytes(payload) + self._latest[cid] = version + self._next_id = max(self._next_id, cid + 1) + + def export_payload(self, cid: int, version: Optional[int] = None) -> bytes: + cid = int(cid) + if version is None or int(version) == 0: + version = self._latest.get(cid) + key = (cid, int(version)) if version is not None else None + if key not in self._payloads: + raise CatalogError(f"callable {cid} version {version} not found") + return self._payloads[key] + + def refs(self) -> list[tuple[int, int]]: + return sorted(self._payloads) + + def refs_by_id(self) -> dict[int, int]: + return dict(self._latest) + + def payloads(self) -> list[dispatch_pb2.CallablePayload]: + return [ + dispatch_pb2.CallablePayload(callable_id=cid, version=version, pickled=payload) + for (cid, version), payload in sorted(self._payloads.items()) + ] + + +def _version(payload: bytes) -> int: + return int.from_bytes(hashlib.blake2b(payload, digest_size=8).digest(), "big") + + +def _loads_with_allowlist(payload: bytes, allowed_modules: Optional[Tuple[str, ...]]) -> Callable: + if allowed_modules is None: + return pickle.loads(payload) + + class AllowlistUnpickler(pickle.Unpickler): + def find_class(self, module: str, name: str): # noqa: ANN001 + if not any(module == prefix or module.startswith(prefix + ".") for prefix in allowed_modules): + raise CatalogError(f"module {module!r} is not allowed in callable payload") + importlib.import_module(module) + return super().find_class(module, name) + + import io + + return AllowlistUnpickler(io.BytesIO(payload)).load() + + +class CatalogService(dispatch_pb2_grpc.CatalogServicer): + def __init__(self, catalog: Catalog) -> None: + self._catalog = catalog + + def PullCallable(self, request, context): # noqa: N802, ANN001 + try: + payload = self._catalog.export_payload(request.callable_id, request.version) + except Exception as e: # noqa: BLE001 + context.abort(grpc.StatusCode.NOT_FOUND, str(e)) + return dispatch_pb2.CallablePayload( + callable_id=request.callable_id, + version=request.version, + pickled=payload, + ) + + def PushCallable(self, request, context): # noqa: N802, ANN001 + try: + self._catalog.install_from_payload(request.callable_id, request.version, request.pickled) + except Exception as e: # noqa: BLE001 + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + return dispatch_pb2.Empty() diff --git a/python/simpler/distributed/l3_daemon.py b/python/simpler/distributed/l3_daemon.py new file mode 100644 index 000000000..0f0ac2a9e --- /dev/null +++ b/python/simpler/distributed/l3_daemon.py @@ -0,0 +1,221 @@ +"""Long-running L3 worker daemon for remote L4 dispatch.""" + +from __future__ import annotations + +import argparse +import multiprocessing as mp +import threading +import traceback +from collections.abc import Callable +from typing import Optional + +import grpc + +from simpler.worker import Worker + +from .catalog import Catalog, CatalogService +from .proto import dispatch_pb2, dispatch_pb2_grpc +from .rpc import RpcServer +from .serialization import decode_config, decode_task_args +from .tensor_pool import TensorPool + + +class L3Daemon(dispatch_pb2_grpc.L3WorkerServicer): + """RPC facade that delegates dispatches to a lazily initialized inner Worker.""" + + def __init__(self, port: int = 0, worker_factory: Optional[Callable[[], Worker]] = None) -> None: + self.port = int(port) + self.catalog = Catalog() + self.tensor_pool = TensorPool() + self._worker_factory = worker_factory or (lambda: Worker(level=3, num_sub_workers=1)) + self._server: Optional[RpcServer] = None + self._backend_proc = None + self._backend_conn = None + self._backend_lock = threading.Lock() + + def start(self, host: str = "127.0.0.1") -> int: + self._start_backend() + server = RpcServer() + server.add_l3_worker(self) + server.add_catalog(_BackendCatalogService(self.catalog, self._backend_call)) + server.add_tensor_pool(self.tensor_pool.service()) + self.port = server.start(self.port, host) + self._server = server + return self.port + + def serve_forever(self, host: str = "127.0.0.1") -> None: + self.start(host) + assert self._server is not None + self._server.wait_for_termination() + + def stop(self) -> None: + if self._server is not None: + self._server.stop(0) + self._server = None + if self._backend_conn is not None: + try: + self._backend_call(("stop",)) + except Exception: # noqa: BLE001 + pass + self._backend_conn.close() + self._backend_conn = None + if self._backend_proc is not None: + self._backend_proc.join(timeout=5.0) + if self._backend_proc.is_alive(): + self._backend_proc.terminate() + self._backend_proc.join(timeout=5.0) + self._backend_proc = None + + def Dispatch(self, request, context): # noqa: N802, ANN001 + try: + return self._on_dispatch(request) + except Exception as e: # noqa: BLE001 + tb = traceback.format_exc() + return dispatch_pb2.DispatchResp( + task_id=request.task_id, + error_code=1, + error_msg=f"{type(e).__name__}: {e}", + remote_traceback=[tb], + ) + + def Heartbeat(self, request, context): # noqa: N802, ANN001 + return dispatch_pb2.Health(ok=True, message="ok") + + def _on_dispatch(self, req: dispatch_pb2.DispatchReq) -> dispatch_pb2.DispatchResp: + resp_bytes = self._backend_call(("dispatch", req.SerializeToString())) + resp = dispatch_pb2.DispatchResp() + resp.ParseFromString(resp_bytes) + return resp + + def _start_backend(self) -> None: + if self._backend_proc is not None: + return + ctx = mp.get_context("fork") if hasattr(mp, "get_context") else mp + parent_conn, child_conn = ctx.Pipe() + proc = ctx.Process(target=_backend_loop, args=(child_conn, self._worker_factory), daemon=True) + proc.start() + child_conn.close() + self._backend_conn = parent_conn + self._backend_proc = proc + + def _backend_call(self, msg): + if self._backend_conn is None: + raise RuntimeError("L3 daemon backend is not running") + with self._backend_lock: + self._backend_conn.send(msg) + ok, payload = self._backend_conn.recv() + if not ok: + raise RuntimeError(payload) + return payload + + +class _BackendCatalogService(CatalogService): + def __init__(self, catalog: Catalog, backend_call) -> None: + super().__init__(catalog) + self._backend_call = backend_call + + def PushCallable(self, request, context): # noqa: N802, ANN001 + super().PushCallable(request, context) + self._backend_call(("push", request.callable_id, request.version, bytes(request.pickled))) + return dispatch_pb2.Empty() + + +def _backend_loop(conn, worker_factory) -> None: + catalog = Catalog() + inner: Optional[Worker] = None + try: + while True: + msg = conn.recv() + op = msg[0] + if op == "stop": + conn.send((True, None)) + break + if op == "push": + _, cid, version, payload = msg + catalog.install_from_payload(cid, version, payload) + conn.send((True, None)) + continue + if op == "dispatch": + _, req_bytes = msg + req = dispatch_pb2.DispatchReq() + req.ParseFromString(req_bytes) + resp, inner = _backend_dispatch(req, catalog, worker_factory, inner) + conn.send((True, resp.SerializeToString())) + continue + raise RuntimeError(f"unknown backend op {op!r}") + except EOFError: + pass + except Exception as e: # noqa: BLE001 + try: + conn.send((False, f"{type(e).__name__}: {e}\n{traceback.format_exc()}")) + except Exception: # noqa: BLE001 + pass + finally: + if inner is not None: + inner.close() + + +def _backend_dispatch( + req: dispatch_pb2.DispatchReq, + catalog: Catalog, + worker_factory: Callable[[], Worker], + inner: Optional[Worker], +) -> tuple[dispatch_pb2.DispatchResp, Optional[Worker]]: + try: + if inner is None: + inner = worker_factory() + for cid, version in catalog.refs(): + fn = catalog.lookup(cid, version) + if fn is not None: + inner._callable_registry[int(cid)] = fn + inner.init() + orch_fn = catalog.lookup(req.callable_id, req.callable_version) + if orch_fn is None: + return ( + dispatch_pb2.DispatchResp( + task_id=req.task_id, + error_code=2, + error_msg=f"callable {req.callable_id} version {req.callable_version} not in catalog", + ), + inner, + ) + cfg = decode_config(req.config_blob) + args = decode_task_args(req.tensor_args, req.scalar_args) + inner.run(orch_fn, args, cfg) + return dispatch_pb2.DispatchResp(task_id=req.task_id, error_code=0), inner + except Exception as e: # noqa: BLE001 + return ( + dispatch_pb2.DispatchResp( + task_id=req.task_id, + error_code=1, + error_msg=f"{type(e).__name__}: {e}", + remote_traceback=[traceback.format_exc()], + ), + inner, + ) + + +def main(argv: Optional[list[str]] = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", type=int, default=5050) + parser.add_argument("--num-sub-workers", type=int, default=1) + args = parser.parse_args(argv) + + def make_worker() -> Worker: + return Worker(level=3, num_sub_workers=args.num_sub_workers) + + daemon = L3Daemon(args.port, make_worker) + try: + daemon.serve_forever(args.host) + except KeyboardInterrupt: + daemon.stop() + return 130 + except grpc.RpcError: + daemon.stop() + raise + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/python/simpler/distributed/proto/__init__.py b/python/simpler/distributed/proto/__init__.py new file mode 100644 index 000000000..a0a6c1bff --- /dev/null +++ b/python/simpler/distributed/proto/__init__.py @@ -0,0 +1 @@ +"""Generated protobuf modules for simpler.distributed.""" diff --git a/python/simpler/distributed/proto/_gen.sh b/python/simpler/distributed/proto/_gen.sh new file mode 100755 index 000000000..a230da8d9 --- /dev/null +++ b/python/simpler/distributed/proto/_gen.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash +set -euo pipefail + +HERE="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +python -m grpc_tools.protoc \ + -I "${HERE}" \ + --python_out="${HERE}" \ + --grpc_python_out="${HERE}" \ + "${HERE}/dispatch.proto" + +python - <<'PY' "${HERE}/dispatch_pb2_grpc.py" +from pathlib import Path +import sys + +path = Path(sys.argv[1]) +text = path.read_text() +text = text.replace("import dispatch_pb2 as dispatch__pb2", "from . import dispatch_pb2 as dispatch__pb2") +path.write_text(text) +PY diff --git a/python/simpler/distributed/proto/dispatch.proto b/python/simpler/distributed/proto/dispatch.proto new file mode 100644 index 000000000..bec873e8c --- /dev/null +++ b/python/simpler/distributed/proto/dispatch.proto @@ -0,0 +1,92 @@ +syntax = "proto3"; + +package simpler.distributed.v1; + +message Empty {} + +message Health { + bool ok = 1; + string message = 2; +} + +message ContinuousTensorRef { + uint64 data = 1; + repeated uint64 shape = 2; + uint32 dtype = 3; + uint32 tag = 4; +} + +message TensorHandle { + string node_id = 1; + uint64 handle_id = 2; +} + +message TensorRef { + oneof source { + bytes inline_data = 1; + TensorHandle handle = 2; + } + repeated int64 shape = 10; + int32 dtype = 11; + int32 tag = 12; +} + +message TensorChunk { + TensorHandle handle = 1; + uint64 offset = 2; + bytes data = 3; + bool last = 4; +} + +message CallConfigWire { + int32 block_dim = 1; + int32 aicpu_thread_num = 2; + bool enable_l2_swimlane = 3; + bool enable_dump_tensor = 4; + int32 enable_pmu = 5; + string output_prefix = 6; +} + +message DispatchReq { + uint64 task_id = 1; + uint64 callable_id = 2; + uint64 callable_version = 3; + bytes config_blob = 4; + repeated uint64 scalar_args = 5; + repeated ContinuousTensorRef tensor_args = 6; + repeated TensorRef tensor_refs = 7; +} + +message DispatchResp { + uint64 task_id = 1; + int32 error_code = 2; + string error_msg = 3; + repeated string remote_traceback = 4; + repeated TensorRef output_tensors = 5; +} + +message CallableRef { + uint64 callable_id = 1; + uint64 version = 2; +} + +message CallablePayload { + uint64 callable_id = 1; + uint64 version = 2; + bytes pickled = 3; +} + +service L3Worker { + rpc Dispatch(DispatchReq) returns (DispatchResp); + rpc Heartbeat(Empty) returns (Health); +} + +service Catalog { + rpc PullCallable(CallableRef) returns (CallablePayload); + rpc PushCallable(CallablePayload) returns (Empty); +} + +service TensorPool { + rpc PullTensor(TensorHandle) returns (stream TensorChunk); + rpc PushTensor(stream TensorChunk) returns (TensorHandle); +} diff --git a/python/simpler/distributed/proto/dispatch_pb2.py b/python/simpler/distributed/proto/dispatch_pb2.py new file mode 100644 index 000000000..2207b5c1e --- /dev/null +++ b/python/simpler/distributed/proto/dispatch_pb2.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: dispatch.proto +# Protobuf Python Version: 6.31.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 31, + 1, + '', + 'dispatch.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0e\x64ispatch.proto\x12\x16simpler.distributed.v1\"\x07\n\x05\x45mpty\"%\n\x06Health\x12\n\n\x02ok\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"N\n\x13\x43ontinuousTensorRef\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x04\x12\r\n\x05shape\x18\x02 \x03(\x04\x12\r\n\x05\x64type\x18\x03 \x01(\r\x12\x0b\n\x03tag\x18\x04 \x01(\r\"2\n\x0cTensorHandle\x12\x0f\n\x07node_id\x18\x01 \x01(\t\x12\x11\n\thandle_id\x18\x02 \x01(\x04\"\x8f\x01\n\tTensorRef\x12\x15\n\x0binline_data\x18\x01 \x01(\x0cH\x00\x12\x36\n\x06handle\x18\x02 \x01(\x0b\x32$.simpler.distributed.v1.TensorHandleH\x00\x12\r\n\x05shape\x18\n \x03(\x03\x12\r\n\x05\x64type\x18\x0b \x01(\x05\x12\x0b\n\x03tag\x18\x0c \x01(\x05\x42\x08\n\x06source\"o\n\x0bTensorChunk\x12\x34\n\x06handle\x18\x01 \x01(\x0b\x32$.simpler.distributed.v1.TensorHandle\x12\x0e\n\x06offset\x18\x02 \x01(\x04\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12\x0c\n\x04last\x18\x04 \x01(\x08\"\xa0\x01\n\x0e\x43\x61llConfigWire\x12\x11\n\tblock_dim\x18\x01 \x01(\x05\x12\x18\n\x10\x61icpu_thread_num\x18\x02 \x01(\x05\x12\x1a\n\x12\x65nable_l2_swimlane\x18\x03 \x01(\x08\x12\x1a\n\x12\x65nable_dump_tensor\x18\x04 \x01(\x08\x12\x12\n\nenable_pmu\x18\x05 \x01(\x05\x12\x15\n\routput_prefix\x18\x06 \x01(\t\"\xf1\x01\n\x0b\x44ispatchReq\x12\x0f\n\x07task_id\x18\x01 \x01(\x04\x12\x13\n\x0b\x63\x61llable_id\x18\x02 \x01(\x04\x12\x18\n\x10\x63\x61llable_version\x18\x03 \x01(\x04\x12\x13\n\x0b\x63onfig_blob\x18\x04 \x01(\x0c\x12\x13\n\x0bscalar_args\x18\x05 \x03(\x04\x12@\n\x0btensor_args\x18\x06 \x03(\x0b\x32+.simpler.distributed.v1.ContinuousTensorRef\x12\x36\n\x0btensor_refs\x18\x07 \x03(\x0b\x32!.simpler.distributed.v1.TensorRef\"\x9b\x01\n\x0c\x44ispatchResp\x12\x0f\n\x07task_id\x18\x01 \x01(\x04\x12\x12\n\nerror_code\x18\x02 \x01(\x05\x12\x11\n\terror_msg\x18\x03 \x01(\t\x12\x18\n\x10remote_traceback\x18\x04 \x03(\t\x12\x39\n\x0eoutput_tensors\x18\x05 \x03(\x0b\x32!.simpler.distributed.v1.TensorRef\"3\n\x0b\x43\x61llableRef\x12\x13\n\x0b\x63\x61llable_id\x18\x01 \x01(\x04\x12\x0f\n\x07version\x18\x02 \x01(\x04\"H\n\x0f\x43\x61llablePayload\x12\x13\n\x0b\x63\x61llable_id\x18\x01 \x01(\x04\x12\x0f\n\x07version\x18\x02 \x01(\x04\x12\x0f\n\x07pickled\x18\x03 \x01(\x0c\x32\xad\x01\n\x08L3Worker\x12U\n\x08\x44ispatch\x12#.simpler.distributed.v1.DispatchReq\x1a$.simpler.distributed.v1.DispatchResp\x12J\n\tHeartbeat\x12\x1d.simpler.distributed.v1.Empty\x1a\x1e.simpler.distributed.v1.Health2\xbf\x01\n\x07\x43\x61talog\x12\\\n\x0cPullCallable\x12#.simpler.distributed.v1.CallableRef\x1a\'.simpler.distributed.v1.CallablePayload\x12V\n\x0cPushCallable\x12\'.simpler.distributed.v1.CallablePayload\x1a\x1d.simpler.distributed.v1.Empty2\xc2\x01\n\nTensorPool\x12Y\n\nPullTensor\x12$.simpler.distributed.v1.TensorHandle\x1a#.simpler.distributed.v1.TensorChunk0\x01\x12Y\n\nPushTensor\x12#.simpler.distributed.v1.TensorChunk\x1a$.simpler.distributed.v1.TensorHandle(\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'dispatch_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_EMPTY']._serialized_start=42 + _globals['_EMPTY']._serialized_end=49 + _globals['_HEALTH']._serialized_start=51 + _globals['_HEALTH']._serialized_end=88 + _globals['_CONTINUOUSTENSORREF']._serialized_start=90 + _globals['_CONTINUOUSTENSORREF']._serialized_end=168 + _globals['_TENSORHANDLE']._serialized_start=170 + _globals['_TENSORHANDLE']._serialized_end=220 + _globals['_TENSORREF']._serialized_start=223 + _globals['_TENSORREF']._serialized_end=366 + _globals['_TENSORCHUNK']._serialized_start=368 + _globals['_TENSORCHUNK']._serialized_end=479 + _globals['_CALLCONFIGWIRE']._serialized_start=482 + _globals['_CALLCONFIGWIRE']._serialized_end=642 + _globals['_DISPATCHREQ']._serialized_start=645 + _globals['_DISPATCHREQ']._serialized_end=886 + _globals['_DISPATCHRESP']._serialized_start=889 + _globals['_DISPATCHRESP']._serialized_end=1044 + _globals['_CALLABLEREF']._serialized_start=1046 + _globals['_CALLABLEREF']._serialized_end=1097 + _globals['_CALLABLEPAYLOAD']._serialized_start=1099 + _globals['_CALLABLEPAYLOAD']._serialized_end=1171 + _globals['_L3WORKER']._serialized_start=1174 + _globals['_L3WORKER']._serialized_end=1347 + _globals['_CATALOG']._serialized_start=1350 + _globals['_CATALOG']._serialized_end=1541 + _globals['_TENSORPOOL']._serialized_start=1544 + _globals['_TENSORPOOL']._serialized_end=1738 +# @@protoc_insertion_point(module_scope) diff --git a/python/simpler/distributed/proto/dispatch_pb2_grpc.py b/python/simpler/distributed/proto/dispatch_pb2_grpc.py new file mode 100644 index 000000000..f0ef4284c --- /dev/null +++ b/python/simpler/distributed/proto/dispatch_pb2_grpc.py @@ -0,0 +1,370 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +from . import dispatch_pb2 as dispatch__pb2 + +GRPC_GENERATED_VERSION = '1.80.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + ' but the generated code in dispatch_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class L3WorkerStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Dispatch = channel.unary_unary( + '/simpler.distributed.v1.L3Worker/Dispatch', + request_serializer=dispatch__pb2.DispatchReq.SerializeToString, + response_deserializer=dispatch__pb2.DispatchResp.FromString, + _registered_method=True) + self.Heartbeat = channel.unary_unary( + '/simpler.distributed.v1.L3Worker/Heartbeat', + request_serializer=dispatch__pb2.Empty.SerializeToString, + response_deserializer=dispatch__pb2.Health.FromString, + _registered_method=True) + + +class L3WorkerServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Dispatch(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Heartbeat(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_L3WorkerServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Dispatch': grpc.unary_unary_rpc_method_handler( + servicer.Dispatch, + request_deserializer=dispatch__pb2.DispatchReq.FromString, + response_serializer=dispatch__pb2.DispatchResp.SerializeToString, + ), + 'Heartbeat': grpc.unary_unary_rpc_method_handler( + servicer.Heartbeat, + request_deserializer=dispatch__pb2.Empty.FromString, + response_serializer=dispatch__pb2.Health.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'simpler.distributed.v1.L3Worker', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('simpler.distributed.v1.L3Worker', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class L3Worker(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Dispatch(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/simpler.distributed.v1.L3Worker/Dispatch', + dispatch__pb2.DispatchReq.SerializeToString, + dispatch__pb2.DispatchResp.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Heartbeat(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/simpler.distributed.v1.L3Worker/Heartbeat', + dispatch__pb2.Empty.SerializeToString, + dispatch__pb2.Health.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + +class CatalogStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.PullCallable = channel.unary_unary( + '/simpler.distributed.v1.Catalog/PullCallable', + request_serializer=dispatch__pb2.CallableRef.SerializeToString, + response_deserializer=dispatch__pb2.CallablePayload.FromString, + _registered_method=True) + self.PushCallable = channel.unary_unary( + '/simpler.distributed.v1.Catalog/PushCallable', + request_serializer=dispatch__pb2.CallablePayload.SerializeToString, + response_deserializer=dispatch__pb2.Empty.FromString, + _registered_method=True) + + +class CatalogServicer(object): + """Missing associated documentation comment in .proto file.""" + + def PullCallable(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def PushCallable(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_CatalogServicer_to_server(servicer, server): + rpc_method_handlers = { + 'PullCallable': grpc.unary_unary_rpc_method_handler( + servicer.PullCallable, + request_deserializer=dispatch__pb2.CallableRef.FromString, + response_serializer=dispatch__pb2.CallablePayload.SerializeToString, + ), + 'PushCallable': grpc.unary_unary_rpc_method_handler( + servicer.PushCallable, + request_deserializer=dispatch__pb2.CallablePayload.FromString, + response_serializer=dispatch__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'simpler.distributed.v1.Catalog', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('simpler.distributed.v1.Catalog', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class Catalog(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def PullCallable(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/simpler.distributed.v1.Catalog/PullCallable', + dispatch__pb2.CallableRef.SerializeToString, + dispatch__pb2.CallablePayload.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def PushCallable(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/simpler.distributed.v1.Catalog/PushCallable', + dispatch__pb2.CallablePayload.SerializeToString, + dispatch__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + +class TensorPoolStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.PullTensor = channel.unary_stream( + '/simpler.distributed.v1.TensorPool/PullTensor', + request_serializer=dispatch__pb2.TensorHandle.SerializeToString, + response_deserializer=dispatch__pb2.TensorChunk.FromString, + _registered_method=True) + self.PushTensor = channel.stream_unary( + '/simpler.distributed.v1.TensorPool/PushTensor', + request_serializer=dispatch__pb2.TensorChunk.SerializeToString, + response_deserializer=dispatch__pb2.TensorHandle.FromString, + _registered_method=True) + + +class TensorPoolServicer(object): + """Missing associated documentation comment in .proto file.""" + + def PullTensor(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def PushTensor(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_TensorPoolServicer_to_server(servicer, server): + rpc_method_handlers = { + 'PullTensor': grpc.unary_stream_rpc_method_handler( + servicer.PullTensor, + request_deserializer=dispatch__pb2.TensorHandle.FromString, + response_serializer=dispatch__pb2.TensorChunk.SerializeToString, + ), + 'PushTensor': grpc.stream_unary_rpc_method_handler( + servicer.PushTensor, + request_deserializer=dispatch__pb2.TensorChunk.FromString, + response_serializer=dispatch__pb2.TensorHandle.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'simpler.distributed.v1.TensorPool', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('simpler.distributed.v1.TensorPool', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class TensorPool(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def PullTensor(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/simpler.distributed.v1.TensorPool/PullTensor', + dispatch__pb2.TensorHandle.SerializeToString, + dispatch__pb2.TensorChunk.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def PushTensor(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/simpler.distributed.v1.TensorPool/PushTensor', + dispatch__pb2.TensorChunk.SerializeToString, + dispatch__pb2.TensorHandle.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/python/simpler/distributed/remote_proxy.py b/python/simpler/distributed/remote_proxy.py new file mode 100644 index 000000000..6965e5634 --- /dev/null +++ b/python/simpler/distributed/remote_proxy.py @@ -0,0 +1,131 @@ +"""L4-side proxy for a remote L3 worker.""" + +from __future__ import annotations + +import itertools +import threading +import time +from typing import Optional + +from simpler.task_interface import CallConfig, TaskArgs + +from .catalog import Catalog +from .proto import dispatch_pb2 +from .rpc import RpcClient, RpcError +from .serialization import encode_config, encode_task_args + + +class RemoteUnavailable(RuntimeError): + pass + + +class RemoteWorkerProxy: + """Synchronous L4-side stub for one remote L3 worker.""" + + def __init__( + self, + endpoint: str, + l4_catalog: Catalog, + *, + timeout: float = 10.0, + heartbeat_timeout: float = 1.0, + heartbeat_interval: float = 5.0, + heartbeat_failures: int = 3, + ) -> None: + self.endpoint = endpoint + self._client = RpcClient(endpoint) + self._catalog = l4_catalog + self._timeout = float(timeout) + self._heartbeat_timeout = float(heartbeat_timeout) + self._heartbeat_interval = float(heartbeat_interval) + self._heartbeat_failures = int(heartbeat_failures) + self._task_ids = itertools.count(1) + self._available = True + self._closed = threading.Event() + self._heartbeat_thread: Optional[threading.Thread] = None + + def handshake(self) -> None: + self._check_heartbeat() + for payload in self._catalog.payloads(): + self._client.call_unary("Catalog.PushCallable", payload, self._timeout) + self._start_heartbeat() + + def heartbeat(self) -> None: + self._check_heartbeat() + + def _check_heartbeat(self) -> None: + try: + health = self._heartbeat_rpc() + except RemoteUnavailable: + self._available = False + raise + if not health.ok: + self._available = False + raise RemoteUnavailable(f"remote {self.endpoint} unhealthy: {health.message}") + self._available = True + + def _heartbeat_rpc(self): + try: + return self._client.heartbeat(self._heartbeat_timeout) + except RpcError as e: + raise RemoteUnavailable(f"remote {self.endpoint} heartbeat failed: {e}") from e + + def _start_heartbeat(self) -> None: + if self._heartbeat_interval <= 0 or self._heartbeat_thread is not None: + return + self._heartbeat_thread = threading.Thread( + target=self._heartbeat_loop, + name=f"simpler-remote-heartbeat-{self.endpoint}", + daemon=True, + ) + self._heartbeat_thread.start() + + def _heartbeat_loop(self) -> None: + failures = 0 + while not self._closed.wait(self._heartbeat_interval): + try: + health = self._heartbeat_rpc() + if not health.ok: + raise RemoteUnavailable(f"remote {self.endpoint} unhealthy: {health.message}") + failures = 0 + self._available = True + except RemoteUnavailable: + failures += 1 + if failures >= self._heartbeat_failures: + self._available = False + + def dispatch(self, callable_id: int, args: Optional[TaskArgs], cfg: Optional[CallConfig]) -> None: + if not self._available: + raise RemoteUnavailable(f"remote {self.endpoint} is unavailable") + config = cfg if cfg is not None else CallConfig() + tensor_args, scalar_args = encode_task_args(args) + version = self._catalog.refs_by_id().get(int(callable_id), 0) + req = dispatch_pb2.DispatchReq( + task_id=next(self._task_ids), + callable_id=int(callable_id), + callable_version=int(version), + config_blob=encode_config(config), + scalar_args=scalar_args, + tensor_args=tensor_args, + ) + try: + resp = self._client.dispatch(req, self._timeout) + except RpcError as e: + self._available = False + raise RemoteUnavailable(f"remote {self.endpoint} dispatch RPC failed: {e}") from e + if resp.error_code != 0: + detail = resp.error_msg + if resp.remote_traceback: + detail = detail + "\nremote traceback:\n" + "\n".join(resp.remote_traceback) + raise RuntimeError(f"remote dispatch failed on {self.endpoint}: {detail}") + + def close(self) -> None: + self._closed.set() + if self._heartbeat_thread is not None: + self._heartbeat_thread.join(timeout=1.0) + self._heartbeat_thread = None + self._client.close() + + +def sleep_poll_interval() -> None: + time.sleep(0.0005) diff --git a/python/simpler/distributed/rpc.py b/python/simpler/distributed/rpc.py new file mode 100644 index 000000000..4e0ef4c4d --- /dev/null +++ b/python/simpler/distributed/rpc.py @@ -0,0 +1,110 @@ +"""Small grpcio wrappers used by distributed dispatch.""" + +from __future__ import annotations + +from concurrent import futures +from typing import Any, Callable, Optional + +import grpc + +from .proto import dispatch_pb2, dispatch_pb2_grpc + +_MAX_MESSAGE_BYTES = 64 * 1024 * 1024 +_CHANNEL_OPTIONS = [ + ("grpc.max_send_message_length", _MAX_MESSAGE_BYTES), + ("grpc.max_receive_message_length", _MAX_MESSAGE_BYTES), + ("grpc.so_reuseport", 0), +] + + +class RpcError(RuntimeError): + """Raised when a gRPC call fails.""" + + def __init__(self, message: str, *, code: Optional[grpc.StatusCode] = None, remote_traceback: str = "") -> None: + super().__init__(message) + self.code = code + self.remote_traceback = remote_traceback + + +class RpcServer: + """Thin owner for a grpc.server instance.""" + + def __init__(self, *, max_workers: int = 8) -> None: + self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers), options=_CHANNEL_OPTIONS) + self._port: Optional[int] = None + + @property + def port(self) -> int: + if self._port is None: + raise RuntimeError("RpcServer has not been started") + return self._port + + def add_l3_worker(self, impl: dispatch_pb2_grpc.L3WorkerServicer) -> None: + dispatch_pb2_grpc.add_L3WorkerServicer_to_server(impl, self._server) + + def add_catalog(self, impl: dispatch_pb2_grpc.CatalogServicer) -> None: + dispatch_pb2_grpc.add_CatalogServicer_to_server(impl, self._server) + + def add_tensor_pool(self, impl: dispatch_pb2_grpc.TensorPoolServicer) -> None: + dispatch_pb2_grpc.add_TensorPoolServicer_to_server(impl, self._server) + + def add_handler(self, service: str, impl: Any) -> None: + if service == "L3Worker": + self.add_l3_worker(impl) + elif service == "Catalog": + self.add_catalog(impl) + elif service == "TensorPool": + self.add_tensor_pool(impl) + else: + raise ValueError(f"unknown service {service!r}") + + def start(self, port: int = 0, host: str = "127.0.0.1") -> int: + try: + bound = self._server.add_insecure_port(f"{host}:{int(port)}") + except RuntimeError as e: + raise RpcError(f"failed to bind gRPC server on {host}:{port}: {e}") from e + if bound == 0: + raise RpcError(f"failed to bind gRPC server on {host}:{port}") + self._server.start() + self._port = bound + return bound + + def wait_for_termination(self) -> None: + self._server.wait_for_termination() + + def stop(self, grace: Optional[float] = 0) -> None: + self._server.stop(grace) + + +class RpcClient: + """Typed client wrapper for the distributed proto services.""" + + def __init__(self, endpoint: str) -> None: + self.endpoint = endpoint + self._channel = grpc.insecure_channel(endpoint, options=_CHANNEL_OPTIONS) + self.l3_worker = dispatch_pb2_grpc.L3WorkerStub(self._channel) + self.catalog = dispatch_pb2_grpc.CatalogStub(self._channel) + self.tensor_pool = dispatch_pb2_grpc.TensorPoolStub(self._channel) + + def call_unary(self, method: str, req: Any, timeout: Optional[float] = None) -> Any: + mapping: dict[str, Callable[..., Any]] = { + "L3Worker.Dispatch": self.l3_worker.Dispatch, + "L3Worker.Heartbeat": self.l3_worker.Heartbeat, + "Catalog.PullCallable": self.catalog.PullCallable, + "Catalog.PushCallable": self.catalog.PushCallable, + } + try: + return mapping[method](req, timeout=timeout) + except KeyError as e: + raise ValueError(f"unknown unary method {method!r}") from e + except grpc.RpcError as e: + raise RpcError(str(e.details() or e), code=e.code()) from e + + def dispatch(self, req: dispatch_pb2.DispatchReq, timeout: Optional[float] = None) -> dispatch_pb2.DispatchResp: + return self.call_unary("L3Worker.Dispatch", req, timeout) + + def heartbeat(self, timeout: Optional[float] = None) -> dispatch_pb2.Health: + return self.call_unary("L3Worker.Heartbeat", dispatch_pb2.Empty(), timeout) + + def close(self) -> None: + self._channel.close() diff --git a/python/simpler/distributed/serialization.py b/python/simpler/distributed/serialization.py new file mode 100644 index 000000000..7c73da86b --- /dev/null +++ b/python/simpler/distributed/serialization.py @@ -0,0 +1,71 @@ +"""Wire serialization helpers shared by distributed dispatch components.""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Optional + +from simpler.task_interface import CallConfig, ContinuousTensor, DataType, TaskArgs, TensorArgType + +from .proto import dispatch_pb2 + + +def encode_config(config: CallConfig) -> bytes: + cfg = dispatch_pb2.CallConfigWire( + block_dim=int(config.block_dim), + aicpu_thread_num=int(config.aicpu_thread_num), + enable_l2_swimlane=bool(config.enable_l2_swimlane), + enable_dump_tensor=bool(config.enable_dump_tensor), + enable_pmu=int(config.enable_pmu), + output_prefix=str(config.output_prefix), + ) + return cfg.SerializeToString() + + +def decode_config(blob: bytes) -> CallConfig: + cfg = CallConfig() + if not blob: + return cfg + wire = dispatch_pb2.CallConfigWire() + wire.ParseFromString(blob) + cfg.block_dim = int(wire.block_dim) + cfg.aicpu_thread_num = int(wire.aicpu_thread_num) + cfg.enable_l2_swimlane = bool(wire.enable_l2_swimlane) + cfg.enable_dump_tensor = bool(wire.enable_dump_tensor) + cfg.enable_pmu = int(wire.enable_pmu) + cfg.output_prefix = wire.output_prefix + return cfg + + +def encode_task_args(args: Optional[TaskArgs]) -> tuple[list[dispatch_pb2.ContinuousTensorRef], list[int]]: + if args is None: + return [], [] + tensors = [] + for i in range(args.tensor_count()): + tensor = args.tensor(i) + tag = args.tag(i) + tensors.append( + dispatch_pb2.ContinuousTensorRef( + data=int(tensor.data), + shape=[int(x) for x in tensor.shapes[: int(tensor.ndims)]], + dtype=int(tensor.dtype.value), + tag=int(tag.value), + ) + ) + scalars = [int(args.scalar(i)) for i in range(args.scalar_count())] + return tensors, scalars + + +def decode_task_args( + tensor_refs: Iterable[dispatch_pb2.ContinuousTensorRef], + scalar_args: Iterable[int], +) -> TaskArgs: + args = TaskArgs() + for ref in tensor_refs: + shape = tuple(int(x) for x in ref.shape) + dtype = DataType(int(ref.dtype)) + tag = TensorArgType(int(ref.tag)) + args.add_tensor(ContinuousTensor.make(int(ref.data), shape, dtype), tag) + for scalar in scalar_args: + args.add_scalar(int(scalar)) + return args diff --git a/python/simpler/distributed/tensor_pool.py b/python/simpler/distributed/tensor_pool.py new file mode 100644 index 000000000..3b4787901 --- /dev/null +++ b/python/simpler/distributed/tensor_pool.py @@ -0,0 +1,69 @@ +"""Tensor byte pool used by distributed dispatch tensor references.""" + +from __future__ import annotations + +import itertools +import uuid +from collections.abc import Iterable +from typing import Optional + +import grpc + +from .proto import dispatch_pb2, dispatch_pb2_grpc + + +class TensorPool: + def __init__(self, *, node_id: Optional[str] = None, inline_threshold: int = 1024 * 1024) -> None: + self.node_id = node_id or str(uuid.uuid4()) + self.inline_threshold = int(inline_threshold) + self._next_id = itertools.count(1) + self._data: dict[int, bytes] = {} + + def put_bytes(self, data: bytes) -> dispatch_pb2.TensorRef: + data = bytes(data) + if len(data) <= self.inline_threshold: + return dispatch_pb2.TensorRef(inline_data=data) + handle_id = next(self._next_id) + self._data[handle_id] = data + return dispatch_pb2.TensorRef(handle=dispatch_pb2.TensorHandle(node_id=self.node_id, handle_id=handle_id)) + + def get_bytes(self, handle: dispatch_pb2.TensorHandle) -> bytes: + if handle.node_id != self.node_id: + raise KeyError(f"tensor handle belongs to node {handle.node_id!r}, not {self.node_id!r}") + return self._data[int(handle.handle_id)] + + def service(self) -> "TensorPoolService": + return TensorPoolService(self) + + +class TensorPoolService(dispatch_pb2_grpc.TensorPoolServicer): + def __init__(self, pool: TensorPool, *, chunk_size: int = 1024 * 1024) -> None: + self._pool = pool + self._chunk_size = int(chunk_size) + + def PullTensor(self, request, context): # noqa: N802, ANN001 + try: + data = self._pool.get_bytes(request) + except Exception as e: # noqa: BLE001 + context.abort(grpc.StatusCode.NOT_FOUND, str(e)) + for offset in range(0, len(data), self._chunk_size): + chunk = data[offset : offset + self._chunk_size] + yield dispatch_pb2.TensorChunk( + handle=request, + offset=offset, + data=chunk, + last=offset + len(chunk) >= len(data), + ) + if not data: + yield dispatch_pb2.TensorChunk(handle=request, offset=0, data=b"", last=True) + + def PushTensor(self, request_iterator: Iterable[dispatch_pb2.TensorChunk], context): # noqa: N802, ANN001 + parts = [] + for chunk in request_iterator: + parts.append(bytes(chunk.data)) + ref = self._pool.put_bytes(b"".join(parts)) + if not ref.HasField("handle"): + handle_id = next(self._pool._next_id) + self._pool._data[handle_id] = ref.inline_data + return dispatch_pb2.TensorHandle(node_id=self._pool.node_id, handle_id=handle_id) + return ref.handle diff --git a/python/simpler/worker.py b/python/simpler/worker.py index ea6ece5b4..31a4b459b 100644 --- a/python/simpler/worker.py +++ b/python/simpler/worker.py @@ -51,6 +51,7 @@ def my_l4_orch(orch, args, config): import signal import struct import sys +import threading import time import traceback from multiprocessing.shared_memory import SharedMemory @@ -544,6 +545,36 @@ def _child_worker_loop( break +def _remote_worker_loop(buf: memoryview, proxy: Any) -> None: + """L4-side mailbox shim for one remote L3 worker. + + The C++ scheduler sees a normal PROCESS-mode next-level mailbox. This + thread owns the other side of that mailbox and forwards every TASK_READY + payload to the remote daemon over RPC. + """ + state_addr = _buffer_field_addr(buf, _OFF_STATE) + while True: + state = _mailbox_load_i32(state_addr) + if state == _TASK_READY: + cid = struct.unpack_from("Q", buf, _OFF_CALLABLE)[0] + code = 0 + msg = "" + try: + args = _read_args_from_mailbox(buf) + cfg = _read_config_from_mailbox(buf) + proxy.dispatch(int(cid), args, cfg) + except Exception as e: # noqa: BLE001 + code = 1 + msg = _format_exc("remote_worker", e) + _write_error(buf, code, msg) + _mailbox_store_i32(state_addr, _TASK_DONE) + elif state == _SHUTDOWN: + proxy.close() + break + else: + time.sleep(_BOOTSTRAP_POLL_INTERVAL_S) + + # --------------------------------------------------------------------------- # Worker factory # --------------------------------------------------------------------------- @@ -586,6 +617,11 @@ def __init__( self._next_level_workers: list[Worker] = [] self._next_level_shms: list[SharedMemory] = [] self._next_level_pids: list[int] = [] + self._remote_worker_specs: list[tuple[str, dict[str, Any]]] = [] + self._remote_workers: list[Any] = [] + self._remote_worker_shms: list[SharedMemory] = [] + self._remote_worker_threads: list[threading.Thread] = [] + self._distributed_catalog: Any = None # Per-chip bootstrap: one `ChipBootstrapConfig` per device_id plus a # matching shared-memory mailbox the child publishes its @@ -616,6 +652,8 @@ def register(self, fn: Callable) -> int: raise RuntimeError("Worker.register() must be called before init()") cid = len(self._callable_registry) self._callable_registry[cid] = fn + if self._distributed_catalog is not None: + self._distributed_catalog.register(fn, callable_id=cid) return cid def add_worker(self, worker: "Worker") -> None: @@ -633,6 +671,24 @@ def add_worker(self, worker: "Worker") -> None: raise RuntimeError("Child worker must not be initialized before add_worker()") self._next_level_workers.append(worker) + def add_remote_worker(self, endpoint: str, **options: Any) -> None: + """Add a remote lower-level Worker daemon as a NEXT_LEVEL child.""" + if self.level < 4: + raise RuntimeError("Worker.add_remote_worker() requires level >= 4") + if self._initialized: + raise RuntimeError("Worker.add_remote_worker() must be called before init()") + self._ensure_distributed_catalog() + self._remote_worker_specs.append((str(endpoint), dict(options))) + + def _ensure_distributed_catalog(self): + if self._distributed_catalog is None: + from .distributed.catalog import Catalog # noqa: PLC0415 + + self._distributed_catalog = Catalog() + for cid, fn in self._callable_registry.items(): + self._distributed_catalog.register(fn, callable_id=cid) + return self._distributed_catalog + # ------------------------------------------------------------------ # init — auto-discovery # ------------------------------------------------------------------ @@ -723,6 +779,19 @@ def _init_hierarchical(self) -> None: _mailbox_store_i32(_buffer_field_addr(shm.buf, _OFF_STATE), _IDLE) self._next_level_shms.append(shm) + # 3a. Allocate remote next-level mailboxes. These are consumed by + # Python shim threads rather than forked child processes. + if self._remote_worker_specs: + from .distributed.remote_proxy import RemoteWorkerProxy # noqa: PLC0415 + + catalog = self._ensure_distributed_catalog() + for endpoint, options in self._remote_worker_specs: + shm = SharedMemory(create=True, size=MAILBOX_SIZE) + assert shm.buf is not None + _mailbox_store_i32(_buffer_field_addr(shm.buf, _OFF_STATE), _IDLE) + self._remote_worker_shms.append(shm) + self._remote_workers.append(RemoteWorkerProxy(endpoint, catalog, **options)) + # 3b. Allocate per-chip bootstrap mailboxes (one per device_id). Must # live in shared memory so the forked child's `ChipBootstrapChannel` # and the parent's read-side view see the same region. SharedMemory @@ -832,6 +901,19 @@ def _start_hierarchical(self) -> None: # noqa: PLR0912 -- three parallel fork l else: self._next_level_pids.append(pid) + # Start remote mailbox shim threads. Handshake pushes the current + # callable catalog before the scheduler can publish TASK_READY. + for idx, proxy in enumerate(self._remote_workers): + proxy.handshake() + thread = threading.Thread( + target=_remote_worker_loop, + args=(self._remote_worker_shms[idx].buf, proxy), + name=f"simpler-remote-worker-{idx}", + daemon=True, + ) + thread.start() + self._remote_worker_threads.append(thread) + # When chip_bootstrap_configs was provided, block here until every # chip child publishes its result on its bootstrap mailbox. We wait # *before* registering the chip mailboxes with the scheduler so a @@ -860,6 +942,10 @@ def _start_hierarchical(self) -> None: # noqa: PLR0912 -- three parallel fork l for shm in self._next_level_shms: dw.add_next_level_process(_mailbox_addr(shm)) + # Register remote Worker shim mailboxes as NEXT_LEVEL (L4+) + for shm in self._remote_worker_shms: + dw.add_next_level_process(_mailbox_addr(shm)) + for shm in self._sub_shms: dw.add_sub_process(_mailbox_addr(shm)) @@ -1160,6 +1246,22 @@ def close(self) -> None: # noqa: PLR0912 -- parallel teardown for _worker + sub shm.close() shm.unlink() + # Shutdown remote Worker shim threads. + for shm in self._remote_worker_shms: + buf = shm.buf + assert buf is not None + _mailbox_store_i32(_buffer_field_addr(buf, _OFF_STATE), _SHUTDOWN) + for thread in self._remote_worker_threads: + thread.join(timeout=5.0) + for proxy in self._remote_workers: + try: + proxy.close() + except Exception: # noqa: BLE001 + pass + for shm in self._remote_worker_shms: + shm.close() + shm.unlink() + # Unlink the bootstrap mailboxes last — chip children touch their # `ChipBootstrapChannel` from inside `shutdown_bootstrap()` + # `finalize()`, which runs after they leave the main loop on @@ -1182,6 +1284,10 @@ def close(self) -> None: # noqa: PLR0912 -- parallel teardown for _worker + sub self._next_level_shms.clear() self._next_level_pids.clear() self._next_level_workers.clear() + self._remote_worker_shms.clear() + self._remote_worker_threads.clear() + self._remote_workers.clear() + self._remote_worker_specs.clear() self._bootstrap_shms.clear() self._chip_contexts.clear() diff --git a/tests/ut/py/test_distributed/test_catalog.py b/tests/ut/py/test_distributed/test_catalog.py new file mode 100644 index 000000000..407f84e53 --- /dev/null +++ b/tests/ut/py/test_distributed/test_catalog.py @@ -0,0 +1,44 @@ +import pytest + +from simpler.distributed.catalog import Catalog, CatalogError + + +def test_catalog_export_install_lookup(): + l4 = Catalog() + + def fn(args): + return args.scalar(0) + 1 + + cid, version = l4.register(fn) + payload = l4.export_payload(cid, version) + + l3 = Catalog() + l3.install_from_payload(cid, version, payload) + + got = l3.lookup(cid, version) + assert got is not None + + +def test_catalog_pull_mock_install(): + l4 = Catalog() + cid, version = l4.register(lambda args: args.scalar(0) * 2) + + class MockClient: + def call_unary(self, method, req, timeout=None): + assert method == "Catalog.PullCallable" + return type("Payload", (), {"callable_id": cid, "version": version, "pickled": l4.export_payload(cid, version)}) + + l3 = Catalog() + req = type("Req", (), {"callable_id": cid, "version": version})() + payload = MockClient().call_unary("Catalog.PullCallable", req) + l3.install_from_payload(payload.callable_id, payload.version, payload.pickled) + + assert l3.lookup(cid, version) is not None + + +def test_catalog_version_mismatch(): + catalog = Catalog() + cid, version = catalog.register(lambda args: None) + payload = catalog.export_payload(cid, version) + with pytest.raises(CatalogError, match="version mismatch"): + catalog.install_from_payload(cid, version + 1, payload) diff --git a/tests/ut/py/test_distributed/test_heartbeat.py b/tests/ut/py/test_distributed/test_heartbeat.py new file mode 100644 index 000000000..a05796474 --- /dev/null +++ b/tests/ut/py/test_distributed/test_heartbeat.py @@ -0,0 +1,29 @@ +import time + +import pytest + +from simpler.distributed.l3_daemon import L3Daemon +from simpler.distributed.remote_proxy import RemoteUnavailable +from simpler.task_interface import TaskArgs +from simpler.worker import Worker + + +def test_remote_proxy_marks_down_after_heartbeat_failures(): + daemon = L3Daemon(0, lambda: Worker(level=3, num_sub_workers=0)) + endpoint = f"127.0.0.1:{daemon.start()}" + w4 = Worker(level=4, num_sub_workers=0) + w4.register(lambda orch, args, config: None) + w4.add_remote_worker(endpoint, heartbeat_interval=0.05, heartbeat_failures=1, heartbeat_timeout=0.05) + w4.init() + w4.run(lambda orch, args, config: None) + daemon.stop() + try: + time.sleep(0.2) + + def l4_orch(orch, args, config): + orch.submit_next_level(0, TaskArgs(), config) + + with pytest.raises(RuntimeError, match="unavailable|dispatch RPC failed"): + w4.run(l4_orch) + finally: + w4.close() diff --git a/tests/ut/py/test_distributed/test_import.py b/tests/ut/py/test_distributed/test_import.py new file mode 100644 index 000000000..b03e11e02 --- /dev/null +++ b/tests/ut/py/test_distributed/test_import.py @@ -0,0 +1,6 @@ +def test_distributed_imports(): + import simpler.distributed + from simpler.distributed.proto import dispatch_pb2 + + assert simpler.distributed is not None + assert dispatch_pb2.DispatchReq(task_id=1).task_id == 1 diff --git a/tests/ut/py/test_distributed/test_l4_l3_remote.py b/tests/ut/py/test_distributed/test_l4_l3_remote.py new file mode 100644 index 000000000..b2b0e9128 --- /dev/null +++ b/tests/ut/py/test_distributed/test_l4_l3_remote.py @@ -0,0 +1,235 @@ +import struct +from multiprocessing.shared_memory import SharedMemory + +from simpler.distributed.l3_daemon import L3Daemon +from simpler.task_interface import CallConfig, TaskArgs +from simpler.worker import Worker + +def _scalar_value(args: TaskArgs) -> int: + return int(args.scalar(0)) if args is not None and args.scalar_count() else 1 + + +def _make_shared_counter(): + shm = SharedMemory(create=True, size=4) + buf = shm.buf + assert buf is not None + struct.pack_into("i", buf, 0, 0) + return shm, buf + + +def _read_counter(buf) -> int: + return struct.unpack_from("i", buf, 0)[0] + + +def _increment_counter(buf) -> None: + value = struct.unpack_from("i", buf, 0)[0] + struct.pack_into("i", buf, 0, value + 1) + + +def _start_daemon(): + daemon = L3Daemon(0, lambda: Worker(level=3, num_sub_workers=1)) + port = daemon.start() + return daemon, f"127.0.0.1:{port}" + + +def _make_file_counter(path): + path.write_text("0") + + def read() -> int: + return int(path.read_text()) + + def add(amount: int) -> None: + path.write_text(str(read() + int(amount))) + + return read, add + + +def test_l4_remote_init_close_no_dispatch(): + daemon, endpoint = _start_daemon() + try: + w4 = Worker(level=4, num_sub_workers=0) + w4.add_remote_worker(endpoint) + w4.init() + w4.close() + finally: + daemon.stop() + + +def test_l4_remote_single_dispatch(tmp_path): + read_counter, add_counter = _make_file_counter(tmp_path / "remote_counter.txt") + daemon, endpoint = _start_daemon() + + try: + w4 = Worker(level=4, num_sub_workers=0) + def l3_sub(args): + add_counter(_scalar_value(args)) + + l3_sub_cid = w4.register(l3_sub) + + def l3_orch(orch, args, config): + orch.submit_sub(l3_sub_cid, args) + + l3_cid = w4.register(l3_orch) + w4.add_remote_worker(endpoint) + w4.init() + + def l4_orch(orch, args, config): + sub_args = TaskArgs() + sub_args.add_scalar(3) + orch.submit_next_level(l3_cid, sub_args, CallConfig()) + + w4.run(l4_orch) + w4.close() + assert read_counter() == 3 + finally: + daemon.stop() + + +def test_l4_remote_multiple_dispatches(tmp_path): + read_counter, add_counter = _make_file_counter(tmp_path / "remote_counter.txt") + daemon, endpoint = _start_daemon() + + try: + w4 = Worker(level=4, num_sub_workers=0) + def l3_sub(args): + add_counter(_scalar_value(args)) + + l3_sub_cid = w4.register(l3_sub) + + def l3_orch(orch, args, config): + orch.submit_sub(l3_sub_cid, args) + + l3_cid = w4.register(l3_orch) + w4.add_remote_worker(endpoint) + w4.init() + + def l4_orch(orch, args, config): + for value in (1, 2, 4): + sub_args = TaskArgs() + sub_args.add_scalar(value) + orch.submit_next_level(l3_cid, sub_args, CallConfig()) + + w4.run(l4_orch) + w4.close() + assert read_counter() == 7 + finally: + daemon.stop() + + +def test_l4_remote_with_local_sub(tmp_path): + read_remote_counter, add_remote_counter = _make_file_counter(tmp_path / "remote_counter.txt") + local_shm, local_buf = _make_shared_counter() + daemon, endpoint = _start_daemon() + + try: + w4 = Worker(level=4, num_sub_workers=1) + def l3_sub(args): + add_remote_counter(1) + + l3_sub_cid = w4.register(l3_sub) + local_cid = w4.register(lambda args: _increment_counter(local_buf)) + + def l3_orch(orch, args, config): + orch.submit_sub(l3_sub_cid) + + l3_cid = w4.register(l3_orch) + w4.add_remote_worker(endpoint) + w4.init() + + def l4_orch(orch, args, config): + orch.submit_next_level(l3_cid, TaskArgs(), CallConfig()) + orch.submit_sub(local_cid) + + w4.run(l4_orch) + w4.close() + assert read_remote_counter() == 1 + assert _read_counter(local_buf) == 1 + finally: + daemon.stop() + local_shm.close() + local_shm.unlink() + + +def test_l4_remote_multiple_runs(tmp_path): + read_counter, add_counter = _make_file_counter(tmp_path / "remote_counter.txt") + daemon, endpoint = _start_daemon() + + try: + w4 = Worker(level=4, num_sub_workers=0) + def l3_sub(args): + add_counter(1) + + l3_sub_cid = w4.register(l3_sub) + + def l3_orch(orch, args, config): + orch.submit_sub(l3_sub_cid) + + l3_cid = w4.register(l3_orch) + w4.add_remote_worker(endpoint) + w4.init() + + def l4_orch(orch, args, config): + orch.submit_next_level(l3_cid, TaskArgs(), CallConfig()) + + for _ in range(5): + w4.run(l4_orch) + w4.close() + assert read_counter() == 5 + finally: + daemon.stop() + + +def test_l4_remote_l3_multiple_subs(tmp_path): + read_counter, add_counter = _make_file_counter(tmp_path / "remote_counter.txt") + daemon, endpoint = _start_daemon() + + try: + w4 = Worker(level=4, num_sub_workers=0) + def l3_sub(args): + add_counter(1) + + l3_sub_cid = w4.register(l3_sub) + + def l3_orch(orch, args, config): + orch.submit_sub(l3_sub_cid) + orch.submit_sub(l3_sub_cid) + + l3_cid = w4.register(l3_orch) + w4.add_remote_worker(endpoint) + w4.init() + + def l4_orch(orch, args, config): + orch.submit_next_level(l3_cid, TaskArgs(), CallConfig()) + + w4.run(l4_orch) + w4.close() + assert read_counter() == 2 + finally: + daemon.stop() + + +def test_l4_remote_error_propagates(): + daemon, endpoint = _start_daemon() + + def broken_l3_orch(orch, args, config): + raise ValueError("remote failure") + + try: + w4 = Worker(level=4, num_sub_workers=0) + l3_cid = w4.register(broken_l3_orch) + w4.add_remote_worker(endpoint) + w4.init() + + def l4_orch(orch, args, config): + orch.submit_next_level(l3_cid, TaskArgs(), CallConfig()) + + try: + w4.run(l4_orch) + except RuntimeError as e: + assert "remote failure" in str(e) + else: + raise AssertionError("remote failure did not propagate") + finally: + w4.close() + finally: + daemon.stop() diff --git a/tests/ut/py/test_distributed/test_rpc_roundtrip.py b/tests/ut/py/test_distributed/test_rpc_roundtrip.py new file mode 100644 index 000000000..ddc4380f4 --- /dev/null +++ b/tests/ut/py/test_distributed/test_rpc_roundtrip.py @@ -0,0 +1,63 @@ +import pytest + +from simpler.distributed.proto import dispatch_pb2, dispatch_pb2_grpc +from simpler.distributed.rpc import RpcClient, RpcError, RpcServer + + +class EchoL3(dispatch_pb2_grpc.L3WorkerServicer): + def Dispatch(self, request, context): # noqa: N802 + return dispatch_pb2.DispatchResp(task_id=request.task_id, error_code=0) + + def Heartbeat(self, request, context): # noqa: N802 + return dispatch_pb2.Health(ok=True, message="ok") + + +class FailingL3(dispatch_pb2_grpc.L3WorkerServicer): + def Dispatch(self, request, context): # noqa: N802 + context.abort(13, "boom") + + def Heartbeat(self, request, context): # noqa: N802 + return dispatch_pb2.Health(ok=True, message="ok") + + +def test_rpc_roundtrip(): + server = RpcServer() + server.add_l3_worker(EchoL3()) + port = server.start(0) + client = RpcClient(f"127.0.0.1:{port}") + try: + resp = client.call_unary( + "L3Worker.Dispatch", + dispatch_pb2.DispatchReq(task_id=42, callable_id=7), + timeout=2, + ) + assert resp.task_id == 42 + assert resp.error_code == 0 + finally: + client.close() + server.stop(0) + + +def test_rpc_error_maps_to_exception(): + server = RpcServer() + server.add_l3_worker(FailingL3()) + port = server.start(0) + client = RpcClient(f"127.0.0.1:{port}") + try: + with pytest.raises(RpcError, match="boom"): + client.dispatch(dispatch_pb2.DispatchReq(task_id=1), timeout=2) + finally: + client.close() + server.stop(0) + + +def test_port_conflict_reports_clear_error(): + first = RpcServer() + port = first.start(0) + second = RpcServer() + try: + with pytest.raises(RpcError, match="failed to bind"): + second.start(port) + finally: + first.stop(0) + second.stop(0) diff --git a/tests/ut/py/test_distributed/test_tensor_pool.py b/tests/ut/py/test_distributed/test_tensor_pool.py new file mode 100644 index 000000000..0dfaa5868 --- /dev/null +++ b/tests/ut/py/test_distributed/test_tensor_pool.py @@ -0,0 +1,39 @@ +from simpler.distributed.proto import dispatch_pb2 +from simpler.distributed.tensor_pool import TensorPool + + +def test_tensor_pool_inline_bytes(): + pool = TensorPool(inline_threshold=8) + ref = pool.put_bytes(b"abc") + assert ref.inline_data == b"abc" + + +def test_tensor_pool_handle_bytes(): + pool = TensorPool(inline_threshold=2) + ref = pool.put_bytes(b"abcdef") + assert ref.HasField("handle") + assert pool.get_bytes(ref.handle) == b"abcdef" + + +def test_tensor_pool_service_pull(): + pool = TensorPool(inline_threshold=1) + ref = pool.put_bytes(b"abcdef") + service = pool.service() + chunks = list(service.PullTensor(ref.handle, None)) + assert b"".join(chunk.data for chunk in chunks) == b"abcdef" + assert chunks[-1].last + + +def test_tensor_pool_service_push(): + pool = TensorPool(inline_threshold=1) + service = pool.service() + handle = service.PushTensor( + iter( + [ + dispatch_pb2.TensorChunk(offset=0, data=b"abc"), + dispatch_pb2.TensorChunk(offset=3, data=b"def", last=True), + ] + ), + None, + ) + assert pool.get_bytes(handle) == b"abcdef" From 2adeb0ae2b01eda3369da903ec5f404a8215d46c Mon Sep 17 00:00:00 2001 From: PKUZHOU Date: Fri, 8 May 2026 09:48:15 +0800 Subject: [PATCH 2/6] feat(distributed): add L4 L3 RXE data plane --- ...istributed-l4-control-data-plane-rxe.zh.md | 360 ++++ docs/distributed-l4-implementation.zh.md | 162 +- docs/distributed-l4.md | 103 +- python/simpler/distributed/hcomm_abi_shim.cc | 56 + python/simpler/distributed/l3_daemon.py | 214 ++- .../simpler/distributed/proto/dispatch.proto | 26 + .../simpler/distributed/proto/dispatch_pb2.py | 52 +- .../distributed/proto/dispatch_pb2_grpc.py | 129 ++ python/simpler/distributed/remote_proxy.py | 343 +++- python/simpler/distributed/rxe_verbs_helper.c | 505 +++++ python/simpler/distributed/serialization.py | 150 +- python/simpler/distributed/tensor_pool.py | 278 ++- .../simpler/distributed/transport_backend.py | 1653 +++++++++++++++++ python/simpler/worker.py | 15 +- src/common/hierarchical/worker_manager.cpp | 16 +- tests/ut/py/test_distributed/test_catalog.py | 17 - .../test_distributed/test_hcomm_e2e_real.py | 180 ++ .../py/test_distributed/test_l4_l3_remote.py | 231 ++- .../test_distributed/test_real_e2e_smoke.py | 233 +++ tests/ut/py/test_distributed/test_rxe_real.py | 97 + .../py/test_distributed/test_tensor_pool.py | 139 +- .../test_transport_backend.py | 163 ++ tools/benchmark_rxe_data_plane.py | 108 ++ tools/test_rxe_data_plane.sh | 49 + 24 files changed, 5174 insertions(+), 105 deletions(-) create mode 100644 docs/distributed-l4-control-data-plane-rxe.zh.md create mode 100644 python/simpler/distributed/hcomm_abi_shim.cc create mode 100644 python/simpler/distributed/rxe_verbs_helper.c create mode 100644 python/simpler/distributed/transport_backend.py create mode 100644 tests/ut/py/test_distributed/test_hcomm_e2e_real.py create mode 100644 tests/ut/py/test_distributed/test_real_e2e_smoke.py create mode 100644 tests/ut/py/test_distributed/test_rxe_real.py create mode 100644 tests/ut/py/test_distributed/test_transport_backend.py create mode 100755 tools/benchmark_rxe_data_plane.py create mode 100755 tools/test_rxe_data_plane.sh diff --git a/docs/distributed-l4-control-data-plane-rxe.zh.md b/docs/distributed-l4-control-data-plane-rxe.zh.md new file mode 100644 index 000000000..75fa568e3 --- /dev/null +++ b/docs/distributed-l4-control-data-plane-rxe.zh.md @@ -0,0 +1,360 @@ +# L4/L3 分布式控制面与 RXE 数据面实现说明 + +本文记录当前 Simpler L4 到 L3 分布式 dispatch 的实现状态,重点说明这次新增的真实 RXE/ibverbs 数据面、L4 控制面流程、TensorPool handle 语义、测试方式和已知局限。 + +当前实现遵循一个边界原则:**不修改 `3rd/hcomm` 源码**。HCOMM 只作为可选运行时能力被 Simpler 侧 shim/adapter 使用;RXE 数据面实现、构建逻辑、测试脚本都放在 `simpler/` 内。 + +## 这次整体修改 + +主要新增和修改的代码如下: + +```text +python/simpler/distributed/ + transport_backend.py # gRPC/HCOMM/RXE transport backend 抽象与运行时加载 + rxe_verbs_helper.c # Simpler 自有 ibverbs RC RDMA write helper + remote_proxy.py # L4 侧 remote worker proxy,接入 RXE input/output 数据面 + l3_daemon.py # L3 daemon,backend process 内创建 TensorPool transport + serialization.py # TensorRef 解码、output writeback、RXE fallback + tensor_pool.py # TensorPool refresh hook,支持 RXE region 重建 + +tests/ut/py/test_distributed/ + test_transport_backend.py # RXE desc 编解码、backend 基础测试 + test_real_e2e_smoke.py # 实机 L4->L3 RXE 数据面 E2E + test_rxe_real.py # 实机 ibv_rc_pingpong smoke + +tools/ + test_rxe_data_plane.sh # 一键测试脚本 + benchmark_rxe_data_plane.py # gRPC vs RXE 端到端 benchmark +``` + +已有的 protobuf 消息没有新增字段。当前 `TensorHandle.transport` 和 `TensorHandle.transport_desc` 已足够承载不同 transport 的数据面描述。 + +## 总体架构 + +分布式路径分成控制面和数据面: + +```text +L4 Worker + -> 本地 PROCESS mailbox + -> Python remote shim thread + -> RemoteWorkerProxy + -> gRPC control plane + -> L3Daemon + -> backend process + -> Worker(level=3) + +Tensor data plane: + small tensor <= inline threshold + -> DispatchReq.tensor_refs.inline_data + + large input tensor + -> L3 TensorPool.AllocTensor + -> L3 TensorPool returns TensorHandle(transport=rxe/grpc/hcomm) + -> L4 writes payload into that handle + -> DispatchReq only carries TensorRef(handle=...) + + large output tensor + -> L4 registers local output buffer as RXE region + -> DispatchReq carries TensorRef(handle=local-rxe-output) + -> L3 runs task into temporary local buffer + -> L3 writes output bytes back to L4 RXE handle + -> DispatchResp returns ACK-style TensorRef(handle=local-rxe-output) +``` + +控制面负责调度、catalog、租约、错误传播和 handle 生命周期;数据面负责 tensor payload 的实际搬运。 + +## L4 控制面实现 + +L4 用户仍然通过原有接口使用远端 L3: + +```python +w4 = Worker(level=4, num_sub_workers=0) +w4.add_remote_worker("127.0.0.1:5050", tensor_transport="rxe") +w4.init() +``` + +`Worker.add_remote_worker()` 会在 `Worker.init()` 时创建一个本地 mailbox 和一个 `RemoteWorkerProxy`。C++ scheduler 仍然看到的是一个普通 PROCESS-mode next-level worker,Python shim thread 负责把 mailbox 中的任务转换成 gRPC dispatch。 + +L4 dispatch 的关键步骤: + +1. shim thread 从 mailbox 读出 callable id、`TaskArgs` 和 `CallConfig`。 +2. `RemoteWorkerProxy` 把 callable catalog 预先推送到 L3 daemon。 +3. 对每个 tensor 参数判断 inline、remote input handle 或 local output handle。 +4. 通过 `L3Worker.Dispatch` 发出 `DispatchReq`。 +5. 收到 `DispatchResp` 后,把 output 写回本地用户 buffer,释放临时 handle。 + +控制面的核心文件: + +- `remote_proxy.py` +- `l3_daemon.py` +- `catalog.py` +- `rpc.py` +- `serialization.py` +- `proto/dispatch.proto` + +## L3 控制面实现 + +`L3Daemon` 是远端 L3 节点入口。它不会直接在 gRPC handler 线程里运行 `Worker(level=3)`,而是启动一个 backend process: + +```text +L3Daemon process + - gRPC server + - Catalog service + - TensorPool control service facade + - Pipe to backend process + +Backend process + - Catalog mirror + - TensorPool + - transport backend: grpc / rxe / hcomm / auto + - lazy Worker(level=3) +``` + +这样做是为了避免 grpcio 线程和 `Worker(level=3)` 内部 fork sub/chip worker 发生冲突。TensorPool 的真实对象也在 backend process 内,因此它注册的 buffer 地址和实际执行任务的地址空间一致。 + +启动示例: + +```bash +python -m simpler.distributed.l3_daemon --port 5050 --tensor-transport rxe +``` + +## 数据面抽象 + +数据面抽象定义在 `transport_backend.py`: + +```python +class TensorTransportBackend: + name = "grpc" + def register_region(self, data: bytearray, *, tag: str) -> RegisteredRegion: ... + def unregister_region(self, region: RegisteredRegion) -> None: ... +``` + +TensorPool 分配大 tensor 时,会创建 `bytearray` 并调用 backend 的 `register_region()`。返回的 `RegisteredRegion` 被编码到 `TensorHandle`: + +```protobuf +message TensorHandle { + string node_id = 1; + uint64 handle_id = 2; + uint64 remote_addr = 3; + uint32 rkey = 4; + uint64 nbytes = 5; + uint64 lease_deadline_unix_ms = 6; + string transport = 7; + bytes transport_desc = 8; +} +``` + +当前 backend: + +- `GrpcTensorTransport`:默认路径,payload 仍通过 `TensorPool.PushTensor/PullTensor` 的 gRPC chunk 传输。 +- `HcommTensorTransport`:可选 HCOMM C API 适配层,只在 Simpler 内做 ABI shim/loader,不修改 HCOMM。 +- `RxeTensorTransport`:真实 RXE/ibverbs 数据面。 + +`build_tensor_transport()` 支持: + +```text +grpc +rxe +hcomm +auto +``` + +`auto` 默认保守,不自动启用 RXE。需要显式 `SIMPLER_RXE_AUTO=1` 才会在 auto 模式优先尝试 RXE。 + +## RXE 数据面实现 + +RXE 数据面由两层组成: + +```text +Python: + RxeTensorTransport + RxeDataPlaneClient + RxeRuntime + +C helper: + rxe_verbs_helper.c +``` + +### L3 侧注册 region + +L3 TensorPool 分配大 input handle 时: + +1. 创建 `bytearray(nbytes)`。 +2. `RxeTensorTransport.register_region()` 获取 buffer 地址。 +3. `RxeRuntime.server_start()` 调 C helper: + - 打开 RXE device + - 创建 PD/CQ/QP + - 注册 MR + - 启动 TCP 控制 server + - 等待 L4 建立 RC QP 后接收一次 RDMA write +4. 返回 `TensorHandle(transport="rxe", transport_desc=...)`。 + +### L4 侧写 input + +L4 收到 L3 分配的 RXE handle 后: + +1. `RemoteWorkerProxy._push_remote_tensor_rxe()` 创建本地 source buffer。 +2. `RxeDataPlaneClient.write_handle()` 解析 `transport_desc`。 +3. `simpler_rxe_write()` 建立 TCP 控制连接,交换 QP 信息。 +4. L4 发起 `IBV_WR_RDMA_WRITE`。 +5. CQ completion 成功后,L4 调 `TensorPool.RefreshTensor`。 + +`RefreshTensor` 不只是续租。对 RXE backend,它还会调用 `refresh_region()`,关闭旧的一次性 server 并在同一 buffer 上重建新 server。因此同一个 TensorPool handle 后续仍可再次写入。 + +### L3 到 L4 output 写回 + +这次新增了 output 方向的真实数据面: + +1. L4 遇到大 `OUTPUT / OUTPUT_EXISTING` tensor 时,不再先把旧内容推到 L3。 +2. L4 直接把本地 output buffer 注册成 RXE region,生成本地 `TensorHandle(node_id="l4-rxe-...", transport="rxe")`。 +3. `DispatchReq.tensor_refs` 携带这个 local output handle。 +4. L3 `decode_task_args_with_tensor_refs_and_writebacks()` 识别该 handle: + - 在 L3 backend process 内分配临时 mmap buffer 给 Worker 执行。 + - 记录 `RemoteTensorWriteback`。 +5. L3 task 执行结束后,`encode_output_tensor_refs()` 用 `RxeDataPlaneClient` 把临时 output buffer RDMA write 回 L4 handle。 +6. `DispatchResp.output_tensors` 返回同一个 handle 作为 ACK。 +7. L4 看到 ACK 属于本地 output handle,就不再 PullTensor。 + +当前 output RXE writeback 覆盖: + +- `TensorArgType.OUTPUT` +- `TensorArgType.OUTPUT_EXISTING` + +`INOUT` 暂时仍走 input staging 路径,因为它同时需要把初始值送到 L3,再把结果写回 L4。这个双向语义还没有在单个 handle 上完全优化。 + +### RXE transport desc v2 + +旧版本 `transport_desc` 是 JSON。当前版本改为二进制头,减少解析歧义并为后续扩展留空间: + +```text +magic = "SRXE" +version = 2 +header_size +port +gid_index +rkey +addr +size +ip[64] +device[64] +``` + +解析逻辑仍兼容旧 JSON desc,便于已有测试和临时 handle 过渡。 + +## HCOMM 现状 + +HCOMM 相关改动只保留在 Simpler 侧: + +- `hcomm_abi_shim.cc` +- `HcommRuntime` +- `HcommTensorTransport` +- `HcommDataPlaneClient` + +当前机器上 stock HCOMM CPU RoCE channel 对 910B1 host 场景不满足能力要求,因此没有把 HCOMM channel E2E 作为主路径。RXE backend 是当前真实数据面 smoke/E2E 的主验证路径。 + +## 测试与验证 + +一键测试脚本: + +```bash +cd /mnt/data/ntlab/zhouzhe/simpler_l4/simpler +tools/test_rxe_data_plane.sh +``` + +脚本执行: + +1. Python 编译检查。 +2. distributed 常规 UT。 +3. RXE/ibverbs `ibv_rc_pingpong` smoke。 +4. L4/L3 RXE tensor 数据面 E2E。 + +已验证结果: + +```text +38 passed, 3 skipped +1 passed +2 passed, 2 deselected +RXE data-plane tests passed. +``` + +可选 benchmark: + +```bash +SIMPLER_RUN_RXE_BENCHMARK=1 tools/test_rxe_data_plane.sh +``` + +也可以单独运行: + +```bash +PYTHONPATH=python tools/benchmark_rxe_data_plane.py \ + --sizes 8192,65536,1048576 \ + --repeats 10 \ + --warmup 2 \ + --transports grpc,rxe +``` + +输出 CSV: + +```text +transport,size_bytes,repeats,mean_ms,p50_ms,p95_ms,min_ms,max_ms +``` + +## 环境变量 + +常用配置: + +```bash +export SIMPLER_TENSOR_TRANSPORT=rxe +export SIMPLER_RXE_DEVICE=rxe0 +export SIMPLER_RXE_GID_INDEX=1 +export SIMPLER_RXE_SERVER_IP=192.168.0.243 +``` + +rdma-core 构建路径默认使用本机已验证路径: + +```bash +export SIMPLER_RXE_INCLUDE_DIR=/home/ntlab/rdma-build/rdma-core-50.0/build/include +export SIMPLER_RXE_LIB_DIR=/home/ntlab/rdma-build/rdma-core-50.0/build/lib +``` + +如果不设置 `SIMPLER_RXE_DEVICE / SIMPLER_RXE_GID_INDEX / SIMPLER_RXE_SERVER_IP`,`RxeRuntime` 会尝试从 `/sys/class/infiniband/rxe*` 和 IPv4-mapped GID 自动推断。 + +## 当前局限性 + +1. RXE helper 仍是 MVP + + 当前 C helper 是 RC QP + TCP 控制面的一次 write server。虽然 TensorPool refresh 会重建 server,使同一 handle 后续可以继续写,但它还不是长期连接池,也没有 QP 复用。 + +2. 性能不是最终形态 + + 当前 RXE 路径每个 region/write 都包含 TCP 控制连接、QP 创建、MR 注册等成本。小 tensor 或短迭代下可能比 gRPC 更慢。benchmark 用来观察趋势,不代表最终设计性能上限。 + +3. `INOUT` 还没有完整双向 RXE 优化 + + `OUTPUT / OUTPUT_EXISTING` 大 tensor 已支持 L3->L4 RXE writeback。`INOUT` 仍按 input staging 处理,后续需要同时支持初始值 L4->L3 和结果 L3->L4。 + +4. transport desc 还不是 protobuf + + 当前 v2 是 Simpler 自定义二进制头,旧 JSON 兼容。后续如果 desc 需要跨语言稳定演进,可以把 desc 单独 protobuf 化或纳入统一 metadata schema。 + +5. 错误恢复偏保守 + + L4->L3 input 的显式 `rxe` 模式失败会报错;`auto` 模式才回退 gRPC。L3->L4 output 写回失败会退回 TensorPool/gRPC response 路径,以保证语义正确。 + +6. 当前主要验证是单机 RXE + + 实机测试覆盖本机 RXE device 和 ibverbs RC pingpong。跨节点 RoCE、多 rank、多并发 worker、长时间压测还需要补测试。 + +7. 安全边界仍是受信任集群 + + Catalog 使用 cloudpickle 传 callable payload,本质是可执行代码反序列化。Catalog/gRPC 服务不应暴露给不可信客户端。 + +## 后续建议 + +优先级较高的后续工作: + +1. 把 RXE C helper 改成长连接或连接池,复用 PD/CQ/QP/MR。 +2. 完成 `INOUT` 双向 RXE 数据面。 +3. 增加多并发 dispatch 压测,覆盖 server refresh 和 FreeTensor 时序。 +4. 把 `transport_desc` 迁移到稳定 schema。 +5. 加入性能指标基线,持续对比 gRPC chunk 与 RXE write 的延迟和吞吐。 diff --git a/docs/distributed-l4-implementation.zh.md b/docs/distributed-l4-implementation.zh.md index 6249797ff..a2d16e64f 100644 --- a/docs/distributed-l4-implementation.zh.md +++ b/docs/distributed-l4-implementation.zh.md @@ -77,6 +77,9 @@ service Catalog { } service TensorPool { + rpc AllocTensor(TensorAllocReq) returns (TensorHandle); + rpc FreeTensor(TensorFreeReq) returns (Empty); + rpc RefreshTensor(TensorRefreshReq) returns (TensorHandle); rpc PullTensor(TensorHandle) returns (stream TensorChunk); rpc PushTensor(stream TensorChunk) returns (TensorHandle); } @@ -90,14 +93,14 @@ service TensorPool { - `config_blob`: 序列化后的 `CallConfig` - `scalar_args`: 标量参数 - `tensor_args`: `ContinuousTensor` 元数据 -- `tensor_refs`: 为后续真实 tensor 数据面预留 +- `tensor_refs`: 当前数据面使用的 tensor 引用;小 tensor 直接 inline,大 tensor 使用 L3 `TensorPool` handle `DispatchResp` 当前承载: - `error_code`: `0` 表示成功 - `error_msg`: 远端失败摘要 - `remote_traceback`: 远端 Python traceback -- `output_tensors`: 为后续 output 回传预留 +- `output_tensors`: 当前用于 output 回传;L3 返回 `OUTPUT / INOUT / OUTPUT_EXISTING` tensor 的 inline bytes 或 handle ## L4 侧实现 @@ -132,6 +135,25 @@ service TensorPool { RemoteWorkerProxy.dispatch(callable_id, args, cfg) ``` +如果 `TaskArgs` 里包含 `ContinuousTensor`,L4 侧不再把 `tensor.data` 裸地址发给远端。当前策略是: + +1. 读取本地 tensor 指针和 `tensor.nbytes()`,拷贝出一份 bytes。`OUTPUT / OUTPUT_EXISTING` 只需要远端写入空间,当前按同样路径 staging。 +2. `nbytes <= 4KB` 时直接放进 `DispatchReq.tensor_refs.inline_data`。 +3. `nbytes > 4KB` 时先通过 `TensorPool.AllocTensor` 在 L3 backend pool 分配 handle。 +4. L4 通过 `TensorPool.PushTensor` 分片把 bytes 写入该 handle。 +5. `DispatchReq.tensor_refs` 只携带 handle、shape、dtype、tag。 +6. `Dispatch` 成功返回后,L4 读取 `DispatchResp.output_tensors`,按本地 output tensor 顺序写回原始 buffer。 +7. L4 best-effort 调 `TensorPool.FreeTensor` 释放 input/output handle。 + +为支持 remote mailbox 恢复 tensor tags,C++ PROCESS mailbox 在旧的 `[T][S][tensors][scalars]` 后追加了向后兼容的 tags 扩展: + +```text +uint32 magic = "SL4T" +int32 tags[T] +``` + +C++ `read_blob` 仍按旧格式读取并忽略尾部;Python `_read_args_from_mailbox()` 识别该扩展并恢复 `TensorArgType`。没有这个扩展时,远端无法区分 `INPUT / OUTPUT / INOUT`。 + 远端返回成功后,shim thread 把 mailbox 状态写回 `TASK_DONE`。如果远端失败,则把错误写入 mailbox error 区域,后续由现有 drain/error propagation 路径抛回 L4 调用者。 ## Callable Catalog @@ -193,11 +215,12 @@ Dispatch 时: 1. gRPC handler 收到 `DispatchReq`。 2. 将 protobuf bytes 通过 pipe 发给 backend。 -3. backend lazy 创建 inner `Worker(level=3)`。 -4. backend 把 catalog 中所有 callable 安装进 inner worker 的 `_callable_registry`。 -5. 查找 `req.callable_id / req.callable_version` 对应的 orch fn。 -6. 反序列化 `TaskArgs` 和 `CallConfig`。 -7. 调用: +3. 查找 `req.callable_id / req.callable_version` 对应的 orch fn。 +4. 反序列化 `TaskArgs` 和 `CallConfig`。 +5. 无 tensor 的 scalar dispatch 使用持久 inner `Worker(level=3)`。 +6. 带 `tensor_refs` 的 dispatch 使用临时 inner `Worker(level=3)`:先把 tensor materialize 到 shared mmap,再 init/fork L3 sub/chip 子进程,保证子进程继承同一段映射。 +7. backend 把 catalog 中所有 callable 安装进 inner worker 的 `_callable_registry`。 +8. 调用: ```python inner.run(orch_fn, args, cfg) @@ -231,20 +254,120 @@ backend exception 已实现: -- 小字节数据 inline -- 大字节数据注册为 handle +- 小 tensor inline,默认阈值 4KB +- L3 backend process 内的 Python `TensorPool` +- `TensorPool` 的 transport backend 抽象,默认后端是 `grpc` +- `AllocTensor / FreeTensor / RefreshTensor` +- handle lease、TTL、过期 GC、pool 容量限制 - `PullTensor` streaming - `PushTensor` streaming -- `ContinuousTensor` 元数据随 `DispatchReq.tensor_args` 传输 +- L4 `RemoteWorkerProxy` 的 tensor staging:本地指针读取 bytes,小 tensor inline,大 tensor remote alloc + push + handle +- 可选 HCOMM backend:L3 pool 注册 byte buffer,handle 中带 `transport="hcomm"` 和 `transport_desc` +- L4 收到 HCOMM handle 时,可用 `HcommWriteWithNotifyNbi` + `HcommChannelFence` 推 input tensor +- L3 backend dispatch 时把 `TensorRef` materialize 成 shared mmap,并构造 `ContinuousTensor` +- 带 tensor 的 L3 dispatch 使用 per-dispatch 临时 inner worker,使 L3 sub/chip fork 后继承 tensor mmap +- `OUTPUT / INOUT / OUTPUT_EXISTING` output tensor 回传到 L4 原始 buffer +- scalar args 与 tensor refs 混合传输 尚未完成: -- 远端真实 tensor materialization -- output tensor 回写 -- `OUTPUT_EXISTING` 的远端到本地同步 +- 完整 RDMA/Urma/SHM 零拷贝 transport;当前默认数据面仍通过 gRPC streaming,HCOMM 只覆盖 input push 的第一版接线 +- 持久 L3 worker 复用场景下的动态 tensor mmap 注入;当前带 tensor 的 dispatch 为保证 fork 继承映射,会使用临时 inner worker - 与 torch tensor / NPU device memory 的完整数据面打通 +- output tensor 的 HCOMM 写回协议;当前 output 仍走 gRPC `PullTensor` +- pool 中默认 `grpc` 后端的 `remote_addr/rkey` 仍是协议占位:`remote_addr` 是 Python bytearray buffer 地址,`rkey=0` + +## HCOMM backend 当前边界 + +当前没有直接依赖 HCOMM 内部 `HostCpuRoceChannel` 类,而是通过公开/实验 C API 做一层 Python facade: + +- L3 侧:`HcommMemReg` 注册 `TensorPool` 的 byte buffer,`HcommMemExport` 导出内存描述。 +- proto:`TensorHandle` 新增 `transport` 和 `transport_desc`,避免 HCOMM 内存描述丢失。 +- L4 侧:当 `RemoteWorkerProxy` 配置 `tensor_transport="hcomm"` 或 `auto`,且远端 handle 标记为 `hcomm` 时,优先用 `HcommMemImport` 导入 `transport_desc`,把本地源数据拷贝到 HCOMM 注册过的 host staging buffer,再调用 `HcommWriteWithNotifyNbi` 和 `HcommChannelFence`。 +- endpoint:可以直接传 `SIMPLER_HCOMM_ENDPOINT_HANDLE`,也可以用 `SIMPLER_HCOMM_ENDPOINT_IP` 和可选 location 字段自动创建。 +- channel:可以直接传 `SIMPLER_HCOMM_CHANNEL_HANDLE`;也可以基于最新 public `HcommChannelCreate` ABI 自动创建 CPU RoCE channel。`SIMPLER_HCOMM_CHANNEL_ROLE` 选择 `client`/`server`,`SIMPLER_HCOMM_CHANNEL_PORT` 选择 listen/connect 端口;`SIMPLER_HCOMM_SOCKET_HANDLE` 仍可选传入,但不再是 Python facade 自动建 channel 的必要条件。 +- `auto` 模式:HCOMM 必要资源不存在或写入失败时自动回落到 gRPC。 +- 显式 `hcomm` 模式:缺库或缺必要资源会报错,不静默假装走 RDMA。 + +启用方式: + +```bash +SIMPLER_TENSOR_TRANSPORT=hcomm \ +SIMPLER_HCOMM_LIB=/path/to/libhcomm.so \ +SIMPLER_HCOMM_ENDPOINT_IP=192.168.0.243 \ +SIMPLER_HCOMM_CHANNEL_ROLE=client \ +SIMPLER_HCOMM_CHANNEL_PORT=60001 \ +python -m simpler.distributed.l3_daemon --port 5050 --tensor-transport hcomm +``` + +或者先用 `auto` 做兼容运行: + +```bash +SIMPLER_TENSOR_TRANSPORT=auto python -m simpler.distributed.l3_daemon --port 5050 --tensor-transport auto +``` + +注意: + +- 当前 facade 只依赖 public `include/hcomm_res.h`/`include/hcomm_primitives.h` 中的 endpoint、memory、channel create/destroy 和 write/fence 接口;不会直接包含内部 `api_c_adpt/hcomm_c_adpt.h`。 +- 最新 `HcommChannelCreate` 的 CPU RoCE 路径可以在 socket 为空时按 `role`/`port` 自建连接。server 侧用 `exchangeAllMems=true` 触发 listen,client 侧注册 staging mem 后连接并写远端导入内存。 +- `transport_desc` 已随 handle 传输,L4 input push 已可通过 `HcommMemImport` 解析远端内存;缺 endpoint handle 时会退回使用 handle 中的 `remote_addr`。 +- 如果使用外部预建 channel,调用方需要保证该 channel 的本地 MR 覆盖 L4 发送 staging buffer;自动创建 channel 时会使用 client 自己注册的 staging mem handle。 +- output tensor 仍走 gRPC pull,因为当前 CPU RoCE 公开路径还缺少和 Simpler output 语义匹配的读回或远端写回协议。 + +真实 HCOMM smoke test 默认跳过,需要显式打开: + +```bash +SIMPLER_HCOMM_REAL_TEST=1 \ +SIMPLER_HCOMM_LIB=/path/to/libhcomm.so \ +SIMPLER_HCOMM_ENDPOINT_IP=127.0.0.1 \ +python -m pytest tests/ut/py/test_distributed/test_transport_backend.py -q +``` + +如果要跑真实 `WriteWithNotify + Fence`,还需要外部预建 channel,并设置: +`SIMPLER_HCOMM_ENDPOINT_HANDLE`、`SIMPLER_HCOMM_CHANNEL_HANDLE`、`SIMPLER_HCOMM_REMOTE_ADDR`、`SIMPLER_HCOMM_REMOTE_NBYTES`。 + +单机 HCOMM E2E smoke test 会创建 server/client 两个进程,通过最新 channel desc +结构自动建 CPU RoCE channel,然后 client 使用 `HcommWriteWithNotifyNbi` 写 server +注册的 host buffer。该测试默认跳过,需要显式打开,并要求已经构建好 +`libhcomm.so`: + +```bash +SIMPLER_HCOMM_E2E_REAL_TEST=1 \ +SIMPLER_HCOMM_LIB=/path/to/libhcomm.so \ +SIMPLER_HCOMM_ENDPOINT_IP=192.168.0.243 \ +SIMPLER_HCOMM_CHANNEL_PORT=60001 \ +python -m pytest tests/ut/py/test_distributed/test_hcomm_e2e_real.py -q +``` + +RXE/Soft-RoCE 可以作为更底层的实机 smoke test,用来确认本机 ibverbs/RoCE +数据通路能不能真的跑起来。它验证的是 `rxe*` 设备、GID、RC QP 和 verbs +读写握手,不等价于 HCOMM 端到端 channel 测试;HCOMM 仍然需要自己的 endpoint、 +channel 资源和 shared library。 + +当前机器上 `ibv_devices` 能看到: + +- `rxe0`:绑定 `192.168.0.243`,GID index 1 为 `::ffff:192.168.0.243` +- `rxe_lo`:绑定 localhost + +已固化一个默认跳过的 pytest: + +```bash +SIMPLER_RXE_REAL_TEST=1 \ +SIMPLER_RXE_DEVICE=rxe0 \ +SIMPLER_RXE_GID_INDEX=1 \ +SIMPLER_RXE_SERVER_IP=192.168.0.243 \ +python -m pytest tests/ut/py/test_distributed/test_rxe_real.py -q +``` + +在当前开发机上该测试实际启动 server/client 两个 `ibv_rc_pingpong` 进程并通过。 +如果不设置 `SIMPLER_RXE_GID_INDEX` 和 `SIMPLER_RXE_SERVER_IP`,测试会尝试从 +`/sys/class/infiniband//ports/1/gids` 自动找第一个 IPv4-mapped GID。 + +所以当前端到端 remote dispatch 测试覆盖两类路径: -所以当前端到端 remote dispatch 测试主要覆盖 scalar `TaskArgs` 和 Python callable 执行链路。 +- scalar `TaskArgs` 和 Python callable 执行链路 +- L4 到 L3 backend/orch fn 的 tensor input/output 数据面,包括 inline 小 tensor、handle 大 tensor、INOUT 写回 +- L4 到 L3 sub worker 的 tensor input/output 路径 ## 使用示例 @@ -283,7 +406,7 @@ python -m pytest tests/ut/py/test_distributed tests/ut/py/test_worker/test_l4_re 当前验证结果: ```text -32 passed +44 passed, 5 skipped, 1 warning ``` 额外检查: @@ -309,13 +432,20 @@ git diff --check - mailbox shim thread - L3 daemon backend process - scalar args dispatch +- input tensor data plane MVP(L4 到 L3 backend/orch fn 和 L3 sub) +- HCOMM input push 的可选后端接线和 fake 单测覆盖 +- output tensor 自动回传(`OUTPUT / INOUT / OUTPUT_EXISTING` 写回 L4 原始 buffer) +- 带 tensor dispatch 的 shared mmap materialization 和 per-dispatch inner worker - remote traceback 传播 - heartbeat fail-fast - 示例和测试 未完成: -- 完整 tensor 数据面 +- 完整 RDMA/Urma/SHM 零拷贝 tensor transport +- HCOMM socket/channel 生命周期完全自动创建与跨节点握手 +- HCOMM output tensor 写回路径 +- 持久 L3 worker 中向已 fork 子进程动态注入新 tensor mapping - 多 remote 负载均衡策略 - 节点发现或服务注册 - 鉴权、TLS、租户隔离 diff --git a/docs/distributed-l4.md b/docs/distributed-l4.md index 8084388c7..afb358b52 100644 --- a/docs/distributed-l4.md +++ b/docs/distributed-l4.md @@ -46,9 +46,106 @@ process with active gRPC threads. ## Tensors -`tensor_pool.py` provides the planned inline/handle byte pool surface. Scalar -`TaskArgs` and `ContinuousTensor` metadata are wired through dispatch today; -full remote tensor materialization is isolated behind `TensorPool`. +`tensor_pool.py` now provides the Python MVP data-plane bridge: + +- tensors up to 4KB are sent inline in `DispatchReq.tensor_refs`; +- larger tensors are staged into the remote L3 backend `TensorPool` with + `AllocTensor` + streaming `PushTensor`, then dispatched by handle; +- the L3 backend materializes `TensorRef` values into shared mmap-backed + `ContinuousTensor` buffers before calling the inner `Worker`; +- `OUTPUT`, `INOUT`, and `OUTPUT_EXISTING` tensors are returned in + `DispatchResp.output_tensors` and copied back into the original L4 buffers; +- tensor dispatches use a per-dispatch inner L3 worker so sub/chip children are + forked after the mmap exists and inherit the same tensor storage. + +This is still not the production zero-copy path. `remote_addr` and `rkey` are +kept in the protocol shape for future SHM/RDMA/Urma backends, while the current +MVP uses Python byte buffers and gRPC streaming for inter-host data movement. +Persistent L3 worker reuse for tensor dispatches still needs a later mailbox or +transport path that can inject new shared mappings into already-forked children. + +### Optional HCOMM Backend + +The tensor pool now has a narrow transport backend boundary. The default remains +`grpc`, preserving the existing byte-pool + gRPC streaming behavior. An +experimental `hcomm` backend can be selected explicitly: + +```bash +SIMPLER_TENSOR_TRANSPORT=hcomm \ +SIMPLER_HCOMM_LIB=/path/to/libhcomm.so \ +SIMPLER_HCOMM_ENDPOINT_IP=192.168.0.243 \ +SIMPLER_HCOMM_CHANNEL_ROLE=client \ +SIMPLER_HCOMM_CHANNEL_PORT=60001 \ +python -m simpler.distributed.l3_daemon --port 5050 --tensor-transport hcomm +``` + +`SIMPLER_HCOMM_ENDPOINT_HANDLE` can still be supplied directly. If it is not +set, the backend can create an endpoint from `SIMPLER_HCOMM_ENDPOINT_IP` and the +optional `SIMPLER_HCOMM_ENDPOINT_*` location fields. `--tensor-transport auto` +tries HCOMM only when the required resources are present; otherwise it falls +back to `grpc`. + +In this first version, the L3 `TensorPool` registers large tensor buffers with +HCOMM and publishes `TensorHandle.transport = "hcomm"` plus the exported memory +descriptor. The L4 side imports that descriptor when an endpoint handle is +available, then uses `HcommWriteWithNotifyNbi` + `HcommChannelFence` for input +tensor push. The L4 client stages source bytes through a HCOMM-registered host +buffer before issuing the write, so the local source address is covered by a +registered memory region. A pre-created `SIMPLER_HCOMM_CHANNEL_HANDLE` is still +accepted, but the facade can now also call the latest public `HcommChannelCreate` +ABI directly. For CPU RoCE channels, `SIMPLER_HCOMM_CHANNEL_ROLE` selects +`client` or `server`, and `SIMPLER_HCOMM_CHANNEL_PORT` selects the listen/connect +port. A socket handle may still be supplied for environments that pre-create one, +but it is no longer required by the Python facade. + +Output tensor writeback still uses the existing gRPC `PullTensor` path, because +the current public CPU RoCE path does not yet provide the full +readback/remote-writeback protocol needed by the L4 output semantics. + +Real HCOMM smoke tests are opt-in so normal CI does not require HCOMM hardware: + +```bash +SIMPLER_HCOMM_REAL_TEST=1 \ +SIMPLER_HCOMM_LIB=/path/to/libhcomm.so \ +SIMPLER_HCOMM_ENDPOINT_IP=127.0.0.1 \ +python -m pytest tests/ut/py/test_distributed/test_transport_backend.py -q +``` + +The pre-created channel write smoke additionally requires +`SIMPLER_HCOMM_ENDPOINT_HANDLE`, `SIMPLER_HCOMM_CHANNEL_HANDLE`, +`SIMPLER_HCOMM_REMOTE_ADDR`, and `SIMPLER_HCOMM_REMOTE_NBYTES`. + +The single-node HCOMM E2E smoke creates a server/client channel pair through the +latest channel descriptor shape and writes bytes over CPU RoCE. It is also +opt-in because it needs a built HCOMM shared library and a working RoCE/RXE +device: + +```bash +SIMPLER_HCOMM_E2E_REAL_TEST=1 \ +SIMPLER_HCOMM_LIB=/path/to/libhcomm.so \ +SIMPLER_HCOMM_ENDPOINT_IP=192.168.0.243 \ +SIMPLER_HCOMM_CHANNEL_PORT=60001 \ +python -m pytest tests/ut/py/test_distributed/test_hcomm_e2e_real.py -q +``` + +Local Soft-RoCE/RXE can be used as a lower-level real-machine smoke test for +the ibverbs data path. This validates that the host has an active RXE device and +that RC queue pairs can exchange data over the selected GID; it does not prove +that HCOMM channel creation is available, because HCOMM still needs its own +endpoint/channel resources and shared library. + +```bash +SIMPLER_RXE_REAL_TEST=1 \ +SIMPLER_RXE_DEVICE=rxe0 \ +SIMPLER_RXE_GID_INDEX=1 \ +SIMPLER_RXE_SERVER_IP=192.168.0.243 \ +python -m pytest tests/ut/py/test_distributed/test_rxe_real.py -q +``` + +If `SIMPLER_RXE_GID_INDEX` and `SIMPLER_RXE_SERVER_IP` are omitted, the test +tries to infer the first IPv4-mapped GID from `/sys/class/infiniband/`. +On the current development host, `rxe0` GID index `1` maps to +`::ffff:192.168.0.243` and the RC pingpong smoke passes. ## Health diff --git a/python/simpler/distributed/hcomm_abi_shim.cc b/python/simpler/distributed/hcomm_abi_shim.cc new file mode 100644 index 000000000..5843c134b --- /dev/null +++ b/python/simpler/distributed/hcomm_abi_shim.cc @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2026. + * + * ABI shim for loading stock local HCOMM builds from Simpler tests. + * + * Some local libhcomm.so builds reference a C++ template instantiation that is + * not exported by the linked HCOMM sidecar libraries. Keep the fix outside + * HCOMM by providing the missing instantiation from Simpler's adapter layer. + */ + +#include "hccl_communicator.h" + +namespace hccl { + +template <> +HcclResult HcclCommunicator::GenIbvAiRMAInfo( + u32 rankid, + const std::shared_ptr &transport, + const std::string &tag, + HcclAiRMAInfo *aiRMAInfoPtr) +{ + (void)tag; + std::vector aiQpVec; + HcclResult ret = transport->GetAiRMAQueueInfo(aiQpVec); + if (ret != HCCL_SUCCESS) { + return ret; + } + if (aiRMAInfoPtr == nullptr) { + return HCCL_E_PTR; + } + if (aiQpVec.size() != aiRMAInfoPtr->qpNum) { + return HCCL_E_INTERNAL; + } + if (aiSqMem_ == nullptr || aiScqMem_ == nullptr || aiRqMem_ == nullptr || aiRcqMem_ == nullptr) { + return HCCL_E_PTR; + } + + HcclAiRMAWQ *aiSqHost = reinterpret_cast(aiSqMem_->ptr()); + HcclAiRMACQ *aiScqHost = reinterpret_cast(aiScqMem_->ptr()); + HcclAiRMAWQ *aiRqHost = reinterpret_cast(aiRqMem_->ptr()); + HcclAiRMACQ *aiRcqHost = reinterpret_cast(aiRcqMem_->ptr()); + if (aiSqHost == nullptr || aiScqHost == nullptr || aiRqHost == nullptr || aiRcqHost == nullptr) { + return HCCL_E_PTR; + } + + for (u32 j = 0; j < aiRMAInfoPtr->qpNum; ++j) { + const auto &aiQpInfo = aiQpVec[j]; + aiSqHost[rankid * aiRMAInfoPtr->qpNum + j] = aiQpInfo.sq; + aiScqHost[rankid * aiRMAInfoPtr->qpNum + j] = aiQpInfo.scq; + aiRqHost[rankid * aiRMAInfoPtr->qpNum + j] = aiQpInfo.rq; + aiRcqHost[rankid * aiRMAInfoPtr->qpNum + j] = aiQpInfo.rcq; + } + return HCCL_SUCCESS; +} + +} // namespace hccl diff --git a/python/simpler/distributed/l3_daemon.py b/python/simpler/distributed/l3_daemon.py index 0f0ac2a9e..6e2686a40 100644 --- a/python/simpler/distributed/l3_daemon.py +++ b/python/simpler/distributed/l3_daemon.py @@ -4,6 +4,7 @@ import argparse import multiprocessing as mp +import os import threading import traceback from collections.abc import Callable @@ -16,17 +17,30 @@ from .catalog import Catalog, CatalogService from .proto import dispatch_pb2, dispatch_pb2_grpc from .rpc import RpcServer -from .serialization import decode_config, decode_task_args +from .serialization import ( + decode_config, + decode_task_args, + decode_task_args_with_tensor_refs_and_writebacks, + encode_output_tensor_refs, +) from .tensor_pool import TensorPool +from .transport_backend import build_tensor_transport class L3Daemon(dispatch_pb2_grpc.L3WorkerServicer): """RPC facade that delegates dispatches to a lazily initialized inner Worker.""" - def __init__(self, port: int = 0, worker_factory: Optional[Callable[[], Worker]] = None) -> None: + def __init__( + self, + port: int = 0, + worker_factory: Optional[Callable[[], Worker]] = None, + *, + tensor_transport: Optional[str] = None, + ) -> None: self.port = int(port) self.catalog = Catalog() self.tensor_pool = TensorPool() + self.tensor_transport = tensor_transport or os.getenv("SIMPLER_TENSOR_TRANSPORT", "grpc") self._worker_factory = worker_factory or (lambda: Worker(level=3, num_sub_workers=1)) self._server: Optional[RpcServer] = None self._backend_proc = None @@ -38,7 +52,7 @@ def start(self, host: str = "127.0.0.1") -> int: server = RpcServer() server.add_l3_worker(self) server.add_catalog(_BackendCatalogService(self.catalog, self._backend_call)) - server.add_tensor_pool(self.tensor_pool.service()) + server.add_tensor_pool(_BackendTensorPoolService(self._backend_call)) self.port = server.start(self.port, host) self._server = server return self.port @@ -92,7 +106,19 @@ def _start_backend(self) -> None: return ctx = mp.get_context("fork") if hasattr(mp, "get_context") else mp parent_conn, child_conn = ctx.Pipe() - proc = ctx.Process(target=_backend_loop, args=(child_conn, self._worker_factory), daemon=True) + proc = ctx.Process( + target=_backend_loop, + args=( + child_conn, + self._worker_factory, + self.tensor_pool.node_id, + self.tensor_pool.inline_threshold, + self.tensor_pool.capacity_bytes, + self.tensor_pool.default_ttl_ms, + self.tensor_transport, + ), + daemon=True, + ) proc.start() child_conn.close() self._backend_conn = parent_conn @@ -120,10 +146,76 @@ def PushCallable(self, request, context): # noqa: N802, ANN001 return dispatch_pb2.Empty() -def _backend_loop(conn, worker_factory) -> None: - catalog = Catalog() +class _BackendTensorPoolService(dispatch_pb2_grpc.TensorPoolServicer): + def __init__(self, backend_call) -> None: + self._backend_call = backend_call + + def AllocTensor(self, request, context): # noqa: N802, ANN001 + try: + data = self._backend_call(("tensor_alloc", request.SerializeToString())) + except Exception as e: # noqa: BLE001 + context.abort(grpc.StatusCode.RESOURCE_EXHAUSTED, str(e)) + handle = dispatch_pb2.TensorHandle() + handle.ParseFromString(data) + return handle + + def FreeTensor(self, request, context): # noqa: N802, ANN001 + try: + self._backend_call(("tensor_free", request.handle.SerializeToString())) + except Exception as e: # noqa: BLE001 + context.abort(grpc.StatusCode.NOT_FOUND, str(e)) + return dispatch_pb2.Empty() + + def RefreshTensor(self, request, context): # noqa: N802, ANN001 + try: + data = self._backend_call(("tensor_refresh", request.SerializeToString())) + except Exception as e: # noqa: BLE001 + context.abort(grpc.StatusCode.NOT_FOUND, str(e)) + handle = dispatch_pb2.TensorHandle() + handle.ParseFromString(data) + return handle + + def PullTensor(self, request, context): # noqa: N802, ANN001 + try: + payload = self._backend_call(("tensor_pull", request.SerializeToString())) + except Exception as e: # noqa: BLE001 + context.abort(grpc.StatusCode.NOT_FOUND, str(e)) + chunk_size = 1024 * 1024 + for offset in range(0, len(payload), chunk_size): + chunk = payload[offset : offset + chunk_size] + yield dispatch_pb2.TensorChunk( + handle=request, + offset=offset, + data=chunk, + last=offset + len(chunk) >= len(payload), + ) + if not payload: + yield dispatch_pb2.TensorChunk(handle=request, offset=0, data=b"", last=True) + + def PushTensor(self, request_iterator, context): # noqa: N802, ANN001 + chunks = list(request_iterator) + handle = chunks[0].handle if chunks else dispatch_pb2.TensorHandle() + payload = _join_chunks(chunks) + try: + data = self._backend_call(("tensor_push", handle.SerializeToString(), payload)) + except Exception as e: # noqa: BLE001 + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + out = dispatch_pb2.TensorHandle() + out.ParseFromString(data) + return out + + +def _backend_loop(conn, worker_factory, node_id, inline_threshold, capacity_bytes, default_ttl_ms, tensor_transport) -> None: inner: Optional[Worker] = None try: + catalog = Catalog() + tensor_pool = TensorPool( + node_id=node_id, + inline_threshold=inline_threshold, + capacity_bytes=capacity_bytes, + default_ttl_ms=default_ttl_ms, + transport_backend=build_tensor_transport(tensor_transport), + ) while True: msg = conn.recv() op = msg[0] @@ -135,11 +227,56 @@ def _backend_loop(conn, worker_factory) -> None: catalog.install_from_payload(cid, version, payload) conn.send((True, None)) continue + if op == "tensor_alloc": + _, req_bytes = msg + req = dispatch_pb2.TensorAllocReq() + req.ParseFromString(req_bytes) + handle = tensor_pool.alloc( + req.nbytes, + ttl_ms=req.ttl_ms, + shape=req.shape, + dtype=req.dtype, + tag=req.tag, + ) + conn.send((True, handle.SerializeToString())) + continue + if op == "tensor_free": + _, handle_bytes = msg + handle = dispatch_pb2.TensorHandle() + handle.ParseFromString(handle_bytes) + tensor_pool.free(handle) + conn.send((True, None)) + continue + if op == "tensor_refresh": + _, req_bytes = msg + req = dispatch_pb2.TensorRefreshReq() + req.ParseFromString(req_bytes) + handle = tensor_pool.refresh(req.handle, req.ttl_ms) + conn.send((True, handle.SerializeToString())) + continue + if op == "tensor_pull": + _, handle_bytes = msg + handle = dispatch_pb2.TensorHandle() + handle.ParseFromString(handle_bytes) + conn.send((True, tensor_pool.get_bytes(handle))) + continue + if op == "tensor_push": + _, handle_bytes, payload = msg + handle = dispatch_pb2.TensorHandle() + handle.ParseFromString(handle_bytes) + if handle.handle_id: + tensor_pool.write_bytes(handle, payload) + out = tensor_pool.refresh(handle) + else: + out_ref = tensor_pool.put_bytes(payload, force_handle=True) + out = out_ref.handle + conn.send((True, out.SerializeToString())) + continue if op == "dispatch": _, req_bytes = msg req = dispatch_pb2.DispatchReq() req.ParseFromString(req_bytes) - resp, inner = _backend_dispatch(req, catalog, worker_factory, inner) + resp, inner = _backend_dispatch(req, catalog, tensor_pool, worker_factory, inner) conn.send((True, resp.SerializeToString())) continue raise RuntimeError(f"unknown backend op {op!r}") @@ -151,6 +288,8 @@ def _backend_loop(conn, worker_factory) -> None: except Exception: # noqa: BLE001 pass finally: + if "tensor_pool" in locals(): + tensor_pool.close() if inner is not None: inner.close() @@ -158,17 +297,13 @@ def _backend_loop(conn, worker_factory) -> None: def _backend_dispatch( req: dispatch_pb2.DispatchReq, catalog: Catalog, + tensor_pool: TensorPool, worker_factory: Callable[[], Worker], inner: Optional[Worker], ) -> tuple[dispatch_pb2.DispatchResp, Optional[Worker]]: + run_inner = inner + ephemeral_inner = False try: - if inner is None: - inner = worker_factory() - for cid, version in catalog.refs(): - fn = catalog.lookup(cid, version) - if fn is not None: - inner._callable_registry[int(cid)] = fn - inner.init() orch_fn = catalog.lookup(req.callable_id, req.callable_version) if orch_fn is None: return ( @@ -180,9 +315,30 @@ def _backend_dispatch( inner, ) cfg = decode_config(req.config_blob) - args = decode_task_args(req.tensor_args, req.scalar_args) - inner.run(orch_fn, args, cfg) - return dispatch_pb2.DispatchResp(task_id=req.task_id, error_code=0), inner + keepalive = [] + writebacks = [] + if req.tensor_refs: + args, keepalive, writebacks = decode_task_args_with_tensor_refs_and_writebacks( + req.tensor_refs, + req.scalar_args, + tensor_pool, + ) + else: + args = decode_task_args(req.tensor_args, req.scalar_args) + if req.tensor_refs: + run_inner = worker_factory() + ephemeral_inner = True + _install_catalog(run_inner, catalog) + run_inner.init() + elif run_inner is None: + run_inner = worker_factory() + _install_catalog(run_inner, catalog) + run_inner.init() + inner = run_inner + run_inner.run(orch_fn, args, cfg) + output_tensors = encode_output_tensor_refs(args, tensor_pool, writebacks) + keepalive.clear() + return dispatch_pb2.DispatchResp(task_id=req.task_id, error_code=0, output_tensors=output_tensors), inner except Exception as e: # noqa: BLE001 return ( dispatch_pb2.DispatchResp( @@ -193,6 +349,27 @@ def _backend_dispatch( ), inner, ) + finally: + if ephemeral_inner and run_inner is not None: + run_inner.close() + + +def _install_catalog(worker: Worker, catalog: Catalog) -> None: + for cid, version in catalog.refs(): + fn = catalog.lookup(cid, version) + if fn is not None: + worker._callable_registry[int(cid)] = fn + + +def _join_chunks(chunks: list[dispatch_pb2.TensorChunk]) -> bytes: + total = 0 + for chunk in chunks: + total = max(total, int(chunk.offset) + len(chunk.data)) + out = bytearray(total) + for chunk in chunks: + offset = int(chunk.offset) + out[offset : offset + len(chunk.data)] = chunk.data + return bytes(out) def main(argv: Optional[list[str]] = None) -> int: @@ -200,12 +377,13 @@ def main(argv: Optional[list[str]] = None) -> int: parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", type=int, default=5050) parser.add_argument("--num-sub-workers", type=int, default=1) + parser.add_argument("--tensor-transport", default=None, choices=("grpc", "rxe", "hcomm", "auto")) args = parser.parse_args(argv) def make_worker() -> Worker: return Worker(level=3, num_sub_workers=args.num_sub_workers) - daemon = L3Daemon(args.port, make_worker) + daemon = L3Daemon(args.port, make_worker, tensor_transport=args.tensor_transport) try: daemon.serve_forever(args.host) except KeyboardInterrupt: diff --git a/python/simpler/distributed/proto/dispatch.proto b/python/simpler/distributed/proto/dispatch.proto index bec873e8c..60f9513c3 100644 --- a/python/simpler/distributed/proto/dispatch.proto +++ b/python/simpler/distributed/proto/dispatch.proto @@ -19,6 +19,12 @@ message ContinuousTensorRef { message TensorHandle { string node_id = 1; uint64 handle_id = 2; + uint64 remote_addr = 3; + uint32 rkey = 4; + uint64 nbytes = 5; + uint64 lease_deadline_unix_ms = 6; + string transport = 7; + bytes transport_desc = 8; } message TensorRef { @@ -38,6 +44,23 @@ message TensorChunk { bool last = 4; } +message TensorAllocReq { + uint64 nbytes = 1; + uint64 ttl_ms = 2; + repeated int64 shape = 10; + int32 dtype = 11; + int32 tag = 12; +} + +message TensorFreeReq { + TensorHandle handle = 1; +} + +message TensorRefreshReq { + TensorHandle handle = 1; + uint64 ttl_ms = 2; +} + message CallConfigWire { int32 block_dim = 1; int32 aicpu_thread_num = 2; @@ -87,6 +110,9 @@ service Catalog { } service TensorPool { + rpc AllocTensor(TensorAllocReq) returns (TensorHandle); + rpc FreeTensor(TensorFreeReq) returns (Empty); + rpc RefreshTensor(TensorRefreshReq) returns (TensorHandle); rpc PullTensor(TensorHandle) returns (stream TensorChunk); rpc PushTensor(stream TensorChunk) returns (TensorHandle); } diff --git a/python/simpler/distributed/proto/dispatch_pb2.py b/python/simpler/distributed/proto/dispatch_pb2.py index 2207b5c1e..d98f3bd39 100644 --- a/python/simpler/distributed/proto/dispatch_pb2.py +++ b/python/simpler/distributed/proto/dispatch_pb2.py @@ -24,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0e\x64ispatch.proto\x12\x16simpler.distributed.v1\"\x07\n\x05\x45mpty\"%\n\x06Health\x12\n\n\x02ok\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"N\n\x13\x43ontinuousTensorRef\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x04\x12\r\n\x05shape\x18\x02 \x03(\x04\x12\r\n\x05\x64type\x18\x03 \x01(\r\x12\x0b\n\x03tag\x18\x04 \x01(\r\"2\n\x0cTensorHandle\x12\x0f\n\x07node_id\x18\x01 \x01(\t\x12\x11\n\thandle_id\x18\x02 \x01(\x04\"\x8f\x01\n\tTensorRef\x12\x15\n\x0binline_data\x18\x01 \x01(\x0cH\x00\x12\x36\n\x06handle\x18\x02 \x01(\x0b\x32$.simpler.distributed.v1.TensorHandleH\x00\x12\r\n\x05shape\x18\n \x03(\x03\x12\r\n\x05\x64type\x18\x0b \x01(\x05\x12\x0b\n\x03tag\x18\x0c \x01(\x05\x42\x08\n\x06source\"o\n\x0bTensorChunk\x12\x34\n\x06handle\x18\x01 \x01(\x0b\x32$.simpler.distributed.v1.TensorHandle\x12\x0e\n\x06offset\x18\x02 \x01(\x04\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12\x0c\n\x04last\x18\x04 \x01(\x08\"\xa0\x01\n\x0e\x43\x61llConfigWire\x12\x11\n\tblock_dim\x18\x01 \x01(\x05\x12\x18\n\x10\x61icpu_thread_num\x18\x02 \x01(\x05\x12\x1a\n\x12\x65nable_l2_swimlane\x18\x03 \x01(\x08\x12\x1a\n\x12\x65nable_dump_tensor\x18\x04 \x01(\x08\x12\x12\n\nenable_pmu\x18\x05 \x01(\x05\x12\x15\n\routput_prefix\x18\x06 \x01(\t\"\xf1\x01\n\x0b\x44ispatchReq\x12\x0f\n\x07task_id\x18\x01 \x01(\x04\x12\x13\n\x0b\x63\x61llable_id\x18\x02 \x01(\x04\x12\x18\n\x10\x63\x61llable_version\x18\x03 \x01(\x04\x12\x13\n\x0b\x63onfig_blob\x18\x04 \x01(\x0c\x12\x13\n\x0bscalar_args\x18\x05 \x03(\x04\x12@\n\x0btensor_args\x18\x06 \x03(\x0b\x32+.simpler.distributed.v1.ContinuousTensorRef\x12\x36\n\x0btensor_refs\x18\x07 \x03(\x0b\x32!.simpler.distributed.v1.TensorRef\"\x9b\x01\n\x0c\x44ispatchResp\x12\x0f\n\x07task_id\x18\x01 \x01(\x04\x12\x12\n\nerror_code\x18\x02 \x01(\x05\x12\x11\n\terror_msg\x18\x03 \x01(\t\x12\x18\n\x10remote_traceback\x18\x04 \x03(\t\x12\x39\n\x0eoutput_tensors\x18\x05 \x03(\x0b\x32!.simpler.distributed.v1.TensorRef\"3\n\x0b\x43\x61llableRef\x12\x13\n\x0b\x63\x61llable_id\x18\x01 \x01(\x04\x12\x0f\n\x07version\x18\x02 \x01(\x04\"H\n\x0f\x43\x61llablePayload\x12\x13\n\x0b\x63\x61llable_id\x18\x01 \x01(\x04\x12\x0f\n\x07version\x18\x02 \x01(\x04\x12\x0f\n\x07pickled\x18\x03 \x01(\x0c\x32\xad\x01\n\x08L3Worker\x12U\n\x08\x44ispatch\x12#.simpler.distributed.v1.DispatchReq\x1a$.simpler.distributed.v1.DispatchResp\x12J\n\tHeartbeat\x12\x1d.simpler.distributed.v1.Empty\x1a\x1e.simpler.distributed.v1.Health2\xbf\x01\n\x07\x43\x61talog\x12\\\n\x0cPullCallable\x12#.simpler.distributed.v1.CallableRef\x1a\'.simpler.distributed.v1.CallablePayload\x12V\n\x0cPushCallable\x12\'.simpler.distributed.v1.CallablePayload\x1a\x1d.simpler.distributed.v1.Empty2\xc2\x01\n\nTensorPool\x12Y\n\nPullTensor\x12$.simpler.distributed.v1.TensorHandle\x1a#.simpler.distributed.v1.TensorChunk0\x01\x12Y\n\nPushTensor\x12#.simpler.distributed.v1.TensorChunk\x1a$.simpler.distributed.v1.TensorHandle(\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0e\x64ispatch.proto\x12\x16simpler.distributed.v1\"\x07\n\x05\x45mpty\"%\n\x06Health\x12\n\n\x02ok\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"N\n\x13\x43ontinuousTensorRef\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x04\x12\r\n\x05shape\x18\x02 \x03(\x04\x12\r\n\x05\x64type\x18\x03 \x01(\r\x12\x0b\n\x03tag\x18\x04 \x01(\r\"\xb0\x01\n\x0cTensorHandle\x12\x0f\n\x07node_id\x18\x01 \x01(\t\x12\x11\n\thandle_id\x18\x02 \x01(\x04\x12\x13\n\x0bremote_addr\x18\x03 \x01(\x04\x12\x0c\n\x04rkey\x18\x04 \x01(\r\x12\x0e\n\x06nbytes\x18\x05 \x01(\x04\x12\x1e\n\x16lease_deadline_unix_ms\x18\x06 \x01(\x04\x12\x11\n\ttransport\x18\x07 \x01(\t\x12\x16\n\x0etransport_desc\x18\x08 \x01(\x0c\"\x8f\x01\n\tTensorRef\x12\x15\n\x0binline_data\x18\x01 \x01(\x0cH\x00\x12\x36\n\x06handle\x18\x02 \x01(\x0b\x32$.simpler.distributed.v1.TensorHandleH\x00\x12\r\n\x05shape\x18\n \x03(\x03\x12\r\n\x05\x64type\x18\x0b \x01(\x05\x12\x0b\n\x03tag\x18\x0c \x01(\x05\x42\x08\n\x06source\"o\n\x0bTensorChunk\x12\x34\n\x06handle\x18\x01 \x01(\x0b\x32$.simpler.distributed.v1.TensorHandle\x12\x0e\n\x06offset\x18\x02 \x01(\x04\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12\x0c\n\x04last\x18\x04 \x01(\x08\"[\n\x0eTensorAllocReq\x12\x0e\n\x06nbytes\x18\x01 \x01(\x04\x12\x0e\n\x06ttl_ms\x18\x02 \x01(\x04\x12\r\n\x05shape\x18\n \x03(\x03\x12\r\n\x05\x64type\x18\x0b \x01(\x05\x12\x0b\n\x03tag\x18\x0c \x01(\x05\"E\n\rTensorFreeReq\x12\x34\n\x06handle\x18\x01 \x01(\x0b\x32$.simpler.distributed.v1.TensorHandle\"X\n\x10TensorRefreshReq\x12\x34\n\x06handle\x18\x01 \x01(\x0b\x32$.simpler.distributed.v1.TensorHandle\x12\x0e\n\x06ttl_ms\x18\x02 \x01(\x04\"\xa0\x01\n\x0e\x43\x61llConfigWire\x12\x11\n\tblock_dim\x18\x01 \x01(\x05\x12\x18\n\x10\x61icpu_thread_num\x18\x02 \x01(\x05\x12\x1a\n\x12\x65nable_l2_swimlane\x18\x03 \x01(\x08\x12\x1a\n\x12\x65nable_dump_tensor\x18\x04 \x01(\x08\x12\x12\n\nenable_pmu\x18\x05 \x01(\x05\x12\x15\n\routput_prefix\x18\x06 \x01(\t\"\xf1\x01\n\x0b\x44ispatchReq\x12\x0f\n\x07task_id\x18\x01 \x01(\x04\x12\x13\n\x0b\x63\x61llable_id\x18\x02 \x01(\x04\x12\x18\n\x10\x63\x61llable_version\x18\x03 \x01(\x04\x12\x13\n\x0b\x63onfig_blob\x18\x04 \x01(\x0c\x12\x13\n\x0bscalar_args\x18\x05 \x03(\x04\x12@\n\x0btensor_args\x18\x06 \x03(\x0b\x32+.simpler.distributed.v1.ContinuousTensorRef\x12\x36\n\x0btensor_refs\x18\x07 \x03(\x0b\x32!.simpler.distributed.v1.TensorRef\"\x9b\x01\n\x0c\x44ispatchResp\x12\x0f\n\x07task_id\x18\x01 \x01(\x04\x12\x12\n\nerror_code\x18\x02 \x01(\x05\x12\x11\n\terror_msg\x18\x03 \x01(\t\x12\x18\n\x10remote_traceback\x18\x04 \x03(\t\x12\x39\n\x0eoutput_tensors\x18\x05 \x03(\x0b\x32!.simpler.distributed.v1.TensorRef\"3\n\x0b\x43\x61llableRef\x12\x13\n\x0b\x63\x61llable_id\x18\x01 \x01(\x04\x12\x0f\n\x07version\x18\x02 \x01(\x04\"H\n\x0f\x43\x61llablePayload\x12\x13\n\x0b\x63\x61llable_id\x18\x01 \x01(\x04\x12\x0f\n\x07version\x18\x02 \x01(\x04\x12\x0f\n\x07pickled\x18\x03 \x01(\x0c\x32\xad\x01\n\x08L3Worker\x12U\n\x08\x44ispatch\x12#.simpler.distributed.v1.DispatchReq\x1a$.simpler.distributed.v1.DispatchResp\x12J\n\tHeartbeat\x12\x1d.simpler.distributed.v1.Empty\x1a\x1e.simpler.distributed.v1.Health2\xbf\x01\n\x07\x43\x61talog\x12\\\n\x0cPullCallable\x12#.simpler.distributed.v1.CallableRef\x1a\'.simpler.distributed.v1.CallablePayload\x12V\n\x0cPushCallable\x12\'.simpler.distributed.v1.CallablePayload\x1a\x1d.simpler.distributed.v1.Empty2\xd4\x03\n\nTensorPool\x12[\n\x0b\x41llocTensor\x12&.simpler.distributed.v1.TensorAllocReq\x1a$.simpler.distributed.v1.TensorHandle\x12R\n\nFreeTensor\x12%.simpler.distributed.v1.TensorFreeReq\x1a\x1d.simpler.distributed.v1.Empty\x12_\n\rRefreshTensor\x12(.simpler.distributed.v1.TensorRefreshReq\x1a$.simpler.distributed.v1.TensorHandle\x12Y\n\nPullTensor\x12$.simpler.distributed.v1.TensorHandle\x1a#.simpler.distributed.v1.TensorChunk0\x01\x12Y\n\nPushTensor\x12#.simpler.distributed.v1.TensorChunk\x1a$.simpler.distributed.v1.TensorHandle(\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -37,26 +37,32 @@ _globals['_HEALTH']._serialized_end=88 _globals['_CONTINUOUSTENSORREF']._serialized_start=90 _globals['_CONTINUOUSTENSORREF']._serialized_end=168 - _globals['_TENSORHANDLE']._serialized_start=170 - _globals['_TENSORHANDLE']._serialized_end=220 - _globals['_TENSORREF']._serialized_start=223 - _globals['_TENSORREF']._serialized_end=366 - _globals['_TENSORCHUNK']._serialized_start=368 - _globals['_TENSORCHUNK']._serialized_end=479 - _globals['_CALLCONFIGWIRE']._serialized_start=482 - _globals['_CALLCONFIGWIRE']._serialized_end=642 - _globals['_DISPATCHREQ']._serialized_start=645 - _globals['_DISPATCHREQ']._serialized_end=886 - _globals['_DISPATCHRESP']._serialized_start=889 - _globals['_DISPATCHRESP']._serialized_end=1044 - _globals['_CALLABLEREF']._serialized_start=1046 - _globals['_CALLABLEREF']._serialized_end=1097 - _globals['_CALLABLEPAYLOAD']._serialized_start=1099 - _globals['_CALLABLEPAYLOAD']._serialized_end=1171 - _globals['_L3WORKER']._serialized_start=1174 - _globals['_L3WORKER']._serialized_end=1347 - _globals['_CATALOG']._serialized_start=1350 - _globals['_CATALOG']._serialized_end=1541 - _globals['_TENSORPOOL']._serialized_start=1544 - _globals['_TENSORPOOL']._serialized_end=1738 + _globals['_TENSORHANDLE']._serialized_start=171 + _globals['_TENSORHANDLE']._serialized_end=347 + _globals['_TENSORREF']._serialized_start=350 + _globals['_TENSORREF']._serialized_end=493 + _globals['_TENSORCHUNK']._serialized_start=495 + _globals['_TENSORCHUNK']._serialized_end=606 + _globals['_TENSORALLOCREQ']._serialized_start=608 + _globals['_TENSORALLOCREQ']._serialized_end=699 + _globals['_TENSORFREEREQ']._serialized_start=701 + _globals['_TENSORFREEREQ']._serialized_end=770 + _globals['_TENSORREFRESHREQ']._serialized_start=772 + _globals['_TENSORREFRESHREQ']._serialized_end=860 + _globals['_CALLCONFIGWIRE']._serialized_start=863 + _globals['_CALLCONFIGWIRE']._serialized_end=1023 + _globals['_DISPATCHREQ']._serialized_start=1026 + _globals['_DISPATCHREQ']._serialized_end=1267 + _globals['_DISPATCHRESP']._serialized_start=1270 + _globals['_DISPATCHRESP']._serialized_end=1425 + _globals['_CALLABLEREF']._serialized_start=1427 + _globals['_CALLABLEREF']._serialized_end=1478 + _globals['_CALLABLEPAYLOAD']._serialized_start=1480 + _globals['_CALLABLEPAYLOAD']._serialized_end=1552 + _globals['_L3WORKER']._serialized_start=1555 + _globals['_L3WORKER']._serialized_end=1728 + _globals['_CATALOG']._serialized_start=1731 + _globals['_CATALOG']._serialized_end=1922 + _globals['_TENSORPOOL']._serialized_start=1925 + _globals['_TENSORPOOL']._serialized_end=2393 # @@protoc_insertion_point(module_scope) diff --git a/python/simpler/distributed/proto/dispatch_pb2_grpc.py b/python/simpler/distributed/proto/dispatch_pb2_grpc.py index f0ef4284c..32913b551 100644 --- a/python/simpler/distributed/proto/dispatch_pb2_grpc.py +++ b/python/simpler/distributed/proto/dispatch_pb2_grpc.py @@ -264,6 +264,21 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ + self.AllocTensor = channel.unary_unary( + '/simpler.distributed.v1.TensorPool/AllocTensor', + request_serializer=dispatch__pb2.TensorAllocReq.SerializeToString, + response_deserializer=dispatch__pb2.TensorHandle.FromString, + _registered_method=True) + self.FreeTensor = channel.unary_unary( + '/simpler.distributed.v1.TensorPool/FreeTensor', + request_serializer=dispatch__pb2.TensorFreeReq.SerializeToString, + response_deserializer=dispatch__pb2.Empty.FromString, + _registered_method=True) + self.RefreshTensor = channel.unary_unary( + '/simpler.distributed.v1.TensorPool/RefreshTensor', + request_serializer=dispatch__pb2.TensorRefreshReq.SerializeToString, + response_deserializer=dispatch__pb2.TensorHandle.FromString, + _registered_method=True) self.PullTensor = channel.unary_stream( '/simpler.distributed.v1.TensorPool/PullTensor', request_serializer=dispatch__pb2.TensorHandle.SerializeToString, @@ -279,6 +294,24 @@ def __init__(self, channel): class TensorPoolServicer(object): """Missing associated documentation comment in .proto file.""" + def AllocTensor(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def FreeTensor(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def RefreshTensor(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def PullTensor(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -294,6 +327,21 @@ def PushTensor(self, request_iterator, context): def add_TensorPoolServicer_to_server(servicer, server): rpc_method_handlers = { + 'AllocTensor': grpc.unary_unary_rpc_method_handler( + servicer.AllocTensor, + request_deserializer=dispatch__pb2.TensorAllocReq.FromString, + response_serializer=dispatch__pb2.TensorHandle.SerializeToString, + ), + 'FreeTensor': grpc.unary_unary_rpc_method_handler( + servicer.FreeTensor, + request_deserializer=dispatch__pb2.TensorFreeReq.FromString, + response_serializer=dispatch__pb2.Empty.SerializeToString, + ), + 'RefreshTensor': grpc.unary_unary_rpc_method_handler( + servicer.RefreshTensor, + request_deserializer=dispatch__pb2.TensorRefreshReq.FromString, + response_serializer=dispatch__pb2.TensorHandle.SerializeToString, + ), 'PullTensor': grpc.unary_stream_rpc_method_handler( servicer.PullTensor, request_deserializer=dispatch__pb2.TensorHandle.FromString, @@ -315,6 +363,87 @@ def add_TensorPoolServicer_to_server(servicer, server): class TensorPool(object): """Missing associated documentation comment in .proto file.""" + @staticmethod + def AllocTensor(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/simpler.distributed.v1.TensorPool/AllocTensor', + dispatch__pb2.TensorAllocReq.SerializeToString, + dispatch__pb2.TensorHandle.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def FreeTensor(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/simpler.distributed.v1.TensorPool/FreeTensor', + dispatch__pb2.TensorFreeReq.SerializeToString, + dispatch__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def RefreshTensor(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/simpler.distributed.v1.TensorPool/RefreshTensor', + dispatch__pb2.TensorRefreshReq.SerializeToString, + dispatch__pb2.TensorHandle.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + @staticmethod def PullTensor(request, target, diff --git a/python/simpler/distributed/remote_proxy.py b/python/simpler/distributed/remote_proxy.py index 6965e5634..4f01ec29c 100644 --- a/python/simpler/distributed/remote_proxy.py +++ b/python/simpler/distributed/remote_proxy.py @@ -2,23 +2,49 @@ from __future__ import annotations +import ctypes import itertools +import os import threading import time +import uuid +from dataclasses import dataclass from typing import Optional +import grpc + from simpler.task_interface import CallConfig, TaskArgs from .catalog import Catalog from .proto import dispatch_pb2 from .rpc import RpcClient, RpcError from .serialization import encode_config, encode_task_args +from .tensor_pool import DEFAULT_INLINE_THRESHOLD +from .transport_backend import ( + HcommDataPlaneClient, + RxeDataPlaneClient, + RxeRuntime, + TransportBackendError, + TransportUnavailable, + _encode_rxe_desc, +) class RemoteUnavailable(RuntimeError): pass +@dataclass +class _LocalOutputRegion: + handle: dispatch_pb2.TensorHandle + runtime: RxeRuntime + server_handle: int + + def close(self) -> None: + self.runtime.server_stop(self.server_handle) + self.server_handle = 0 + + class RemoteWorkerProxy: """Synchronous L4-side stub for one remote L3 worker.""" @@ -31,6 +57,11 @@ def __init__( heartbeat_timeout: float = 1.0, heartbeat_interval: float = 5.0, heartbeat_failures: int = 3, + tensor_inline_threshold: int = DEFAULT_INLINE_THRESHOLD, + tensor_chunk_size: int = 1024 * 1024, + tensor_transport: Optional[str] = None, + hcomm_client: Optional[HcommDataPlaneClient] = None, + rxe_client: Optional[RxeDataPlaneClient] = None, ) -> None: self.endpoint = endpoint self._client = RpcClient(endpoint) @@ -39,6 +70,12 @@ def __init__( self._heartbeat_timeout = float(heartbeat_timeout) self._heartbeat_interval = float(heartbeat_interval) self._heartbeat_failures = int(heartbeat_failures) + self._tensor_inline_threshold = int(tensor_inline_threshold) + self._tensor_chunk_size = int(tensor_chunk_size) + self._tensor_transport = (tensor_transport or os.getenv("SIMPLER_TENSOR_TRANSPORT", "grpc")).lower() + self._hcomm_client = hcomm_client + self._rxe_client = rxe_client + self._local_node_id = f"l4-rxe-{os.getpid()}-{uuid.uuid4().hex}" self._task_ids = itertools.count(1) self._available = True self._closed = threading.Event() @@ -99,6 +136,7 @@ def dispatch(self, callable_id: int, args: Optional[TaskArgs], cfg: Optional[Cal raise RemoteUnavailable(f"remote {self.endpoint} is unavailable") config = cfg if cfg is not None else CallConfig() tensor_args, scalar_args = encode_task_args(args) + tensor_refs, remote_handles, local_output_regions = self._stage_tensor_args(args) version = self._catalog.refs_by_id().get(int(callable_id), 0) req = dispatch_pb2.DispatchReq( task_id=next(self._task_ids), @@ -106,26 +144,329 @@ def dispatch(self, callable_id: int, args: Optional[TaskArgs], cfg: Optional[Cal callable_version=int(version), config_blob=encode_config(config), scalar_args=scalar_args, - tensor_args=tensor_args, + tensor_args=[] if tensor_refs else tensor_args, + tensor_refs=tensor_refs, ) try: resp = self._client.dispatch(req, self._timeout) except RpcError as e: self._available = False + self._free_remote_handles(remote_handles) + self._close_local_output_regions(local_output_regions) raise RemoteUnavailable(f"remote {self.endpoint} dispatch RPC failed: {e}") from e if resp.error_code != 0: + self._free_remote_handles(remote_handles) + self._close_local_output_regions(local_output_regions) detail = resp.error_msg if resp.remote_traceback: detail = detail + "\nremote traceback:\n" + "\n".join(resp.remote_traceback) raise RuntimeError(f"remote dispatch failed on {self.endpoint}: {detail}") + try: + self._write_output_tensors(args, resp.output_tensors) + finally: + self._free_remote_handles(remote_handles) + self._free_response_handles(resp.output_tensors) + self._close_local_output_regions(local_output_regions) + + def _stage_tensor_args( + self, + args: Optional[TaskArgs], + ) -> tuple[list[dispatch_pb2.TensorRef], list[dispatch_pb2.TensorHandle], list[_LocalOutputRegion]]: + if args is None or args.tensor_count() == 0: + return [], [], [] + refs = [] + remote_handles = [] + local_output_regions = [] + try: + for i in range(args.tensor_count()): + tensor = args.tensor(i) + tag = args.tag(i) + nbytes = _tensor_nbytes(tensor) + shape = [int(x) for x in tensor.shapes[: int(tensor.ndims)]] + dtype = int(tensor.dtype.value) + tag_value = int(tag.value) + if self._should_stage_local_output(tag, nbytes): + try: + ref, region = self._stage_local_output_tensor(tensor, nbytes, shape, dtype, tag_value) + except (TransportBackendError, TransportUnavailable): + if self._tensor_transport != "auto": + raise + else: + refs.append(ref) + local_output_regions.append(region) + continue + data = ctypes.string_at(int(tensor.data), nbytes) if nbytes else b"" + if nbytes <= self._tensor_inline_threshold: + refs.append( + dispatch_pb2.TensorRef( + inline_data=data, + shape=shape, + dtype=dtype, + tag=tag_value, + ) + ) + continue + handle = self._alloc_remote_tensor(nbytes, shape, dtype, tag_value) + remote_handles.append(handle) + self._push_remote_tensor(handle, data) + refs.append( + dispatch_pb2.TensorRef( + handle=handle, + shape=shape, + dtype=dtype, + tag=tag_value, + ) + ) + except Exception: + self._free_remote_handles(remote_handles) + self._close_local_output_regions(local_output_regions) + raise + return refs, remote_handles, local_output_regions + + def _should_stage_local_output(self, tag, nbytes: int) -> bool: # noqa: ANN001 + return ( + self._tensor_transport in {"rxe", "auto"} + and getattr(tag, "name", "") in {"OUTPUT", "OUTPUT_EXISTING"} + and int(nbytes) > self._tensor_inline_threshold + ) + + def _stage_local_output_tensor( + self, + tensor, # noqa: ANN001 + nbytes: int, + shape: list[int], + dtype: int, + tag: int, + ) -> tuple[dispatch_pb2.TensorRef, _LocalOutputRegion]: + runtime = RxeRuntime.from_env(required=True) + desc, server_handle = runtime.server_start(int(tensor.data), int(nbytes)) + transport_desc = _encode_rxe_desc(desc, runtime.device or "", runtime.gid_index) + handle = dispatch_pb2.TensorHandle( + node_id=self._local_node_id, + handle_id=0, + remote_addr=int(desc.addr), + rkey=int(desc.rkey), + nbytes=int(nbytes), + transport="rxe", + transport_desc=transport_desc, + ) + ref = dispatch_pb2.TensorRef(handle=handle, shape=shape, dtype=int(dtype), tag=int(tag)) + return ref, _LocalOutputRegion(handle=handle, runtime=runtime, server_handle=server_handle) + + def _alloc_remote_tensor( + self, + nbytes: int, + shape: list[int], + dtype: int, + tag: int, + ) -> dispatch_pb2.TensorHandle: + req = dispatch_pb2.TensorAllocReq(nbytes=int(nbytes), shape=shape, dtype=int(dtype), tag=int(tag)) + try: + return self._client.tensor_pool.AllocTensor(req, timeout=self._timeout) + except grpc.RpcError as e: + self._available = False + raise RemoteUnavailable(f"remote {self.endpoint} tensor alloc failed: {e.details() or e}") from e + + def _push_remote_tensor(self, handle: dispatch_pb2.TensorHandle, data: bytes) -> None: + if self._should_use_rxe(handle): + self._push_remote_tensor_rxe(handle, data) + return + if self._should_use_hcomm(handle): + self._push_remote_tensor_hcomm(handle, data) + return + self._push_remote_tensor_grpc(handle, data) + + def _should_use_rxe(self, handle: dispatch_pb2.TensorHandle) -> bool: + return ( + self._tensor_transport in {"rxe", "auto"} + and handle.transport == "rxe" + and int(handle.nbytes) > 0 + ) + + def _should_use_hcomm(self, handle: dispatch_pb2.TensorHandle) -> bool: + return ( + self._tensor_transport in {"hcomm", "auto"} + and handle.transport == "hcomm" + and int(handle.nbytes) > 0 + ) + + def _push_remote_tensor_rxe(self, handle: dispatch_pb2.TensorHandle, data: bytes) -> None: + if len(data) != int(handle.nbytes): + raise ValueError(f"RXE tensor push size mismatch: data={len(data)}, handle={handle.nbytes}") + client = self._rxe_client or RxeDataPlaneClient.from_env() + if self._rxe_client is None: + self._rxe_client = client + local = ctypes.create_string_buffer(data) + local_addr = ctypes.addressof(local) + try: + client.write_handle(handle, local_addr, len(data)) + client.fence() + self._client.tensor_pool.RefreshTensor( + dispatch_pb2.TensorRefreshReq(handle=handle), + timeout=self._timeout, + ) + except (TransportBackendError, TransportUnavailable) as e: + if self._tensor_transport == "auto": + self._push_remote_tensor_grpc(handle, data) + return + self._available = False + raise RemoteUnavailable(f"remote {self.endpoint} RXE tensor push unavailable: {e}") from e + except grpc.RpcError as e: + self._available = False + raise RemoteUnavailable(f"remote {self.endpoint} tensor refresh failed: {e.details() or e}") from e + + def _push_remote_tensor_hcomm(self, handle: dispatch_pb2.TensorHandle, data: bytes) -> None: + if len(data) != int(handle.nbytes): + raise ValueError(f"HCOMM tensor push size mismatch: data={len(data)}, handle={handle.nbytes}") + client = self._hcomm_client or HcommDataPlaneClient.from_env() + if self._hcomm_client is None: + self._hcomm_client = client + local = ctypes.create_string_buffer(data) + local_addr = ctypes.addressof(local) + try: + if hasattr(client, "write_handle"): + client.write_handle(handle, local_addr, len(data)) + else: + client.write_with_notify(int(handle.remote_addr), local_addr, len(data)) + client.fence() + self._client.tensor_pool.RefreshTensor( + dispatch_pb2.TensorRefreshReq(handle=handle), + timeout=self._timeout, + ) + except (TransportBackendError, TransportUnavailable) as e: + if self._tensor_transport == "auto": + self._push_remote_tensor_grpc(handle, data) + return + self._available = False + raise RemoteUnavailable(f"remote {self.endpoint} HCOMM tensor push unavailable: {e}") from e + except grpc.RpcError as e: + self._available = False + raise RemoteUnavailable(f"remote {self.endpoint} tensor refresh failed: {e.details() or e}") from e + + def _push_remote_tensor_grpc(self, handle: dispatch_pb2.TensorHandle, data: bytes) -> None: + def chunks(): + if not data: + yield dispatch_pb2.TensorChunk(handle=handle, offset=0, data=b"", last=True) + return + for offset in range(0, len(data), self._tensor_chunk_size): + chunk = data[offset : offset + self._tensor_chunk_size] + yield dispatch_pb2.TensorChunk( + handle=handle, + offset=offset, + data=chunk, + last=offset + len(chunk) >= len(data), + ) + + try: + self._client.tensor_pool.PushTensor(chunks(), timeout=self._timeout) + except grpc.RpcError as e: + self._available = False + raise RemoteUnavailable(f"remote {self.endpoint} tensor push failed: {e.details() or e}") from e + + def _free_remote_handles(self, handles: list[dispatch_pb2.TensorHandle]) -> None: + for handle in handles: + try: + self._client.tensor_pool.FreeTensor(dispatch_pb2.TensorFreeReq(handle=handle), timeout=self._timeout) + except grpc.RpcError: + pass + + def _write_output_tensors(self, args: Optional[TaskArgs], refs) -> None: # noqa: ANN001 + if args is None: + return + output_indexes = _output_tensor_indexes(args) + if len(refs) != len(output_indexes): + if refs: + raise RuntimeError( + f"remote returned {len(refs)} output tensors for {len(output_indexes)} local output tensors" + ) + return + for ref, tensor_index in zip(refs, output_indexes): + tensor = args.tensor(tensor_index) + nbytes = _tensor_nbytes(tensor) + if self._is_local_output_ack(ref, nbytes): + continue + data = self._read_tensor_ref(ref) + if len(data) != nbytes: + raise RuntimeError( + f"remote output tensor {tensor_index} has {len(data)} bytes, expected {nbytes}" + ) + if nbytes: + ctypes.memmove(int(tensor.data), data, nbytes) + + def _read_tensor_ref(self, ref: dispatch_pb2.TensorRef) -> bytes: + if ref.HasField("inline_data"): + return bytes(ref.inline_data) + if ref.HasField("handle"): + try: + chunks = self._client.tensor_pool.PullTensor(ref.handle, timeout=self._timeout) + return _join_chunks(chunks) + except grpc.RpcError as e: + self._available = False + raise RemoteUnavailable(f"remote {self.endpoint} tensor pull failed: {e.details() or e}") from e + raise RuntimeError("remote output tensor has neither inline_data nor handle") + + def _is_local_output_ack(self, ref: dispatch_pb2.TensorRef, nbytes: int) -> bool: + return ( + ref.HasField("handle") + and ref.handle.transport == "rxe" + and ref.handle.node_id == self._local_node_id + and int(ref.handle.nbytes) == int(nbytes) + ) + + def _free_response_handles(self, refs) -> None: # noqa: ANN001 + handles = [ + ref.handle + for ref in refs + if ref.HasField("handle") and ref.handle.node_id != self._local_node_id + ] + self._free_remote_handles(handles) + + def _close_local_output_regions(self, regions: list[_LocalOutputRegion]) -> None: + while regions: + region = regions.pop() + try: + region.close() + except Exception: + pass def close(self) -> None: self._closed.set() if self._heartbeat_thread is not None: self._heartbeat_thread.join(timeout=1.0) self._heartbeat_thread = None + if self._hcomm_client is not None and hasattr(self._hcomm_client, "close"): + self._hcomm_client.close() + self._hcomm_client = None + if self._rxe_client is not None and hasattr(self._rxe_client, "close"): + self._rxe_client.close() + self._rxe_client = None self._client.close() def sleep_poll_interval() -> None: time.sleep(0.0005) + + +def _tensor_nbytes(tensor) -> int: # noqa: ANN001 + nbytes = tensor.nbytes + return int(nbytes() if callable(nbytes) else nbytes) + + +def _output_tensor_indexes(args: TaskArgs) -> list[int]: + return [ + i + for i in range(args.tensor_count()) + if args.tag(i).name in {"OUTPUT", "INOUT", "OUTPUT_EXISTING"} + ] + + +def _join_chunks(chunks) -> bytes: # noqa: ANN001 + chunks = list(chunks) + total = 0 + for chunk in chunks: + total = max(total, int(chunk.offset) + len(chunk.data)) + out = bytearray(total) + for chunk in chunks: + offset = int(chunk.offset) + out[offset : offset + len(chunk.data)] = chunk.data + return bytes(out) diff --git a/python/simpler/distributed/rxe_verbs_helper.c b/python/simpler/distributed/rxe_verbs_helper.c new file mode 100644 index 000000000..1a8324869 --- /dev/null +++ b/python/simpler/distributed/rxe_verbs_helper.c @@ -0,0 +1,505 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct simpler_rxe_qp_info { + uint32_t qpn; + uint32_t psn; + uint32_t rkey; + uint64_t addr; + uint32_t size; + uint8_t gid[16]; +}; + +struct simpler_rxe_server_desc { + char ip[64]; + uint16_t port; + uint32_t rkey; + uint64_t addr; + uint32_t size; +}; + +struct simpler_rxe_handle { + struct ibv_context *ctx; + struct ibv_pd *pd; + struct ibv_cq *cq; + struct ibv_qp *qp; + struct ibv_mr *mr; + pthread_t thread; + void *addr; + size_t size; + char device[64]; + char ip[64]; + int gid_index; + int listen_fd; + int conn_fd; + int ready_fd; + uint16_t port; + volatile int stop; + volatile int rc; + char err[256]; +}; + +static void set_err(struct simpler_rxe_handle *h, const char *msg) +{ + if (h != NULL && msg != NULL) { + snprintf(h->err, sizeof(h->err), "%s: errno=%d", msg, errno); + h->rc = errno ? -errno : -1; + } +} + +static int send_all(int fd, const void *buf, size_t len) +{ + const char *p = (const char *)buf; + while (len > 0) { + ssize_t n = send(fd, p, len, 0); + if (n <= 0) { + return -1; + } + p += n; + len -= (size_t)n; + } + return 0; +} + +static int recv_all(int fd, void *buf, size_t len) +{ + char *p = (char *)buf; + while (len > 0) { + ssize_t n = recv(fd, p, len, 0); + if (n <= 0) { + return -1; + } + p += n; + len -= (size_t)n; + } + return 0; +} + +static struct ibv_context *open_device(const char *device) +{ + int num = 0; + struct ibv_device **list = ibv_get_device_list(&num); + if (list == NULL) { + return NULL; + } + struct ibv_context *ctx = NULL; + for (int i = 0; i < num; ++i) { + const char *name = ibv_get_device_name(list[i]); + if (name != NULL && strcmp(name, device) == 0) { + ctx = ibv_open_device(list[i]); + break; + } + } + ibv_free_device_list(list); + return ctx; +} + +static int gid_from_context(struct ibv_context *ctx, int gid_index, uint8_t gid[16]) +{ + union ibv_gid raw_gid; + if (ibv_query_gid(ctx, 1, gid_index, &raw_gid) != 0) { + return -1; + } + memcpy(gid, raw_gid.raw, 16); + return 0; +} + +static int modify_qp_init(struct ibv_qp *qp) +{ + struct ibv_qp_attr attr; + memset(&attr, 0, sizeof(attr)); + attr.qp_state = IBV_QPS_INIT; + attr.port_num = 1; + attr.pkey_index = 0; + attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE; + return ibv_modify_qp(qp, &attr, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS); +} + +static int modify_qp_rtr(struct ibv_qp *qp, const struct simpler_rxe_qp_info *remote, int gid_index) +{ + struct ibv_qp_attr attr; + memset(&attr, 0, sizeof(attr)); + attr.qp_state = IBV_QPS_RTR; + attr.path_mtu = IBV_MTU_1024; + attr.dest_qp_num = remote->qpn; + attr.rq_psn = remote->psn; + attr.max_dest_rd_atomic = 1; + attr.min_rnr_timer = 12; + attr.ah_attr.is_global = 1; + memcpy(attr.ah_attr.grh.dgid.raw, remote->gid, 16); + attr.ah_attr.grh.sgid_index = gid_index; + attr.ah_attr.grh.hop_limit = 1; + attr.ah_attr.dlid = 0; + attr.ah_attr.sl = 0; + attr.ah_attr.src_path_bits = 0; + attr.ah_attr.port_num = 1; + return ibv_modify_qp(qp, &attr, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | + IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); +} + +static int modify_qp_rts(struct ibv_qp *qp, uint32_t psn) +{ + struct ibv_qp_attr attr; + memset(&attr, 0, sizeof(attr)); + attr.qp_state = IBV_QPS_RTS; + attr.timeout = 14; + attr.retry_cnt = 7; + attr.rnr_retry = 7; + attr.sq_psn = psn; + attr.max_rd_atomic = 1; + return ibv_modify_qp(qp, &attr, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | + IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC); +} + +static int setup_verbs(struct simpler_rxe_handle *h) +{ + h->ctx = open_device(h->device); + if (h->ctx == NULL) { + set_err(h, "ibv_open_device failed"); + return -1; + } + h->pd = ibv_alloc_pd(h->ctx); + if (h->pd == NULL) { + set_err(h, "ibv_alloc_pd failed"); + return -1; + } + h->cq = ibv_create_cq(h->ctx, 16, NULL, NULL, 0); + if (h->cq == NULL) { + set_err(h, "ibv_create_cq failed"); + return -1; + } + struct ibv_qp_init_attr qp_init; + memset(&qp_init, 0, sizeof(qp_init)); + qp_init.send_cq = h->cq; + qp_init.recv_cq = h->cq; + qp_init.qp_type = IBV_QPT_RC; + qp_init.cap.max_send_wr = 16; + qp_init.cap.max_recv_wr = 16; + qp_init.cap.max_send_sge = 1; + qp_init.cap.max_recv_sge = 1; + h->qp = ibv_create_qp(h->pd, &qp_init); + if (h->qp == NULL) { + set_err(h, "ibv_create_qp failed"); + return -1; + } + h->mr = ibv_reg_mr(h->pd, h->addr, h->size, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + if (h->mr == NULL) { + set_err(h, "ibv_reg_mr failed"); + return -1; + } + if (modify_qp_init(h->qp) != 0) { + set_err(h, "modify_qp_init failed"); + return -1; + } + return 0; +} + +static int listen_socket(struct simpler_rxe_handle *h) +{ + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + set_err(h, "socket failed"); + return -1; + } + int yes = 1; + setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes)); + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(0); + if (inet_pton(AF_INET, h->ip, &addr.sin_addr) != 1) { + close(fd); + set_err(h, "inet_pton failed"); + return -1; + } + if (bind(fd, (struct sockaddr *)&addr, sizeof(addr)) != 0) { + close(fd); + set_err(h, "bind failed"); + return -1; + } + socklen_t len = sizeof(addr); + if (getsockname(fd, (struct sockaddr *)&addr, &len) != 0) { + close(fd); + set_err(h, "getsockname failed"); + return -1; + } + if (listen(fd, 1) != 0) { + close(fd); + set_err(h, "listen failed"); + return -1; + } + h->listen_fd = fd; + h->port = ntohs(addr.sin_port); + return 0; +} + +static void *server_main(void *arg) +{ + struct simpler_rxe_handle *h = (struct simpler_rxe_handle *)arg; + if (setup_verbs(h) != 0 || listen_socket(h) != 0) { + if (h->ready_fd >= 0) { + (void)write(h->ready_fd, "E", 1); + } + return NULL; + } + if (h->ready_fd >= 0) { + (void)write(h->ready_fd, "R", 1); + } + int fd = accept(h->listen_fd, NULL, NULL); + if (fd < 0) { + if (!h->stop) { + set_err(h, "accept failed"); + } + return NULL; + } + h->conn_fd = fd; + + struct simpler_rxe_qp_info local; + struct simpler_rxe_qp_info remote; + memset(&local, 0, sizeof(local)); + memset(&remote, 0, sizeof(remote)); + local.qpn = h->qp->qp_num; + local.psn = 0x111111; + local.rkey = h->mr->rkey; + local.addr = (uint64_t)(uintptr_t)h->addr; + local.size = (uint32_t)h->size; + if (gid_from_context(h->ctx, h->gid_index, local.gid) != 0) { + set_err(h, "ibv_query_gid failed"); + close(fd); + h->conn_fd = -1; + return NULL; + } + if (send_all(fd, &local, sizeof(local)) != 0 || recv_all(fd, &remote, sizeof(remote)) != 0) { + set_err(h, "server qp info exchange failed"); + close(fd); + h->conn_fd = -1; + return NULL; + } + if (modify_qp_rtr(h->qp, &remote, h->gid_index) != 0 || modify_qp_rts(h->qp, local.psn) != 0) { + set_err(h, "server qp transition failed"); + close(fd); + h->conn_fd = -1; + return NULL; + } + char done = 0; + if (recv_all(fd, &done, 1) != 0 || done != 'D') { + set_err(h, "server completion wait failed"); + } + close(fd); + h->conn_fd = -1; + return NULL; +} + +int simpler_rxe_server_start(const char *device, int gid_index, const char *ip, void *addr, uint64_t size, + struct simpler_rxe_server_desc *desc, void **out) +{ + if (device == NULL || ip == NULL || addr == NULL || size == 0 || desc == NULL || out == NULL) { + return -EINVAL; + } + struct simpler_rxe_handle *h = (struct simpler_rxe_handle *)calloc(1, sizeof(*h)); + if (h == NULL) { + return -ENOMEM; + } + snprintf(h->device, sizeof(h->device), "%s", device); + snprintf(h->ip, sizeof(h->ip), "%s", ip); + h->gid_index = gid_index; + h->addr = addr; + h->size = (size_t)size; + h->listen_fd = -1; + h->conn_fd = -1; + h->ready_fd = -1; + int pipefd[2]; + if (pipe(pipefd) != 0) { + free(h); + return -errno; + } + h->ready_fd = pipefd[1]; + if (pthread_create(&h->thread, NULL, server_main, h) != 0) { + int rc = -errno; + close(pipefd[0]); + close(pipefd[1]); + free(h); + return rc; + } + char ready = 0; + if (read(pipefd[0], &ready, 1) != 1 || ready != 'R') { + int rc = h->rc ? h->rc : -1; + pthread_join(h->thread, NULL); + close(pipefd[0]); + close(pipefd[1]); + free(h); + return rc; + } + close(pipefd[0]); + close(pipefd[1]); + h->ready_fd = -1; + memset(desc, 0, sizeof(*desc)); + snprintf(desc->ip, sizeof(desc->ip), "%s", h->ip); + desc->port = h->port; + desc->rkey = h->mr->rkey; + desc->addr = (uint64_t)(uintptr_t)h->addr; + desc->size = (uint32_t)h->size; + *out = h; + return 0; +} + +void simpler_rxe_server_stop(void *handle) +{ + struct simpler_rxe_handle *h = (struct simpler_rxe_handle *)handle; + if (h == NULL) { + return; + } + h->stop = 1; + if (h->listen_fd >= 0) { + shutdown(h->listen_fd, SHUT_RDWR); + close(h->listen_fd); + h->listen_fd = -1; + } + if (h->conn_fd >= 0) { + shutdown(h->conn_fd, SHUT_RDWR); + } + pthread_join(h->thread, NULL); + if (h->conn_fd >= 0) { + close(h->conn_fd); + h->conn_fd = -1; + } + if (h->mr != NULL) { + ibv_dereg_mr(h->mr); + } + if (h->qp != NULL) { + ibv_destroy_qp(h->qp); + } + if (h->cq != NULL) { + ibv_destroy_cq(h->cq); + } + if (h->pd != NULL) { + ibv_dealloc_pd(h->pd); + } + if (h->ctx != NULL) { + ibv_close_device(h->ctx); + } + free(h); +} + +int simpler_rxe_write(const char *device, int gid_index, const char *ip, uint16_t port, + const void *local_addr, uint64_t size) +{ + if (device == NULL || ip == NULL || local_addr == NULL || size == 0) { + return -EINVAL; + } + struct simpler_rxe_handle h; + memset(&h, 0, sizeof(h)); + snprintf(h.device, sizeof(h.device), "%s", device); + h.gid_index = gid_index; + h.addr = (void *)local_addr; + h.size = (size_t)size; + h.listen_fd = -1; + if (setup_verbs(&h) != 0) { + goto fail; + } + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + goto fail; + } + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + if (inet_pton(AF_INET, ip, &addr.sin_addr) != 1 || connect(fd, (struct sockaddr *)&addr, sizeof(addr)) != 0) { + close(fd); + goto fail; + } + + struct simpler_rxe_qp_info local; + struct simpler_rxe_qp_info remote; + memset(&local, 0, sizeof(local)); + memset(&remote, 0, sizeof(remote)); + if (recv_all(fd, &remote, sizeof(remote)) != 0) { + close(fd); + goto fail; + } + local.qpn = h.qp->qp_num; + local.psn = 0x222222; + local.rkey = h.mr->rkey; + local.addr = (uint64_t)(uintptr_t)local_addr; + local.size = (uint32_t)size; + if (gid_from_context(h.ctx, gid_index, local.gid) != 0) { + close(fd); + goto fail; + } + if (send_all(fd, &local, sizeof(local)) != 0) { + close(fd); + goto fail; + } + if (modify_qp_rtr(h.qp, &remote, gid_index) != 0 || modify_qp_rts(h.qp, local.psn) != 0) { + close(fd); + goto fail; + } + + struct ibv_sge sge; + memset(&sge, 0, sizeof(sge)); + sge.addr = (uint64_t)(uintptr_t)local_addr; + sge.length = (uint32_t)size; + sge.lkey = h.mr->lkey; + struct ibv_send_wr wr; + struct ibv_send_wr *bad = NULL; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = 1; + wr.sg_list = &sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_RDMA_WRITE; + wr.send_flags = IBV_SEND_SIGNALED; + wr.wr.rdma.remote_addr = remote.addr; + wr.wr.rdma.rkey = remote.rkey; + if (size > remote.size || ibv_post_send(h.qp, &wr, &bad) != 0) { + close(fd); + goto fail; + } + struct ibv_wc wc; + int polls = 0; + do { + int n = ibv_poll_cq(h.cq, 1, &wc); + if (n < 0) { + close(fd); + goto fail; + } + if (n > 0) { + break; + } + usleep(1000); + } while (++polls < 15000); + if (polls >= 15000 || wc.status != IBV_WC_SUCCESS) { + close(fd); + goto fail; + } + char done = 'D'; + (void)send_all(fd, &done, 1); + close(fd); + if (h.mr != NULL) ibv_dereg_mr(h.mr); + if (h.qp != NULL) ibv_destroy_qp(h.qp); + if (h.cq != NULL) ibv_destroy_cq(h.cq); + if (h.pd != NULL) ibv_dealloc_pd(h.pd); + if (h.ctx != NULL) ibv_close_device(h.ctx); + return 0; + +fail: + if (h.mr != NULL) ibv_dereg_mr(h.mr); + if (h.qp != NULL) ibv_destroy_qp(h.qp); + if (h.cq != NULL) ibv_destroy_cq(h.cq); + if (h.pd != NULL) ibv_dealloc_pd(h.pd); + if (h.ctx != NULL) ibv_close_device(h.ctx); + return errno ? -errno : -1; +} diff --git a/python/simpler/distributed/serialization.py b/python/simpler/distributed/serialization.py index 7c73da86b..44caadcf4 100644 --- a/python/simpler/distributed/serialization.py +++ b/python/simpler/distributed/serialization.py @@ -2,12 +2,34 @@ from __future__ import annotations +import ctypes +import mmap from collections.abc import Iterable +from dataclasses import dataclass from typing import Optional -from simpler.task_interface import CallConfig, ContinuousTensor, DataType, TaskArgs, TensorArgType +from simpler.task_interface import CallConfig, ContinuousTensor, DataType, TaskArgs, TensorArgType, get_element_size from .proto import dispatch_pb2 +from .tensor_pool import TensorPool +from .transport_backend import RxeDataPlaneClient, TransportBackendError, TransportUnavailable + +_OUTPUT_TAGS = { + TensorArgType.OUTPUT, + TensorArgType.INOUT, + TensorArgType.OUTPUT_EXISTING, +} + +_REMOTE_OUTPUT_TAGS = { + TensorArgType.OUTPUT, + TensorArgType.OUTPUT_EXISTING, +} + + +@dataclass(frozen=True) +class RemoteTensorWriteback: + tensor_index: int + handle: dispatch_pb2.TensorHandle def encode_config(config: CallConfig) -> bytes: @@ -69,3 +91,129 @@ def decode_task_args( for scalar in scalar_args: args.add_scalar(int(scalar)) return args + + +def encode_tensor_ref( + data: bytes, + *, + shape: Iterable[int], + dtype: DataType, + tag: TensorArgType, + pool: TensorPool, + ttl_ms: Optional[int] = None, + force_handle: bool = False, +) -> dispatch_pb2.TensorRef: + return pool.put_bytes( + data, + shape=shape, + dtype=int(dtype.value), + tag=int(tag.value), + ttl_ms=ttl_ms, + force_handle=force_handle, + ) + + +def decode_task_args_with_tensor_refs( + tensor_refs: Iterable[dispatch_pb2.TensorRef], + scalar_args: Iterable[int], + pool: TensorPool, +) -> tuple[TaskArgs, list[object]]: + args, keepalive, _ = decode_task_args_with_tensor_refs_and_writebacks(tensor_refs, scalar_args, pool) + return args, keepalive + + +def decode_task_args_with_tensor_refs_and_writebacks( + tensor_refs: Iterable[dispatch_pb2.TensorRef], + scalar_args: Iterable[int], + pool: TensorPool, +) -> tuple[TaskArgs, list[object], list[RemoteTensorWriteback]]: + args = TaskArgs() + keepalive: list[object] = [] + writebacks: list[RemoteTensorWriteback] = [] + for tensor_index, ref in enumerate(tensor_refs): + shape = tuple(int(x) for x in ref.shape) + dtype = DataType(int(ref.dtype)) + tag = TensorArgType(int(ref.tag)) + nbytes = _shape_nbytes(shape, dtype) + remote_output = ( + tag in _REMOTE_OUTPUT_TAGS + and ref.HasField("handle") + and ref.handle.transport == "rxe" + and ref.handle.node_id != pool.node_id + ) + data = b"" if remote_output else pool.materialize_ref(ref) + size = max(1, nbytes if remote_output else len(data)) + buf = mmap.mmap(-1, size) + if data: + buf.write(data) + else: + buf.write(b"\x00") + keepalive.append(buf) + ptr = ctypes.addressof(ctypes.c_char.from_buffer(buf)) + args.add_tensor(ContinuousTensor.make(ptr, shape, dtype), tag) + if remote_output: + writebacks.append(RemoteTensorWriteback(tensor_index=tensor_index, handle=ref.handle)) + for scalar in scalar_args: + args.add_scalar(int(scalar)) + return args, keepalive, writebacks + + +def encode_output_tensor_refs( + args: TaskArgs, + pool: TensorPool, + writebacks: Optional[Iterable[RemoteTensorWriteback]] = None, +) -> list[dispatch_pb2.TensorRef]: + refs = [] + writeback_by_index = {item.tensor_index: item for item in (writebacks or [])} + rxe_client = None + for i in range(args.tensor_count()): + tag = args.tag(i) + if tag not in _OUTPUT_TAGS: + continue + tensor = args.tensor(i) + data = ctypes.string_at(int(tensor.data), _tensor_nbytes(tensor)) + writeback = writeback_by_index.get(i) + if writeback is not None: + try: + if writeback.handle.transport != "rxe": + raise RuntimeError(f"unsupported remote output transport {writeback.handle.transport!r}") + rxe_client = rxe_client or RxeDataPlaneClient.from_env() + local = ctypes.create_string_buffer(data, len(data)) + rxe_client.write_handle(writeback.handle, ctypes.addressof(local), len(data)) + rxe_client.fence() + refs.append( + dispatch_pb2.TensorRef( + handle=writeback.handle, + shape=[int(x) for x in tensor.shapes[: int(tensor.ndims)]], + dtype=int(tensor.dtype.value), + tag=int(tag.value), + ) + ) + continue + except (RuntimeError, TransportBackendError, TransportUnavailable): + pass + refs.append( + pool.put_bytes( + data, + shape=[int(x) for x in tensor.shapes[: int(tensor.ndims)]], + dtype=int(tensor.dtype.value), + tag=int(tag.value), + ) + ) + return refs + + +def _tensor_nbytes(tensor) -> int: # noqa: ANN001 + nbytes = tensor.nbytes + return int(nbytes() if callable(nbytes) else nbytes) + + +def _shape_nbytes(shape: Iterable[int], dtype: DataType) -> int: + count = 1 + for dim in shape: + count *= int(dim) + return count * _dtype_nbytes(dtype) + + +def _dtype_nbytes(dtype: DataType) -> int: + return int(get_element_size(dtype)) diff --git a/python/simpler/distributed/tensor_pool.py b/python/simpler/distributed/tensor_pool.py index 3b4787901..50d2922dc 100644 --- a/python/simpler/distributed/tensor_pool.py +++ b/python/simpler/distributed/tensor_pool.py @@ -1,51 +1,254 @@ -"""Tensor byte pool used by distributed dispatch tensor references.""" +"""Tensor byte pool used by distributed dispatch tensor references. + +This is the Python MVP of the data-plane bridge described in +``L4_L3_data_plane_design.md``. It is intentionally a byte pool rather than an +RDMA implementation: handles, leases, alloc/free, chunked pull/push, and inline +vs handle decisions are represented explicitly so the backend can later swap the +storage implementation for SHM/RDMA/Urma without changing the control protocol. +""" from __future__ import annotations +import ctypes import itertools +import time import uuid from collections.abc import Iterable +from dataclasses import dataclass from typing import Optional import grpc from .proto import dispatch_pb2, dispatch_pb2_grpc +from .transport_backend import GrpcTensorTransport, RegisteredRegion, TensorTransportBackend + +DEFAULT_INLINE_THRESHOLD = 4 * 1024 +DEFAULT_POOL_CAPACITY = 64 * 1024 * 1024 +DEFAULT_LEASE_TTL_MS = 60_000 + + +class TensorPoolError(RuntimeError): + pass + + +class TensorPoolFull(TensorPoolError): + pass + + +@dataclass +class _Entry: + data: bytearray + nbytes: int + expires_at_ms: int + shape: tuple[int, ...] + dtype: int + tag: int + region: RegisteredRegion class TensorPool: - def __init__(self, *, node_id: Optional[str] = None, inline_threshold: int = 1024 * 1024) -> None: + def __init__( + self, + *, + node_id: Optional[str] = None, + inline_threshold: int = DEFAULT_INLINE_THRESHOLD, + capacity_bytes: int = DEFAULT_POOL_CAPACITY, + default_ttl_ms: int = DEFAULT_LEASE_TTL_MS, + transport_backend: Optional[TensorTransportBackend] = None, + ) -> None: self.node_id = node_id or str(uuid.uuid4()) self.inline_threshold = int(inline_threshold) + self.capacity_bytes = int(capacity_bytes) + self.default_ttl_ms = int(default_ttl_ms) + self.transport_backend = transport_backend or GrpcTensorTransport() self._next_id = itertools.count(1) - self._data: dict[int, bytes] = {} + self._entries: dict[int, _Entry] = {} + self._used_bytes = 0 - def put_bytes(self, data: bytes) -> dispatch_pb2.TensorRef: - data = bytes(data) - if len(data) <= self.inline_threshold: - return dispatch_pb2.TensorRef(inline_data=data) + @property + def used_bytes(self) -> int: + self.gc_expired() + return self._used_bytes + + def alloc( + self, + nbytes: int, + *, + ttl_ms: Optional[int] = None, + shape: Iterable[int] = (), + dtype: int = 0, + tag: int = 0, + ) -> dispatch_pb2.TensorHandle: + nbytes = int(nbytes) + if nbytes < 0: + raise ValueError(f"nbytes must be non-negative, got {nbytes}") + self.gc_expired() + if self._used_bytes + nbytes > self.capacity_bytes: + raise TensorPoolFull( + f"tensor pool {self.node_id} is full: requested={nbytes}, " + f"used={self._used_bytes}, capacity={self.capacity_bytes}" + ) handle_id = next(self._next_id) - self._data[handle_id] = data - return dispatch_pb2.TensorRef(handle=dispatch_pb2.TensorHandle(node_id=self.node_id, handle_id=handle_id)) + ttl = self.default_ttl_ms if ttl_ms is None or int(ttl_ms) == 0 else int(ttl_ms) + data = bytearray(nbytes) + region = self.transport_backend.register_region(data, tag=f"{self.node_id}:{handle_id}:{int(tag)}") + entry = _Entry( + data=data, + nbytes=nbytes, + expires_at_ms=_now_ms() + ttl, + shape=tuple(int(x) for x in shape), + dtype=int(dtype), + tag=int(tag), + region=region, + ) + self._entries[handle_id] = entry + self._used_bytes += nbytes + return self._make_handle(handle_id, entry) + + def free(self, handle: dispatch_pb2.TensorHandle) -> None: + handle_id = self._checked_handle_id(handle) + entry = self._entries.pop(handle_id) + self._used_bytes -= entry.nbytes + self.transport_backend.unregister_region(entry.region) + + def refresh(self, handle: dispatch_pb2.TensorHandle, ttl_ms: Optional[int] = None) -> dispatch_pb2.TensorHandle: + handle_id = self._checked_handle_id(handle) + entry = self._entries[handle_id] + ttl = self.default_ttl_ms if ttl_ms is None or int(ttl_ms) == 0 else int(ttl_ms) + entry.expires_at_ms = _now_ms() + ttl + refresh_region = getattr(self.transport_backend, "refresh_region", None) + if refresh_region is not None: + entry.region = refresh_region(entry.region, entry.data, tag=f"{self.node_id}:{handle_id}:{entry.tag}") + return self._make_handle(handle_id, entry) + + def write_bytes(self, handle: dispatch_pb2.TensorHandle, data: bytes, *, offset: int = 0) -> None: + handle_id = self._checked_handle_id(handle) + entry = self._entries[handle_id] + offset = int(offset) + data = bytes(data) + end = offset + len(data) + if offset < 0 or end > entry.nbytes: + raise ValueError(f"write out of range: offset={offset}, size={len(data)}, nbytes={entry.nbytes}") + entry.data[offset:end] = data + + def read_bytes(self, handle: dispatch_pb2.TensorHandle, *, offset: int = 0, nbytes: Optional[int] = None) -> bytes: + handle_id = self._checked_handle_id(handle) + entry = self._entries[handle_id] + offset = int(offset) + size = entry.nbytes - offset if nbytes is None else int(nbytes) + end = offset + size + if offset < 0 or size < 0 or end > entry.nbytes: + raise ValueError(f"read out of range: offset={offset}, size={size}, nbytes={entry.nbytes}") + return bytes(entry.data[offset:end]) + + def put_bytes( + self, + data: bytes, + *, + shape: Iterable[int] = (), + dtype: int = 0, + tag: int = 0, + ttl_ms: Optional[int] = None, + force_handle: bool = False, + ) -> dispatch_pb2.TensorRef: + data = bytes(data) + ref = dispatch_pb2.TensorRef(shape=[int(x) for x in shape], dtype=int(dtype), tag=int(tag)) + if not force_handle and len(data) <= self.inline_threshold: + ref.inline_data = data + return ref + handle = self.alloc(len(data), ttl_ms=ttl_ms, shape=shape, dtype=dtype, tag=tag) + self.write_bytes(handle, data) + ref.handle.CopyFrom(handle) + return ref def get_bytes(self, handle: dispatch_pb2.TensorHandle) -> bytes: - if handle.node_id != self.node_id: - raise KeyError(f"tensor handle belongs to node {handle.node_id!r}, not {self.node_id!r}") - return self._data[int(handle.handle_id)] + return self.read_bytes(handle) + + def materialize_ref(self, ref: dispatch_pb2.TensorRef) -> bytes: + if ref.HasField("inline_data"): + return bytes(ref.inline_data) + if ref.HasField("handle"): + return self.get_bytes(ref.handle) + raise TensorPoolError("TensorRef has neither inline_data nor handle") + + def gc_expired(self) -> int: + now = _now_ms() + expired = [handle_id for handle_id, entry in self._entries.items() if entry.expires_at_ms <= now] + for handle_id in expired: + entry = self._entries.pop(handle_id) + self._used_bytes -= entry.nbytes + self.transport_backend.unregister_region(entry.region) + return len(expired) + + def close(self) -> None: + for handle_id in list(self._entries): + entry = self._entries.pop(handle_id) + self._used_bytes -= entry.nbytes + self.transport_backend.unregister_region(entry.region) + close = getattr(self.transport_backend, "close", None) + if close is not None: + close() def service(self) -> "TensorPoolService": return TensorPoolService(self) + def _checked_handle_id(self, handle: dispatch_pb2.TensorHandle) -> int: + self.gc_expired() + if handle.node_id != self.node_id: + raise KeyError(f"tensor handle belongs to node {handle.node_id!r}, not {self.node_id!r}") + handle_id = int(handle.handle_id) + if handle_id not in self._entries: + raise KeyError(f"tensor handle {handle_id} is not allocated") + return handle_id + + def _make_handle(self, handle_id: int, entry: _Entry) -> dispatch_pb2.TensorHandle: + return dispatch_pb2.TensorHandle( + node_id=self.node_id, + handle_id=int(handle_id), + remote_addr=int(entry.region.remote_addr), + rkey=int(entry.region.rkey), + nbytes=entry.nbytes, + lease_deadline_unix_ms=entry.expires_at_ms, + transport=entry.region.transport, + transport_desc=entry.region.transport_desc, + ) + class TensorPoolService(dispatch_pb2_grpc.TensorPoolServicer): def __init__(self, pool: TensorPool, *, chunk_size: int = 1024 * 1024) -> None: self._pool = pool self._chunk_size = int(chunk_size) + def AllocTensor(self, request, context): # noqa: N802, ANN001 + try: + return self._pool.alloc( + request.nbytes, + ttl_ms=request.ttl_ms, + shape=request.shape, + dtype=request.dtype, + tag=request.tag, + ) + except Exception as e: # noqa: BLE001 + _abort(context, grpc.StatusCode.RESOURCE_EXHAUSTED, str(e)) + + def FreeTensor(self, request, context): # noqa: N802, ANN001 + try: + self._pool.free(request.handle) + except Exception as e: # noqa: BLE001 + _abort(context, grpc.StatusCode.NOT_FOUND, str(e)) + return dispatch_pb2.Empty() + + def RefreshTensor(self, request, context): # noqa: N802, ANN001 + try: + return self._pool.refresh(request.handle, request.ttl_ms) + except Exception as e: # noqa: BLE001 + _abort(context, grpc.StatusCode.NOT_FOUND, str(e)) + def PullTensor(self, request, context): # noqa: N802, ANN001 try: data = self._pool.get_bytes(request) except Exception as e: # noqa: BLE001 - context.abort(grpc.StatusCode.NOT_FOUND, str(e)) + _abort(context, grpc.StatusCode.NOT_FOUND, str(e)) for offset in range(0, len(data), self._chunk_size): chunk = data[offset : offset + self._chunk_size] yield dispatch_pb2.TensorChunk( @@ -58,12 +261,43 @@ def PullTensor(self, request, context): # noqa: N802, ANN001 yield dispatch_pb2.TensorChunk(handle=request, offset=0, data=b"", last=True) def PushTensor(self, request_iterator: Iterable[dispatch_pb2.TensorChunk], context): # noqa: N802, ANN001 - parts = [] - for chunk in request_iterator: - parts.append(bytes(chunk.data)) - ref = self._pool.put_bytes(b"".join(parts)) - if not ref.HasField("handle"): - handle_id = next(self._pool._next_id) - self._pool._data[handle_id] = ref.inline_data - return dispatch_pb2.TensorHandle(node_id=self._pool.node_id, handle_id=handle_id) - return ref.handle + chunks = list(request_iterator) + if not chunks: + return self._pool.alloc(0) + handle = chunks[0].handle + payload = _join_chunks(chunks) + try: + if handle.handle_id: + self._pool.write_bytes(handle, payload) + return self._pool.refresh(handle) + ref = self._pool.put_bytes(payload, force_handle=True) + return ref.handle + except Exception as e: # noqa: BLE001 + _abort(context, grpc.StatusCode.INVALID_ARGUMENT, str(e)) + + +def _join_chunks(chunks: list[dispatch_pb2.TensorChunk]) -> bytes: + total = 0 + for chunk in chunks: + total = max(total, int(chunk.offset) + len(chunk.data)) + out = bytearray(total) + for chunk in chunks: + offset = int(chunk.offset) + out[offset : offset + len(chunk.data)] = chunk.data + return bytes(out) + + +def _buffer_addr(data: bytearray) -> int: + if not data: + return 0 + return ctypes.addressof(ctypes.c_char.from_buffer(data)) + + +def _now_ms() -> int: + return int(time.time() * 1000) + + +def _abort(context, code: grpc.StatusCode, message: str): # noqa: ANN001 + if context is None: + raise TensorPoolError(message) + context.abort(code, message) diff --git a/python/simpler/distributed/transport_backend.py b/python/simpler/distributed/transport_backend.py new file mode 100644 index 000000000..0fa121399 --- /dev/null +++ b/python/simpler/distributed/transport_backend.py @@ -0,0 +1,1653 @@ +"""Optional tensor data-plane transport backends. + +The default distributed data plane still uses gRPC chunk streaming. This +module adds a narrow backend boundary for transports that can expose registered +memory to the peer, and a first HCOMM C-API facade that can be enabled on +systems where HCOMM endpoint/channel resources are already available. +""" + +from __future__ import annotations + +import ctypes +import ctypes.util +import ipaddress +import json +import os +import shlex +import struct +import subprocess +import sysconfig +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + + +class TransportBackendError(RuntimeError): + pass + + +class TransportUnavailable(TransportBackendError): + pass + + +@dataclass(frozen=True) +class RegisteredRegion: + remote_addr: int + rkey: int = 0 + transport: str = "grpc" + transport_desc: bytes = b"" + + +@dataclass(frozen=True) +class ImportedMemory: + remote_addr: int + nbytes: int + mem_type: int = 1 + + +@dataclass +class _HcommStagingBuffer: + data: Optional[bytearray] + addr: int + nbytes: int + mem_handle: int + owned: bool + + +@dataclass +class _RxeServerRegion: + region: RegisteredRegion + handle: int + + +class TensorTransportBackend: + """Storage-side transport hook used by ``TensorPool``.""" + + name = "grpc" + + @property + def available(self) -> bool: + return True + + def unavailable_reason(self) -> str: + return "" + + def register_region(self, data: bytearray, *, tag: str) -> RegisteredRegion: + return RegisteredRegion(remote_addr=_buffer_addr(data), transport=self.name) + + def unregister_region(self, region: RegisteredRegion) -> None: + del region + + +class GrpcTensorTransport(TensorTransportBackend): + """Current Python byte-pool transport.""" + + name = "grpc" + + +class RxeTensorTransport(TensorTransportBackend): + """Registers TensorPool buffers for direct RXE/ibverbs RDMA writes.""" + + name = "rxe" + + def __init__(self, runtime: Optional["RxeRuntime"] = None) -> None: + self.runtime = runtime or RxeRuntime.from_env(required=False) + self._regions: dict[int, _RxeServerRegion] = {} + + @classmethod + def from_env(cls) -> "RxeTensorTransport": + return cls() + + @property + def available(self) -> bool: + return self.runtime.available and self.runtime.device is not None and self.runtime.server_ip is not None + + def unavailable_reason(self) -> str: + if not self.runtime.available: + return self.runtime.unavailable_reason() + if self.runtime.device is None: + return "no RXE device found; set SIMPLER_RXE_DEVICE" + if self.runtime.server_ip is None: + return "no IPv4 GID found for RXE; set SIMPLER_RXE_SERVER_IP and SIMPLER_RXE_GID_INDEX" + return "" + + def register_region(self, data: bytearray, *, tag: str) -> RegisteredRegion: + del tag + if not data: + return RegisteredRegion(remote_addr=0, rkey=0, transport=self.name) + if not self.available: + raise TransportUnavailable(self.unavailable_reason()) + addr = _buffer_addr(data) + desc, server_handle = self.runtime.server_start(addr, len(data)) + payload = _encode_rxe_desc(desc, self.runtime.device or "", self.runtime.gid_index) + region = RegisteredRegion( + remote_addr=int(desc.addr), + rkey=int(desc.rkey), + transport=self.name, + transport_desc=payload, + ) + self._regions[addr] = _RxeServerRegion(region=region, handle=server_handle) + return region + + def unregister_region(self, region: RegisteredRegion) -> None: + item = self._regions.pop(int(region.remote_addr), None) + if item is None: + return + self.runtime.server_stop(item.handle) + + def refresh_region(self, region: RegisteredRegion, data: bytearray, *, tag: str) -> RegisteredRegion: + self.unregister_region(region) + return self.register_region(data, tag=tag) + + def close(self) -> None: + for item in list(self._regions.values()): + self.unregister_region(item.region) + + def __del__(self) -> None: + try: + self.close() + except Exception: + pass + + +class HcommTensorTransport(TensorTransportBackend): + """Registers TensorPool byte buffers with HCOMM when configured. + + This backend is intentionally conservative: HCOMM needs an endpoint handle + created from a rank graph/endpoint description. For this first integration + layer we accept an already-created endpoint handle via environment or + constructor and expose HCOMM memory descriptors in TensorHandle metadata. + """ + + name = "hcomm" + + def __init__(self, runtime: Optional["HcommRuntime"] = None, endpoint_handle: int = 0) -> None: + self.runtime = runtime or HcommRuntime.from_env(required=False) + self.endpoint_handle = _parse_handle(endpoint_handle or os.getenv("SIMPLER_HCOMM_ENDPOINT_HANDLE", "0")) + self._owns_endpoint = False + if self.endpoint_handle == 0 and self.runtime.available and hasattr(self.runtime, "endpoint_create"): + endpoint = EndpointSpec.from_env() + if endpoint is not None: + self.endpoint_handle = self.runtime.endpoint_create(endpoint) + self._owns_endpoint = True + self._regions: dict[int, tuple[RegisteredRegion, int, int]] = {} + + @classmethod + def from_env(cls) -> "HcommTensorTransport": + return cls() + + @property + def available(self) -> bool: + return self.runtime.available and self.endpoint_handle != 0 + + def unavailable_reason(self) -> str: + if not self.runtime.available: + return self.runtime.unavailable_reason() + if self.endpoint_handle == 0: + return "SIMPLER_HCOMM_ENDPOINT_HANDLE is not set" + return "" + + def register_region(self, data: bytearray, *, tag: str) -> RegisteredRegion: + if not data: + return RegisteredRegion(remote_addr=0, rkey=0, transport=self.name) + if not self.available: + raise TransportUnavailable(self.unavailable_reason()) + addr = _buffer_addr(data) + size = len(data) + mem_handle = self.runtime.mem_reg(self.endpoint_handle, addr, size, tag=tag) + try: + desc = self.runtime.mem_export(self.endpoint_handle, mem_handle) + except Exception: + self.runtime.mem_unreg(self.endpoint_handle, mem_handle) + raise + region = RegisteredRegion(remote_addr=addr, rkey=0, transport=self.name, transport_desc=desc) + self._regions[addr] = (region, self.endpoint_handle, mem_handle) + return region + + def unregister_region(self, region: RegisteredRegion) -> None: + item = self._regions.pop(int(region.remote_addr), None) + if item is None: + return + _, endpoint_handle, mem_handle = item + self.runtime.mem_unreg(endpoint_handle, mem_handle) + + def close(self) -> None: + for region, _, _ in list(self._regions.values()): + self.unregister_region(region) + if self._owns_endpoint and self.endpoint_handle: + self.runtime.endpoint_destroy(self.endpoint_handle) + self.endpoint_handle = 0 + self._owns_endpoint = False + + def __del__(self) -> None: + try: + self.close() + except Exception: + pass + + +class HcommDataPlaneClient: + """Thin caller for HCOMM channel primitives on the L4 side.""" + + def __init__( + self, + runtime: Optional["HcommRuntime"] = None, + *, + endpoint_handle: int = 0, + channel_handle: int = 0, + socket_handle: int = 0, + local_mem_handle: int = 0, + remote_notify_idx: int = 0, + ) -> None: + self.runtime = runtime or HcommRuntime.from_env(required=False) + self.endpoint_handle = _parse_handle(endpoint_handle or os.getenv("SIMPLER_HCOMM_ENDPOINT_HANDLE", "0")) + self.channel_handle = _parse_handle(channel_handle or os.getenv("SIMPLER_HCOMM_CHANNEL_HANDLE", "0")) + self.socket_handle = _parse_handle(socket_handle or os.getenv("SIMPLER_HCOMM_SOCKET_HANDLE", "0")) + self.local_mem_handle = _parse_handle(local_mem_handle or os.getenv("SIMPLER_HCOMM_LOCAL_MEM_HANDLE", "0")) + self.remote_notify_idx = int(os.getenv("SIMPLER_HCOMM_REMOTE_NOTIFY_IDX", str(remote_notify_idx)), 0) + self.notify_num = int(os.getenv("SIMPLER_HCOMM_NOTIFY_NUM", "1"), 0) + self.engine = int(os.getenv("SIMPLER_HCOMM_ENGINE", "0"), 0) # COMM_ENGINE_CPU + self.channel_role = _hcomm_socket_role(os.getenv("SIMPLER_HCOMM_CHANNEL_ROLE", "client")) + self.channel_port = int(os.getenv("SIMPLER_HCOMM_CHANNEL_PORT", "60001"), 0) + self._owns_channel = False + self._owns_endpoint = False + if self.endpoint_handle == 0 and self.runtime.available and hasattr(self.runtime, "endpoint_create"): + endpoint = EndpointSpec.from_env() + if endpoint is not None: + self.endpoint_handle = self.runtime.endpoint_create(endpoint) + self._owns_endpoint = True + self._imports: dict[bytes, ImportedMemory] = {} + self._staging: Optional[_HcommStagingBuffer] = None + + @classmethod + def from_env(cls) -> "HcommDataPlaneClient": + return cls() + + @property + def available(self) -> bool: + return self.runtime.available and (self.channel_handle != 0 or self._can_create_channel) + + @property + def _can_create_channel(self) -> bool: + return self.endpoint_handle != 0 + + def unavailable_reason(self) -> str: + if not self.runtime.available: + return self.runtime.unavailable_reason() + if self.channel_handle == 0 and not self._can_create_channel: + return "SIMPLER_HCOMM_CHANNEL_HANDLE is not set, and automatic channel creation requires an endpoint handle" + return "" + + def resolve_remote_memory(self, handle) -> ImportedMemory: # noqa: ANN001 + desc = bytes(getattr(handle, "transport_desc", b"")) + if desc and self.endpoint_handle: + cached = self._imports.get(desc) + if cached is not None: + return cached + imported = self.runtime.mem_import(self.endpoint_handle, desc) + self._imports[desc] = imported + return imported + return ImportedMemory(remote_addr=int(handle.remote_addr), nbytes=int(handle.nbytes)) + + def write_handle(self, handle, local_addr: int, nbytes: int) -> None: # noqa: ANN001 + remote_mem = self.resolve_remote_memory(handle) + if int(nbytes) > int(remote_mem.nbytes): + raise TransportBackendError( + f"HCOMM remote memory too small: write={int(nbytes)}, remote={int(remote_mem.nbytes)}" + ) + staging = self._stage_local(local_addr, int(nbytes)) + self.ensure_channel(handle) + self.write_with_notify(remote_mem.remote_addr, staging.addr, nbytes) + + def ensure_channel(self, handle=None) -> int: # noqa: ANN001 + if self.channel_handle: + return self.channel_handle + if not self._can_create_channel: + raise TransportUnavailable(self.unavailable_reason()) + self._ensure_staging(1) + desc = bytes(getattr(handle, "transport_desc", b"")) if handle is not None else b"" + if not desc: + raise TransportUnavailable("HCOMM channel creation requires TensorHandle.transport_desc") + remote_endpoint = _endpoint_from_transport_desc(desc) + self.channel_handle = self.runtime.channel_create( + self.endpoint_handle, + remote_endpoint=remote_endpoint, + socket_handle=self.socket_handle, + local_mem_handles=[self._local_mem_handle()], + notify_num=self.notify_num, + engine=self.engine, + role=self.channel_role, + port=self.channel_port, + ) + self._owns_channel = True + return self.channel_handle + + def remote_mems(self) -> list[ImportedMemory]: + if not self.channel_handle: + raise TransportUnavailable(self.unavailable_reason()) + return self.runtime.channel_get_remote_mem(self.channel_handle) + + def write_with_notify(self, remote_addr: int, local_addr: int, nbytes: int) -> None: + if not self.available: + raise TransportUnavailable(self.unavailable_reason()) + self.runtime.write_with_notify( + self.channel_handle, + int(remote_addr), + int(local_addr), + int(nbytes), + self.remote_notify_idx, + ) + + def fence(self) -> None: + if not self.available: + raise TransportUnavailable(self.unavailable_reason()) + self.runtime.channel_fence(self.channel_handle) + + def close(self) -> None: + for desc in list(self._imports): + if self.endpoint_handle: + try: + self.runtime.mem_unimport(self.endpoint_handle, desc) + except Exception: + pass + self._imports.pop(desc, None) + if self._owns_channel and self.channel_handle: + try: + self.runtime.channel_destroy([self.channel_handle]) + except Exception: + pass + self.channel_handle = 0 + self._owns_channel = False + if self._staging is not None and self._staging.owned: + try: + self.runtime.mem_unreg(self.endpoint_handle, self._staging.mem_handle) + except Exception: + pass + self._staging = None + if self._owns_endpoint and self.endpoint_handle: + try: + self.runtime.endpoint_destroy(self.endpoint_handle) + except Exception: + pass + self.endpoint_handle = 0 + self._owns_endpoint = False + + def _stage_local(self, local_addr: int, nbytes: int) -> "_HcommStagingBuffer": + if self.local_mem_handle: + self._staging = _HcommStagingBuffer( + data=None, + addr=int(local_addr), + nbytes=int(nbytes), + mem_handle=self.local_mem_handle, + owned=False, + ) + return self._staging + staging = self._ensure_staging(nbytes) + if nbytes: + ctypes.memmove(staging.addr, int(local_addr), int(nbytes)) + return staging + + def _ensure_staging(self, nbytes: int) -> "_HcommStagingBuffer": + if self._staging is not None and self._staging.nbytes >= int(nbytes): + return self._staging + if self._staging is not None and self._staging.owned: + self.runtime.mem_unreg(self.endpoint_handle, self._staging.mem_handle) + self._staging = None + if self.local_mem_handle: + if nbytes <= 0: + raise TransportUnavailable("external SIMPLER_HCOMM_LOCAL_MEM_HANDLE requires nonzero staged writes") + # External channel users are responsible for ensuring this handle covers + # the local source address. Automatic channel creation uses owned staging. + self._staging = _HcommStagingBuffer( + data=None, + addr=0, + nbytes=int(nbytes), + mem_handle=self.local_mem_handle, + owned=False, + ) + return self._staging + if not self.endpoint_handle: + raise TransportUnavailable("HCOMM staging requires an endpoint handle") + capacity = max(1, int(nbytes)) + data = bytearray(capacity) + addr = _buffer_addr(data) + mem_handle = self.runtime.mem_reg(self.endpoint_handle, addr, capacity, tag="simpler-l4-staging") + self._staging = _HcommStagingBuffer(data=data, addr=addr, nbytes=capacity, mem_handle=mem_handle, owned=True) + return self._staging + + def _local_mem_handle(self) -> int: + if self.local_mem_handle: + return self.local_mem_handle + if self._staging is None: + raise TransportUnavailable("HCOMM channel creation requires a staged local buffer") + return self._staging.mem_handle + + +class RxeDataPlaneClient: + """L4-side RXE/ibverbs writer for TensorHandle metadata.""" + + def __init__(self, runtime: Optional["RxeRuntime"] = None) -> None: + self.runtime = runtime or RxeRuntime.from_env(required=False) + + @classmethod + def from_env(cls) -> "RxeDataPlaneClient": + return cls() + + @property + def available(self) -> bool: + return self.runtime.available and self.runtime.device is not None + + def unavailable_reason(self) -> str: + if not self.runtime.available: + return self.runtime.unavailable_reason() + if self.runtime.device is None: + return "no RXE device found; set SIMPLER_RXE_DEVICE" + return "" + + def write_handle(self, handle, local_addr: int, nbytes: int) -> None: # noqa: ANN001 + if not self.available: + raise TransportUnavailable(self.unavailable_reason()) + desc = _decode_rxe_desc(bytes(getattr(handle, "transport_desc", b""))) + if int(nbytes) > int(desc.size): + raise TransportBackendError(f"RXE remote memory too small: write={int(nbytes)}, remote={int(desc.size)}") + self.runtime.write(desc.ip, desc.port, int(local_addr), int(nbytes), gid_index=desc.gid_index) + + def fence(self) -> None: + return + + def close(self) -> None: + return + + +class RxeRuntime: + """Runtime loader for the Simpler-owned RXE/ibverbs helper.""" + + def __init__(self, lib_path: Optional[str] = None, *, required: bool = False) -> None: + self.device = os.getenv("SIMPLER_RXE_DEVICE") or _first_existing_rxe_device() + self.gid_index: int = int(os.getenv("SIMPLER_RXE_GID_INDEX", "0"), 0) + self.server_ip = os.getenv("SIMPLER_RXE_SERVER_IP") + if self.device and (not self.server_ip or "SIMPLER_RXE_GID_INDEX" not in os.environ): + inferred = _find_rxe_ipv4_gid(self.device) + if inferred is not None: + inferred_gid_index, inferred_ip = inferred + if "SIMPLER_RXE_GID_INDEX" not in os.environ: + self.gid_index = inferred_gid_index + self.server_ip = self.server_ip or inferred_ip + self._lib_path = lib_path or os.getenv("SIMPLER_RXE_HELPER_LIB") + self._lib = None + self._load_error = "" + try: + path = Path(self._lib_path).expanduser().resolve() if self._lib_path else _build_rxe_verbs_helper() + self._preload_dependencies() + self._lib = ctypes.CDLL(str(path), mode=getattr(os, "RTLD_LOCAL", 0) | getattr(os, "RTLD_NOW", 0)) + self._lib_path = str(path) + self._bind_symbols() + except (OSError, TransportBackendError, TransportUnavailable) as e: + self._load_error = str(e) + self._lib = None + if required and self._lib is None: + raise TransportUnavailable(self.unavailable_reason()) + + @classmethod + def from_env(cls, *, required: bool = False) -> "RxeRuntime": + return cls(required=required) + + @property + def available(self) -> bool: + return self._lib is not None + + def unavailable_reason(self) -> str: + if self._lib is not None: + return "" + return self._load_error or "RXE helper is unavailable" + + def server_start(self, addr: int, size: int) -> tuple["_RxeServerDesc", int]: + self._require() + if not self.device or not self.server_ip: + raise TransportUnavailable(self.unavailable_reason() or "RXE device/server IP is not configured") + desc = _RxeServerDesc() + handle = ctypes.c_void_p() + ret = self._lib.simpler_rxe_server_start( + self.device.encode(), + ctypes.c_int(int(self.gid_index)), + self.server_ip.encode(), + ctypes.c_void_p(int(addr)), + ctypes.c_uint64(int(size)), + ctypes.byref(desc), + ctypes.byref(handle), + ) + _check_rxe(ret, "simpler_rxe_server_start") + return desc, int(handle.value or 0) + + def server_stop(self, handle: int) -> None: + if self._lib is None or not handle: + return + self._lib.simpler_rxe_server_stop(ctypes.c_void_p(int(handle))) + + def write(self, ip: str, port: int, local_addr: int, size: int, *, gid_index: Optional[int] = None) -> None: + self._require() + if not self.device: + raise TransportUnavailable("no RXE device found; set SIMPLER_RXE_DEVICE") + ret = self._lib.simpler_rxe_write( + self.device.encode(), + ctypes.c_int(int(self.gid_index if gid_index is None else gid_index)), + str(ip).encode(), + ctypes.c_uint16(int(port)), + ctypes.c_void_p(int(local_addr)), + ctypes.c_uint64(int(size)), + ) + _check_rxe(ret, "simpler_rxe_write") + + def _require(self) -> None: + if self._lib is None: + raise TransportUnavailable(self.unavailable_reason()) + + def _bind_symbols(self) -> None: + assert self._lib is not None + self._lib.simpler_rxe_server_start.argtypes = [ + ctypes.c_char_p, + ctypes.c_int, + ctypes.c_char_p, + ctypes.c_void_p, + ctypes.c_uint64, + ctypes.POINTER(_RxeServerDesc), + ctypes.POINTER(ctypes.c_void_p), + ] + self._lib.simpler_rxe_server_start.restype = ctypes.c_int + self._lib.simpler_rxe_server_stop.argtypes = [ctypes.c_void_p] + self._lib.simpler_rxe_server_stop.restype = None + self._lib.simpler_rxe_write.argtypes = [ + ctypes.c_char_p, + ctypes.c_int, + ctypes.c_char_p, + ctypes.c_uint16, + ctypes.c_void_p, + ctypes.c_uint64, + ] + self._lib.simpler_rxe_write.restype = ctypes.c_int + + def _preload_dependencies(self) -> None: + lib_dir = _rxe_lib_dir() + if lib_dir is not None: + _prepend_env_path("LD_LIBRARY_PATH", lib_dir) + driver_dir = lib_dir / "libibverbs" + if driver_dir.is_dir(): + _prepend_env_path("LD_LIBRARY_PATH", driver_dir) + lib = lib_dir / "libibverbs.so.1" + if lib.exists(): + ctypes.CDLL(str(lib), mode=getattr(os, "RTLD_GLOBAL", 0) | getattr(os, "RTLD_NOW", 0)) + + +class HcommRuntime: + """Runtime loader for the HCOMM experimental C API.""" + + def __init__(self, lib_path: Optional[str] = None, *, required: bool = False) -> None: + self._lib_path = lib_path or os.getenv("SIMPLER_HCOMM_LIB") or _find_hcomm_library() + self._lib = None + self._acl = None + self._load_error = "" + self._acl_load_error = "" + if self._lib_path: + try: + self._ensure_acl_runtime() + _preload_hcomm_dependencies(self._lib_path) + _preload_hcomm_abi_shim(self._lib_path) + self._lib = ctypes.CDLL(self._lib_path, mode=_hcomm_dlopen_mode()) + self._bind_symbols() + except (AttributeError, OSError) as e: + self._load_error = str(e) + self._lib = None + elif required: + self._load_error = "HCOMM shared library not found; set SIMPLER_HCOMM_LIB" + if required and self._lib is None: + raise TransportUnavailable(self.unavailable_reason()) + + @classmethod + def from_env(cls, *, required: bool = False) -> "HcommRuntime": + return cls(required=required) + + @property + def available(self) -> bool: + return self._lib is not None + + def unavailable_reason(self) -> str: + if self._lib is not None: + return "" + return self._load_error or "HCOMM shared library not found; set SIMPLER_HCOMM_LIB" + + def endpoint_create(self, endpoint: "EndpointSpec") -> int: + self._require() + handle = ctypes.c_void_p() + ret = self._lib.HcommEndpointCreate(ctypes.byref(endpoint.to_ctypes()), ctypes.byref(handle)) + _check_hcomm(ret, "HcommEndpointCreate") + return int(handle.value or 0) + + def endpoint_destroy(self, endpoint_handle: int) -> None: + self._require() + ret = self._lib.HcommEndpointDestroy(ctypes.c_void_p(int(endpoint_handle))) + _check_hcomm(ret, "HcommEndpointDestroy") + + def endpoint_start_listen(self, endpoint_handle: int, port: int) -> None: + self._require() + ret = self._lib.HcommEndpointStartListen(ctypes.c_void_p(int(endpoint_handle)), ctypes.c_uint32(int(port)), None) + _check_hcomm(ret, "HcommEndpointStartListen") + + def endpoint_stop_listen(self, endpoint_handle: int, port: int) -> None: + self._require() + ret = self._lib.HcommEndpointStopListen(ctypes.c_void_p(int(endpoint_handle)), ctypes.c_uint32(int(port))) + _check_hcomm(ret, "HcommEndpointStopListen") + + def mem_reg(self, endpoint_handle: int, addr: int, size: int, *, tag: str, mem_type: str = "host") -> int: + self._require() + mem = _HcommMem( + type=_mem_type_value(mem_type), + addr=ctypes.c_void_p(int(addr)), + size=int(size), + ) + handle = ctypes.c_void_p() + ret = self._lib.HcommMemReg( + ctypes.c_void_p(int(endpoint_handle)), + tag.encode(), + ctypes.byref(mem), + ctypes.byref(handle), + ) + _check_hcomm(ret, "HcommMemReg") + return int(handle.value or 0) + + def mem_unreg(self, endpoint_handle: int, mem_handle: int) -> None: + self._require() + ret = self._lib.HcommMemUnreg(ctypes.c_void_p(int(endpoint_handle)), ctypes.c_void_p(int(mem_handle))) + _check_hcomm(ret, "HcommMemUnreg") + + def mem_export(self, endpoint_handle: int, mem_handle: int) -> bytes: + self._require() + desc = ctypes.c_void_p() + desc_len = ctypes.c_uint32() + ret = self._lib.HcommMemExport( + ctypes.c_void_p(int(endpoint_handle)), + ctypes.c_void_p(int(mem_handle)), + ctypes.byref(desc), + ctypes.byref(desc_len), + ) + _check_hcomm(ret, "HcommMemExport") + if not desc.value or desc_len.value == 0: + return b"" + return ctypes.string_at(desc.value, desc_len.value) + + def mem_import(self, endpoint_handle: int, desc: bytes) -> ImportedMemory: + self._require() + payload = bytes(desc) + out = _HcommMem() + ret = self._lib.HcommMemImport( + ctypes.c_void_p(int(endpoint_handle)), + ctypes.c_char_p(payload), + ctypes.c_uint32(len(payload)), + ctypes.byref(out), + ) + _check_hcomm(ret, "HcommMemImport") + return ImportedMemory(remote_addr=int(out.addr or 0), nbytes=int(out.size), mem_type=int(out.type)) + + def mem_unimport(self, endpoint_handle: int, desc: bytes) -> None: + self._require() + payload = bytes(desc) + ret = self._lib.HcommMemUnimport( + ctypes.c_void_p(int(endpoint_handle)), + ctypes.c_char_p(payload), + ctypes.c_uint32(len(payload)), + ) + _check_hcomm(ret, "HcommMemUnimport") + + def channel_create( + self, + endpoint_handle: int, + *, + remote_endpoint: "_EndpointDesc", + socket_handle: int, + local_mem_handles: list[int], + notify_num: int = 1, + engine: int = 0, + role: int = 0, + port: int = 60001, + exchange_all_mems: bool = False, + ) -> int: + self._require() + if not local_mem_handles and not exchange_all_mems: + raise TransportBackendError("HCOMM channel creation requires at least one local mem handle") + mem_array = None + if local_mem_handles: + mem_array_type = ctypes.c_void_p * len(local_mem_handles) + mem_array = mem_array_type(*(ctypes.c_void_p(int(handle)) for handle in local_mem_handles)) + desc = _HcommChannelDesc() + _init_hcomm_channel_desc(desc) + desc.remoteEndpoint = remote_endpoint + desc.notifyNum = int(notify_num) + desc.exchangeAllMems = bool(exchange_all_mems) + desc.memHandles = ctypes.cast(mem_array, ctypes.POINTER(ctypes.c_void_p)) if mem_array is not None else None + desc.memHandleNum = len(local_mem_handles) + desc.socket = ctypes.c_void_p(int(socket_handle)) + desc.role = int(role) + desc.port = int(port) + channels = (ctypes.c_uint64 * 1)() + ret = self._lib.HcommChannelCreate( + ctypes.c_void_p(int(endpoint_handle)), + ctypes.c_int(int(engine)), + ctypes.byref(desc), + ctypes.c_uint32(1), + channels, + ) + _check_hcomm(ret, "HcommChannelCreate") + return int(channels[0]) + + def channel_destroy(self, channels: list[int]) -> None: + self._require() + if not channels: + return + arr_type = ctypes.c_uint64 * len(channels) + arr = arr_type(*(int(channel) for channel in channels)) + ret = self._lib.HcommChannelDestroy(arr, ctypes.c_uint32(len(channels))) + _check_hcomm(ret, "HcommChannelDestroy") + + def channel_get_remote_mem(self, channel_handle: int) -> list[ImportedMemory]: + self._require() + raise TransportUnavailable( + "remote memory query is not available in the public HCOMM C API; " + "use TensorHandle.transport_desc with HcommMemImport instead" + ) + + def write_with_notify( + self, + channel_handle: int, + remote_addr: int, + local_addr: int, + nbytes: int, + remote_notify_idx: int, + ) -> None: + self._require() + ret = self._lib.HcommWriteWithNotifyNbi( + ctypes.c_uint64(int(channel_handle)), + ctypes.c_void_p(int(remote_addr)), + ctypes.c_void_p(int(local_addr)), + ctypes.c_uint64(int(nbytes)), + ctypes.c_uint32(int(remote_notify_idx)), + ) + _check_hcomm(ret, "HcommWriteWithNotifyNbi") + + def channel_fence(self, channel_handle: int) -> None: + self._require() + ret = self._lib.HcommChannelFence(ctypes.c_uint64(int(channel_handle))) + _check_hcomm(ret, "HcommChannelFence") + + def _require(self) -> None: + if self._lib is None: + raise TransportUnavailable(self.unavailable_reason()) + + def _ensure_acl_runtime(self) -> None: + if os.getenv("SIMPLER_HCOMM_ACL_AUTO_INIT", "1").lower() in {"0", "false", "no", "off"}: + return + acl = self._load_acl_runtime() + if acl is None: + raise TransportUnavailable(self._acl_load_error or "libascendcl.so not found for HCOMM endpoint creation") + + device = ctypes.c_int(-1) + if acl.aclrtGetDevice(ctypes.byref(device)) == 0: + return + + ret = acl.aclInit(None) + if ret not in (0, 100002): # ACL_ERROR_REPEAT_INITIALIZE + raise TransportBackendError(f"aclInit failed before HCOMM endpoint creation with aclError={ret}") + + device_id = int(os.getenv("SIMPLER_HCOMM_DEVICE_ID", os.getenv("ASCEND_DEVICE_ID", "0")), 0) + ret = acl.aclrtSetDevice(ctypes.c_int(device_id)) + if ret != 0: + raise TransportBackendError( + f"aclrtSetDevice failed before HCOMM endpoint creation with aclError={ret}, device_id={device_id}" + ) + + def _load_acl_runtime(self): # noqa: ANN202 + if self._acl is not None: + return self._acl + search_dirs = _hcomm_dependency_dirs(self._lib_path or "") + acl_path = _find_dependency_library("libascendcl.so", search_dirs) + if acl_path is None: + self._acl_load_error = "libascendcl.so not found; set ASCEND_HOME_PATH or SIMPLER_HCOMM_DEP_LIB_DIRS" + return None + _ensure_cann_runtime_env(acl_path) + try: + mode = getattr(os, "RTLD_GLOBAL", 0) | getattr(os, "RTLD_NOW", 0) + try: + acl = ctypes.CDLL(str(acl_path), mode=mode) + except OSError: + _preload_needed(acl_path, search_dirs, set(), mode) + acl = ctypes.CDLL(str(acl_path), mode=mode) + acl.aclInit.argtypes = [ctypes.c_char_p] + acl.aclInit.restype = ctypes.c_int + acl.aclrtSetDevice.argtypes = [ctypes.c_int] + acl.aclrtSetDevice.restype = ctypes.c_int + acl.aclrtGetDevice.argtypes = [ctypes.POINTER(ctypes.c_int)] + acl.aclrtGetDevice.restype = ctypes.c_int + self._acl = acl + return self._acl + except OSError as e: + self._acl_load_error = str(e) + return None + + def _bind_symbols(self) -> None: + assert self._lib is not None + self._lib.HcommEndpointCreate.argtypes = [ctypes.POINTER(_EndpointDesc), ctypes.POINTER(ctypes.c_void_p)] + self._lib.HcommEndpointCreate.restype = ctypes.c_int + self._lib.HcommEndpointDestroy.argtypes = [ctypes.c_void_p] + self._lib.HcommEndpointDestroy.restype = ctypes.c_int + self._lib.HcommEndpointStartListen.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p] + self._lib.HcommEndpointStartListen.restype = ctypes.c_int + self._lib.HcommEndpointStopListen.argtypes = [ctypes.c_void_p, ctypes.c_uint32] + self._lib.HcommEndpointStopListen.restype = ctypes.c_int + self._lib.HcommMemReg.argtypes = [ + ctypes.c_void_p, + ctypes.c_char_p, + ctypes.POINTER(_HcommMem), + ctypes.POINTER(ctypes.c_void_p), + ] + self._lib.HcommMemReg.restype = ctypes.c_int + self._lib.HcommMemUnreg.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + self._lib.HcommMemUnreg.restype = ctypes.c_int + self._lib.HcommMemExport.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_void_p), + ctypes.POINTER(ctypes.c_uint32), + ] + self._lib.HcommMemExport.restype = ctypes.c_int + self._lib.HcommMemImport.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_uint32, + ctypes.POINTER(_HcommMem), + ] + self._lib.HcommMemImport.restype = ctypes.c_int + self._lib.HcommMemUnimport.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint32] + self._lib.HcommMemUnimport.restype = ctypes.c_int + self._lib.HcommChannelCreate.argtypes = [ + ctypes.c_void_p, + ctypes.c_int, + ctypes.POINTER(_HcommChannelDesc), + ctypes.c_uint32, + ctypes.POINTER(ctypes.c_uint64), + ] + self._lib.HcommChannelCreate.restype = ctypes.c_int + self._lib.HcommChannelDestroy.argtypes = [ctypes.POINTER(ctypes.c_uint64), ctypes.c_uint32] + self._lib.HcommChannelDestroy.restype = ctypes.c_int + self._lib.HcommWriteWithNotifyNbi.argtypes = [ + ctypes.c_uint64, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_uint64, + ctypes.c_uint32, + ] + self._lib.HcommWriteWithNotifyNbi.restype = ctypes.c_int + self._lib.HcommChannelFence.argtypes = [ctypes.c_uint64] + self._lib.HcommChannelFence.restype = ctypes.c_int + + +@dataclass(frozen=True) +class EndpointSpec: + ip: str + loc_type: str = "host" + host_id: int = 0 + dev_phy_id: int = 0 + super_dev_id: int = 0 + server_idx: int = 0 + super_pod_idx: int = 0 + + @classmethod + def from_env(cls) -> Optional["EndpointSpec"]: + ip = os.getenv("SIMPLER_HCOMM_ENDPOINT_IP") + if not ip: + return None + return cls( + ip=ip, + loc_type=os.getenv("SIMPLER_HCOMM_ENDPOINT_LOC_TYPE", "host"), + host_id=int(os.getenv("SIMPLER_HCOMM_ENDPOINT_HOST_ID", "0"), 0), + dev_phy_id=int(os.getenv("SIMPLER_HCOMM_ENDPOINT_DEV_PHY_ID", "0"), 0), + super_dev_id=int(os.getenv("SIMPLER_HCOMM_ENDPOINT_SUPER_DEV_ID", "0"), 0), + server_idx=int(os.getenv("SIMPLER_HCOMM_ENDPOINT_SERVER_IDX", "0"), 0), + super_pod_idx=int(os.getenv("SIMPLER_HCOMM_ENDPOINT_SUPER_POD_IDX", "0"), 0), + ) + + def to_ctypes(self) -> "_EndpointDesc": + endpoint = _EndpointDesc() + ctypes.memset(ctypes.byref(endpoint), 0xFF, ctypes.sizeof(endpoint)) + endpoint.protocol = 1 # COMM_PROTOCOL_ROCE + ip = ipaddress.ip_address(self.ip) + if ip.version == 4: + endpoint.commAddr.type = 0 # COMM_ADDR_TYPE_IP_V4 + endpoint.commAddr.raws[:4] = ip.packed + else: + endpoint.commAddr.type = 1 # COMM_ADDR_TYPE_IP_V6 + endpoint.commAddr.raws[:16] = ip.packed + if self.loc_type == "device": + endpoint.loc.locType = 0 # ENDPOINT_LOC_TYPE_DEVICE + endpoint.loc.words[0] = int(self.dev_phy_id) + endpoint.loc.words[1] = int(self.super_dev_id) + endpoint.loc.words[2] = int(self.server_idx) + endpoint.loc.words[3] = int(self.super_pod_idx) + else: + endpoint.loc.locType = 1 # ENDPOINT_LOC_TYPE_HOST + endpoint.loc.words[0] = int(self.host_id) + return endpoint + + +class _CommAddr(ctypes.Structure): + _fields_ = [ + ("type", ctypes.c_int), + ("raws", ctypes.c_uint8 * 36), + ] + + +class _EndpointLoc(ctypes.Structure): + _fields_ = [ + ("locType", ctypes.c_int), + ("words", ctypes.c_uint32 * 15), + ] + + +class _EndpointDesc(ctypes.Structure): + _fields_ = [ + ("protocol", ctypes.c_int), + ("commAddr", _CommAddr), + ("loc", _EndpointLoc), + ("raws", ctypes.c_uint8 * 52), + ] + + +class _HcommMem(ctypes.Structure): + _fields_ = [ + ("type", ctypes.c_int), + ("addr", ctypes.c_void_p), + ("size", ctypes.c_uint64), + ] + + +class _CommAbiHeader(ctypes.Structure): + _fields_ = [ + ("version", ctypes.c_uint32), + ("magicWord", ctypes.c_uint32), + ("size", ctypes.c_uint32), + ("reserved", ctypes.c_uint32), + ] + + +class _HcommRoceAttr(ctypes.Structure): + _fields_ = [ + ("queueNum", ctypes.c_uint32), + ("retryCnt", ctypes.c_uint32), + ("retryInterval", ctypes.c_uint32), + ("tc", ctypes.c_uint8), + ("sl", ctypes.c_uint8), + ] + + +class _HcommChannelAttr(ctypes.Union): + _fields_ = [ + ("raws", ctypes.c_uint8 * 128), + ("roceAttr", _HcommRoceAttr), + ] + + +class _HcommChannelDesc(ctypes.Structure): + _fields_ = [ + ("header", _CommAbiHeader), + ("remoteEndpoint", _EndpointDesc), + ("notifyNum", ctypes.c_uint32), + ("exchangeAllMems", ctypes.c_bool), + ("memHandles", ctypes.POINTER(ctypes.c_void_p)), + ("memHandleNum", ctypes.c_uint32), + ("socket", ctypes.c_void_p), + ("role", ctypes.c_int), + ("port", ctypes.c_uint16), + ("attr", _HcommChannelAttr), + ] + + +class _RxeServerDesc(ctypes.Structure): + _fields_ = [ + ("ip", ctypes.c_char * 64), + ("port", ctypes.c_uint16), + ("rkey", ctypes.c_uint32), + ("addr", ctypes.c_uint64), + ("size", ctypes.c_uint32), + ] + + +@dataclass(frozen=True) +class _DecodedRxeDesc: + ip: str + port: int + rkey: int + addr: int + size: int + device: str + gid_index: int + + +_RXE_DESC_MAGIC = b"SRXE" +_RXE_DESC_VERSION = 2 +_RXE_DESC_STRUCT = struct.Struct("<4sHHHHIQQ64s64s") + + +def build_tensor_transport(name: str) -> TensorTransportBackend: + selected = (name or "grpc").lower() + if selected == "grpc": + return GrpcTensorTransport() + if selected == "rxe": + backend = RxeTensorTransport.from_env() + if not backend.available: + raise TransportUnavailable(backend.unavailable_reason()) + return backend + if selected == "hcomm": + backend = HcommTensorTransport.from_env() + if not backend.available: + raise TransportUnavailable(backend.unavailable_reason()) + return backend + if selected == "auto": + if os.getenv("SIMPLER_RXE_AUTO", "").lower() in {"1", "true", "yes", "on"}: + rxe_backend = RxeTensorTransport.from_env() + if rxe_backend.available: + return rxe_backend + backend = HcommTensorTransport.from_env() + return backend if backend.available else GrpcTensorTransport() + raise ValueError(f"unknown tensor transport backend {name!r}") + + +def _check_hcomm(ret: int, op: str) -> None: + if int(ret) != 0: + raise TransportBackendError(f"{op} failed with HcclResult={ret}") + + +def _check_rxe(ret: int, op: str) -> None: + if int(ret) != 0: + raise TransportBackendError(f"{op} failed with errno-style rc={int(ret)}") + + +def _encode_rxe_desc(desc: "_RxeServerDesc", device: str, gid_index: int) -> bytes: + ip = bytes(desc.ip).split(b"\0", 1)[0] + device_bytes = str(device).encode("ascii") + if len(ip) >= 64: + raise TransportBackendError(f"RXE descriptor IP is too long: {ip!r}") + if len(device_bytes) >= 64: + raise TransportBackendError(f"RXE descriptor device name is too long: {device!r}") + return _RXE_DESC_STRUCT.pack( + _RXE_DESC_MAGIC, + _RXE_DESC_VERSION, + _RXE_DESC_STRUCT.size, + int(desc.port), + int(gid_index), + int(desc.rkey), + int(desc.addr), + int(desc.size), + ip.ljust(64, b"\0"), + device_bytes.ljust(64, b"\0"), + ) + + +def _decode_rxe_desc(desc: bytes) -> "_DecodedRxeDesc": + if not desc: + raise TransportBackendError("RXE TensorHandle.transport_desc is empty") + if desc.startswith(_RXE_DESC_MAGIC): + if len(desc) < _RXE_DESC_STRUCT.size: + raise TransportBackendError( + f"RXE binary transport_desc is too short: {len(desc)} < {_RXE_DESC_STRUCT.size}" + ) + magic, version, header_size, port, gid_index, rkey, addr, size, ip_raw, device_raw = _RXE_DESC_STRUCT.unpack( + desc[: _RXE_DESC_STRUCT.size] + ) + if magic != _RXE_DESC_MAGIC or version != _RXE_DESC_VERSION or header_size != _RXE_DESC_STRUCT.size: + raise TransportBackendError( + f"unsupported RXE transport_desc header: magic={magic!r}, version={version}, size={header_size}" + ) + return _DecodedRxeDesc( + ip=ip_raw.split(b"\0", 1)[0].decode("ascii"), + port=int(port), + rkey=int(rkey), + addr=int(addr), + size=int(size), + device=device_raw.split(b"\0", 1)[0].decode("ascii"), + gid_index=int(gid_index), + ) + try: + payload = json.loads(desc.decode("ascii")) + except (UnicodeDecodeError, json.JSONDecodeError) as e: + raise TransportBackendError(f"invalid RXE transport_desc: {e}") from e + if payload.get("transport") != "rxe": + raise TransportBackendError(f"RXE transport_desc has unexpected transport {payload.get('transport')!r}") + return _DecodedRxeDesc( + ip=str(payload["ip"]), + port=int(payload["port"]), + rkey=int(payload.get("rkey", 0)), + addr=int(payload.get("addr", 0)), + size=int(payload["size"]), + device=str(payload.get("device") or os.getenv("SIMPLER_RXE_DEVICE") or "rxe0"), + gid_index=int(payload.get("gid_index", os.getenv("SIMPLER_RXE_GID_INDEX", "0"))), + ) + + +def _build_rxe_verbs_helper() -> Path: + src = Path(__file__).with_name("rxe_verbs_helper.c") + if not src.exists(): + raise TransportUnavailable(f"RXE verbs helper source not found: {src}") + + build_dir = _repo_root() / ".cache" / "rxe_verbs_helper" + build_dir.mkdir(parents=True, exist_ok=True) + suffix = sysconfig.get_config_var("EXT_SUFFIX") or ".so" + out = build_dir / f"libsimpler_rxe_verbs_helper{suffix}" + stamp = build_dir / "rxe_verbs_helper.stamp" + include_dir = _rxe_include_dir() + lib_dir = _rxe_lib_dir() + signature = ( + f"{src.resolve()}:{src.stat().st_mtime_ns}:{src.stat().st_size}\n" + f"include={include_dir}\nlib={lib_dir}\n" + ) + if out.exists() and stamp.exists() and stamp.read_text() == signature: + return out + if include_dir is None or lib_dir is None: + raise TransportUnavailable("RXE helper needs rdma-core headers/libs; set SIMPLER_RXE_INCLUDE_DIR and SIMPLER_RXE_LIB_DIR") + + compiler = os.getenv("CC") or "cc" + cmd = [ + compiler, + "-shared", + "-fPIC", + "-O2", + f"-I{include_dir}", + str(src), + f"-L{lib_dir}", + f"-Wl,-rpath,{lib_dir}", + "-libverbs", + "-lpthread", + "-o", + str(out), + ] + try: + subprocess.check_call(cmd) + except (OSError, subprocess.CalledProcessError) as e: + raise TransportUnavailable(f"failed to build RXE helper with {' '.join(cmd)}: {e}") from e + stamp.write_text(signature) + return out + + +def _rxe_include_dir() -> Optional[Path]: + for value in ( + os.getenv("SIMPLER_RXE_INCLUDE_DIR"), + "/home/ntlab/rdma-build/rdma-core-50.0/build/include", + "/home/ntlab/local/include", + "/usr/include", + ): + if not value: + continue + path = Path(value).expanduser() + if (path / "infiniband" / "verbs.h").exists(): + return path.resolve() + return None + + +def _rxe_lib_dir() -> Optional[Path]: + for value in ( + os.getenv("SIMPLER_RXE_LIB_DIR"), + "/home/ntlab/rdma-build/rdma-core-50.0/build/lib", + "/home/ntlab/local/lib64", + "/home/ntlab/local/lib", + "/usr/lib64", + "/usr/lib", + ): + if not value: + continue + path = Path(value).expanduser() + if (path / "libibverbs.so").exists() or (path / "libibverbs.so.1").exists(): + return path.resolve() + return None + + +def _first_existing_rxe_device() -> Optional[str]: + infiniband = Path("/sys/class/infiniband") + if not infiniband.exists(): + return None + for path in sorted(infiniband.iterdir()): + if path.name.startswith("rxe"): + return path.name + return None + + +def _find_rxe_ipv4_gid(device: str) -> Optional[tuple[int, str]]: + gid_dir = Path("/sys/class/infiniband") / device / "ports" / "1" / "gids" + if not gid_dir.exists(): + return None + for path in sorted(gid_dir.iterdir(), key=lambda item: int(item.name) if item.name.isdigit() else item.name): + try: + text = path.read_text(encoding="ascii").strip() + except OSError: + continue + ip = _ipv4_from_gid(text) + if ip: + return int(path.name), ip + return None + + +def _ipv4_from_gid(gid: str) -> Optional[str]: + parts = gid.strip().split(":") + if len(parts) != 8 or parts[5].lower() != "ffff": + return None + try: + hi = int(parts[6], 16) + lo = int(parts[7], 16) + except ValueError: + return None + return ".".join(str(octet) for octet in (hi >> 8, hi & 0xFF, lo >> 8, lo & 0xFF)) + + +def _find_hcomm_library() -> Optional[str]: + for name in ("hcomm", "hccl", "ascendcl"): + found = ctypes.util.find_library(name) + if found: + return found + return None + + +def _hcomm_dlopen_mode() -> int: + mode = getattr(os, "RTLD_LOCAL", 0) + # Local HCOMM builds can contain unresolved C++ symbols in paths unrelated + # to the public C data-plane API used here. Lazy binding keeps those paths + # from blocking endpoint/memory smoke tests. + mode |= getattr(os, "RTLD_LAZY", 0) + return mode + + +def _preload_hcomm_dependencies(lib_path: str) -> None: + """Load build-tree sidecar libs before libhcomm.so is dlopened. + + The HCOMM build emits `libhcomm.so`, `libhccl_alg.so`, `libhccl_plf.so`, + and `libhccl_v2.so` into sibling directories without an rpath between + them. Preloading by absolute path keeps the Python smoke tests usable + against a fresh local build without asking callers to hand-maintain a long + `LD_LIBRARY_PATH`. + """ + + mode = getattr(os, "RTLD_GLOBAL", 0) + mode |= getattr(os, "RTLD_NOW", 0) + search_dirs = _hcomm_dependency_dirs(lib_path) + root = Path(lib_path).expanduser().resolve() + loaded: set[Path] = set() + _preload_needed(root, search_dirs, loaded, mode) + + +def _preload_hcomm_abi_shim(lib_path: str) -> None: + """Preload Simpler-owned ABI compatibility symbols for local HCOMM builds.""" + + if os.getenv("SIMPLER_HCOMM_ABI_SHIM", "1").lower() in {"0", "false", "no", "off"}: + return + missing = _hcomm_missing_abi_symbols(lib_path) + if not missing: + return + shim = _build_hcomm_abi_shim() + mode = getattr(os, "RTLD_GLOBAL", 0) | getattr(os, "RTLD_NOW", 0) + ctypes.CDLL(str(shim), mode=mode) + + +_HCOMM_ABI_SHIM_SYMBOLS = { + "_ZN4hccl16HcclCommunicator15GenIbvAiRMAInfoI13HcclAiRMAInfoEE10HcclResultjRKSt10shared_ptrINS_9TransportEERKSsPT_", +} + + +def _hcomm_missing_abi_symbols(lib_path: str) -> set[str]: + try: + output = subprocess.check_output( + ["nm", "-D", "--undefined-only", str(Path(lib_path).expanduser())], + text=True, + stderr=subprocess.DEVNULL, + ) + except (OSError, subprocess.CalledProcessError): + return set() + missing: set[str] = set() + for line in output.splitlines(): + for symbol in _HCOMM_ABI_SHIM_SYMBOLS: + if symbol in line: + missing.add(symbol) + return missing + + +def _build_hcomm_abi_shim() -> Path: + src = Path(__file__).with_name("hcomm_abi_shim.cc") + if not src.exists(): + raise TransportUnavailable(f"HCOMM ABI shim source not found: {src}") + + build_dir = _repo_root() / ".cache" / "hcomm_abi_shim" + build_dir.mkdir(parents=True, exist_ok=True) + suffix = sysconfig.get_config_var("EXT_SUFFIX") or ".so" + out = build_dir / f"libsimple_hcomm_abi_shim{suffix}" + stamp = build_dir / "hcomm_abi_shim.stamp" + signature = f"{src.resolve()}:{src.stat().st_mtime_ns}:{src.stat().st_size}\n" + if out.exists() and stamp.exists() and stamp.read_text() == signature: + return out + + compiler = os.getenv("CXX") or "c++" + hcomm_root = _hcomm_source_root() + cmd = [ + compiler, + "-shared", + "-fPIC", + "-O2", + "-std=c++14", + "-DLOG_CPP", + "-DOPEN_BUILD_PROJECT", + "-D_GLIBCXX_USE_CXX11_ABI=0", + *[f"-I{path}" for path in _hcomm_abi_shim_include_dirs(hcomm_root)], + str(src), + "-o", + str(out), + ] + try: + subprocess.check_call(cmd) + except (OSError, subprocess.CalledProcessError) as e: + raise TransportUnavailable(f"failed to build HCOMM ABI shim with {' '.join(cmd)}: {e}") from e + stamp.write_text(signature) + return out + + +def _hcomm_source_root() -> Optional[Path]: + explicit = os.getenv("SIMPLER_HCOMM_SRC_DIR") + if explicit: + root = Path(explicit).expanduser().resolve() + if root.exists(): + return root + root = _repo_root().parent / "3rd" / "hcomm" + if root.exists(): + return root.resolve() + return None + + +def _hcomm_abi_shim_include_dirs(hcomm_root: Optional[Path]) -> list[Path]: + if hcomm_root is None: + raise TransportUnavailable("HCOMM ABI shim needs HCOMM source headers; set SIMPLER_HCOMM_SRC_DIR") + flags_make = hcomm_root / "build" / "src" / "framework" / "CMakeFiles" / "hcomm.dir" / "flags.make" + if flags_make.exists(): + includes = _parse_make_include_dirs(flags_make) + if includes: + return includes + src = hcomm_root / "src" + return [ + hcomm_root / "include", + hcomm_root / "include" / "hccl", + hcomm_root / "pkg_inc", + hcomm_root / "pkg_inc" / "hccl", + src / "framework" / "communicator" / "impl", + src / "framework" / "common" / "src", + src / "framework" / "common" / "src" / "h2d_dto", + src / "framework" / "inc", + src / "pub_inc", + src / "pub_inc" / "inner", + src / "pub_inc" / "new", + src / "algorithm" / "pub_inc", + src / "algorithm" / "base" / "inc", + ] + + +def _parse_make_include_dirs(flags_make: Path) -> list[Path]: + lines = flags_make.read_text().splitlines() + value = "" + collecting = False + for line in lines: + if line.startswith("CXX_INCLUDES = "): + value = line.split("=", 1)[1].strip() + collecting = line.endswith("\\") + elif collecting: + value += " " + line.strip() + collecting = line.endswith("\\") + elif value: + break + includes: list[Path] = [] + seen: set[Path] = set() + for token in shlex.split(value.replace("\\\n", " ")): + if not token.startswith("-I"): + continue + path = Path(token[2:]).expanduser() + if not path.is_dir(): + continue + resolved = path.resolve() + if resolved in seen: + continue + seen.add(resolved) + includes.append(resolved) + return includes + + +def _repo_root() -> Path: + path = Path(__file__).resolve() + for parent in path.parents: + if (parent / "pyproject.toml").exists(): + return parent + return path.parents[3] + + +_SYSTEM_LIB_PREFIXES = ("libc.so", "libdl.so", "libm.so", "libpthread.so", "librt.so", "ld-linux-", "libgcc_s.so") + + +def _preload_needed(path: Path, search_dirs: list[Path], loaded: set[Path], mode: int) -> None: + for lib_name in _elf_needed(path): + if _is_system_library(lib_name): + continue + dep_path = _find_dependency_library(lib_name, search_dirs) + if dep_path is None or dep_path in loaded: + continue + _preload_needed(dep_path, search_dirs, loaded, mode) + ctypes.CDLL(str(dep_path), mode=mode) + loaded.add(dep_path) + + +def _elf_needed(path: Path) -> list[str]: + try: + output = subprocess.check_output(["readelf", "-d", str(path)], text=True, stderr=subprocess.DEVNULL) + except (OSError, subprocess.CalledProcessError): + return [] + needed: list[str] = [] + for line in output.splitlines(): + if "(NEEDED)" not in line: + continue + start = line.find("[") + end = line.find("]", start + 1) + if start >= 0 and end > start: + needed.append(line[start + 1 : end]) + return needed + + +def _is_system_library(lib_name: str) -> bool: + return lib_name.startswith(_SYSTEM_LIB_PREFIXES) or lib_name in {"libstdc++.so.6"} + + +def _find_dependency_library(lib_name: str, search_dirs: list[Path]) -> Optional[Path]: + for directory in _ordered_dependency_dirs(lib_name, search_dirs): + path = directory / lib_name + if path.exists(): + return path.resolve() + found = ctypes.util.find_library(lib_name.removeprefix("lib").removesuffix(".so")) + if found: + path = Path(found) + if path.exists(): + return path.resolve() + return None + + +_HCOMM_BUILD_LIBS = { + "libhcomm.so", + "libhccl_alg.so", + "libhccl_plf.so", + "libhccl_v2.so", + "libhccl_legacy.so", + "libccl_dpu.so", + "libra.so", + "libra_hdc.so", + "libra_peer.so", + "librs.so", + "libtls_adp.so", + "libtopoaddrinfo.so", +} + + +def _ordered_dependency_dirs(lib_name: str, search_dirs: list[Path]) -> list[Path]: + if lib_name not in _HCOMM_BUILD_LIBS: + return search_dirs + build_dirs = [directory for directory in search_dirs if "/3rd/hcomm/build/" in str(directory)] + other_dirs = [directory for directory in search_dirs if directory not in build_dirs] + return build_dirs + other_dirs + + +def _hcomm_dependency_dirs(hcomm_lib_path: str) -> list[Path]: + explicit = os.getenv("SIMPLER_HCOMM_DEP_LIB_DIRS", "") + dirs = [Path(item).expanduser() for item in explicit.split(os.pathsep) if item] + + for env_name in ("ASCEND_HOME_PATH", "ASCEND_TOOLKIT_HOME", "CANN_HOME", "ASCEND_HOME"): + root = os.getenv(env_name) + if root: + dirs.extend(_cann_dependency_dirs(Path(root).expanduser())) + for root in (Path("/home/ntlab/zcy/cann/cann-9.0.0"), Path("/usr/local/Ascend/ascend-toolkit/latest")): + dirs.extend(_cann_dependency_dirs(root)) + + lib = Path(hcomm_lib_path).expanduser() + if lib.parent: + dirs.extend([lib.parent, lib.parent / "legacy"]) + if lib.parent.name == "framework": + src = lib.parent.parent + build = src.parent + dirs.extend( + [ + src / "algorithm", + src / "platform", + src / "platform" / "hccp" / "rdma_agent" / "hdc", + src / "platform" / "hccp" / "rdma_agent" / "peer", + src / "framework" / "legacy", + build / "stub", + build / "_CPack_Packages" / "makeself_staging" / "aarch64-linux" / "lib64", + ] + ) + + result: list[Path] = [] + seen: set[Path] = set() + for directory in dirs: + resolved = directory.resolve() + if resolved in seen or not resolved.is_dir(): + continue + seen.add(resolved) + result.append(resolved) + return result + + +def _cann_dependency_dirs(root: Path) -> list[Path]: + return [ + root / "aarch64-linux" / "lib64", + root / "aarch64-linux" / "lib64" / "device" / "lib64", + root / "aarch64-linux" / "devlib", + root / "aarch64-linux" / "devlib" / "device", + root / "aarch64-linux" / "devlib" / "linux" / "aarch64", + root / "lib64", + root / "lib64" / "device" / "lib64", + root / "devlib", + ] + + +def _ensure_cann_runtime_env(acl_path: Path) -> None: + root = _cann_root_from_acl_path(acl_path) + if root is None: + return + os.environ.setdefault("ASCEND_HOME_PATH", str(root)) + os.environ.setdefault("ASCEND_TOOLKIT_HOME", str(root)) + os.environ.setdefault("ASCEND_OPP_PATH", str(root / "opp")) + os.environ.setdefault("ASCEND_AICPU_PATH", str(root)) + os.environ.setdefault("TOOLCHAIN_HOME", str(root / "toolkit")) + + arch = os.uname().machine + for directory in ( + root / "lib64", + root / "lib64" / "plugin" / "opskernel", + root / "lib64" / "plugin" / "nnengine", + root / "opp" / "built-in" / "op_impl" / "ai_core" / "tbe" / "op_tiling" / "lib" / "linux" / arch, + Path("/usr/local/Ascend/driver/lib64"), + Path("/usr/local/Ascend/driver/lib64/common"), + Path("/usr/local/Ascend/driver/lib64/driver"), + root / "devlib", + ): + _prepend_env_path("LD_LIBRARY_PATH", directory) + + +def _cann_root_from_acl_path(acl_path: Path) -> Optional[Path]: + resolved = acl_path.expanduser().resolve() + for parent in resolved.parents: + if (parent / "opp").is_dir() and (parent / "lib64").is_dir(): + return parent + if parent.name.endswith("-linux") and (parent.parent / "opp").is_dir(): + return parent.parent + return None + + +def _prepend_env_path(name: str, directory: Path) -> None: + if not directory.is_dir(): + return + value = str(directory) + parts = [part for part in os.environ.get(name, "").split(os.pathsep) if part] + if value in parts: + return + os.environ[name] = os.pathsep.join([value, *parts]) + + +def _mem_type_value(mem_type: str) -> int: + return 0 if mem_type == "device" else 1 + + +def _parse_handle(value) -> int: # noqa: ANN001 + if isinstance(value, int): + return value + return int(str(value), 0) + + +def _hcomm_socket_role(value) -> int: # noqa: ANN001 + if isinstance(value, int): + return value + role = str(value).strip().lower() + if role in {"client", "c", "0"}: + return 0 + if role in {"server", "s", "1"}: + return 1 + if role in {"reserved", "auto", "-1"}: + return -1 + raise ValueError(f"unknown HCOMM socket role {value!r}") + + +def _init_hcomm_channel_desc(desc: "_HcommChannelDesc") -> None: + ctypes.memset(ctypes.byref(desc), 0xFF, ctypes.sizeof(desc)) + desc.header.version = 1 + desc.header.magicWord = 0x0FCF0F0F + desc.header.size = ctypes.sizeof(_HcommChannelDesc) + desc.header.reserved = 0 + desc.notifyNum = 0 + desc.exchangeAllMems = False + desc.memHandles = None + desc.memHandleNum = 0 + desc.socket = None + desc.role = -1 + desc.port = 0 + endpoint = EndpointSpec(ip="0.0.0.0").to_ctypes() + endpoint.protocol = -1 + endpoint.commAddr.type = -1 + endpoint.loc.locType = -1 + desc.remoteEndpoint = endpoint + + +def _endpoint_from_transport_desc(desc: bytes) -> "_EndpointDesc": + size = ctypes.sizeof(_EndpointDesc) + if len(desc) < size: + raise TransportBackendError( + f"HCOMM transport_desc is too short to contain EndpointDesc: {len(desc)} < {size}" + ) + return _EndpointDesc.from_buffer_copy(desc[-size:]) + + +def _buffer_addr(data: bytearray) -> int: + if not data: + return 0 + return ctypes.addressof(ctypes.c_char.from_buffer(data)) diff --git a/python/simpler/worker.py b/python/simpler/worker.py index 31a4b459b..72256f4ec 100644 --- a/python/simpler/worker.py +++ b/python/simpler/worker.py @@ -78,6 +78,7 @@ def my_l4_orch(orch, args, config): ContinuousTensor, DataType, TaskArgs, + TensorArgType, _ChipWorker, _Worker, ) @@ -140,6 +141,7 @@ def my_l4_orch(orch, args, config): _CTRL_OFF_ARG1 = 24 _CTRL_OFF_ARG2 = 32 _CTRL_OFF_RESULT = 40 +_TASK_ARGS_TAGS_MAGIC = 0x534C3454 # "SL4T" def _mailbox_addr(shm: SharedMemory) -> int: @@ -195,6 +197,7 @@ def _read_args_from_mailbox(buf) -> TaskArgs: Blob layout at _OFF_ARGS: int32 tensor_count (T), int32 scalar_count (S), ContinuousTensor[T] (40 B each), uint64_t[S] (8 B each). + Optional extension: uint32 magic "SL4T", int32 tags[T]. """ base = _OFF_ARGS t_count = struct.unpack_from("i", buf, base)[0] @@ -207,6 +210,13 @@ def _read_args_from_mailbox(buf) -> TaskArgs: f"args blob ({blob_bytes} bytes) exceeds mailbox capacity ({_MAILBOX_ARGS_CAPACITY} bytes); " f"tensors={t_count}, scalars={s_count} — likely a corrupt header or a writer bug" ) + tag_values = None + tag_magic_off = base + blob_bytes + tag_blob_bytes = blob_bytes + 4 + t_count * 4 + if tag_blob_bytes <= _MAILBOX_ARGS_CAPACITY: + magic = struct.unpack_from("I", buf, tag_magic_off)[0] + if magic == _TASK_ARGS_TAGS_MAGIC: + tag_values = [struct.unpack_from("i", buf, tag_magic_off + 4 + i * 4)[0] for i in range(t_count)] args = TaskArgs() ct_off = base + 8 @@ -217,7 +227,10 @@ def _read_args_from_mailbox(buf) -> TaskArgs: ndims = struct.unpack_from("I", buf, off + 28)[0] dtype_val = struct.unpack_from("B", buf, off + 32)[0] ct = ContinuousTensor.make(data, tuple(shapes[:ndims]), DataType(dtype_val)) - args.add_tensor(ct) + if tag_values is None: + args.add_tensor(ct) + else: + args.add_tensor(ct, TensorArgType(tag_values[i])) sc_off = ct_off + t_count * 40 for i in range(s_count): diff --git a/src/common/hierarchical/worker_manager.cpp b/src/common/hierarchical/worker_manager.cpp index 6636e921b..dcdcad7a2 100644 --- a/src/common/hierarchical/worker_manager.cpp +++ b/src/common/hierarchical/worker_manager.cpp @@ -20,6 +20,8 @@ namespace { +constexpr uint32_t TASK_ARGS_TAGS_MAGIC = 0x534C3454; // "SL4T" + // Read the child-written error message from the mailbox, guaranteeing // NUL-termination even if the child wrote exactly MAILBOX_ERROR_MSG_SIZE // bytes without a terminator. @@ -155,6 +157,7 @@ void WorkerThread::dispatch_thread(TaskSlotState &s, int32_t group_index) { void WorkerThread::dispatch_process(TaskSlotState &s, int32_t group_index) { uint64_t callable = (s.worker_type == WorkerType::SUB) ? static_cast(s.callable_id) : s.callable; TaskArgsView view = s.args_view(group_index); + const TaskArgs &source_args = s.is_group() ? s.task_args_list[static_cast(group_index)] : s.task_args; // Clear the child-writable error fields so stale bytes from a prior // dispatch cannot masquerade as a fresh failure. @@ -169,8 +172,12 @@ void WorkerThread::dispatch_process(TaskSlotState &s, int32_t group_index) { std::memcpy(mbox() + MAILBOX_OFF_CONFIG, &s.config, sizeof(CallConfig)); // Write length-prefixed TaskArgs blob: [T][S][tensors][scalars]. + // A tagged extension [magic][int32 tags[T]] is appended for Python + // mailbox consumers. C++ read_blob ignores the trailing bytes and remains + // compatible with the historical tag-less TaskArgsView wire format. size_t blob_bytes = TASK_ARGS_BLOB_HEADER_SIZE + static_cast(view.tensor_count) * sizeof(ContinuousTensor) + - static_cast(view.scalar_count) * sizeof(uint64_t); + static_cast(view.scalar_count) * sizeof(uint64_t) + sizeof(uint32_t) + + static_cast(view.tensor_count) * sizeof(int32_t); if (blob_bytes > MAILBOX_ARGS_CAPACITY) { throw std::runtime_error("WorkerThread::dispatch_process: args blob exceeds mailbox capacity"); } @@ -189,6 +196,13 @@ void WorkerThread::dispatch_process(TaskSlotState &s, int32_t group_index) { view.scalars, static_cast(view.scalar_count) * sizeof(uint64_t) ); } + size_t tag_off = TASK_ARGS_BLOB_HEADER_SIZE + static_cast(view.tensor_count) * sizeof(ContinuousTensor) + + static_cast(view.scalar_count) * sizeof(uint64_t); + std::memcpy(d + tag_off, &TASK_ARGS_TAGS_MAGIC, sizeof(uint32_t)); + for (int32_t i = 0; i < view.tensor_count; ++i) { + int32_t tag = static_cast(source_args.tag(i)); + std::memcpy(d + tag_off + sizeof(uint32_t) + static_cast(i) * sizeof(int32_t), &tag, sizeof(int32_t)); + } // Signal child process. write_mailbox_state(MailboxState::TASK_READY); diff --git a/tests/ut/py/test_distributed/test_catalog.py b/tests/ut/py/test_distributed/test_catalog.py index 407f84e53..3667b5625 100644 --- a/tests/ut/py/test_distributed/test_catalog.py +++ b/tests/ut/py/test_distributed/test_catalog.py @@ -19,23 +19,6 @@ def fn(args): assert got is not None -def test_catalog_pull_mock_install(): - l4 = Catalog() - cid, version = l4.register(lambda args: args.scalar(0) * 2) - - class MockClient: - def call_unary(self, method, req, timeout=None): - assert method == "Catalog.PullCallable" - return type("Payload", (), {"callable_id": cid, "version": version, "pickled": l4.export_payload(cid, version)}) - - l3 = Catalog() - req = type("Req", (), {"callable_id": cid, "version": version})() - payload = MockClient().call_unary("Catalog.PullCallable", req) - l3.install_from_payload(payload.callable_id, payload.version, payload.pickled) - - assert l3.lookup(cid, version) is not None - - def test_catalog_version_mismatch(): catalog = Catalog() cid, version = catalog.register(lambda args: None) diff --git a/tests/ut/py/test_distributed/test_hcomm_e2e_real.py b/tests/ut/py/test_distributed/test_hcomm_e2e_real.py new file mode 100644 index 000000000..254d8380b --- /dev/null +++ b/tests/ut/py/test_distributed/test_hcomm_e2e_real.py @@ -0,0 +1,180 @@ +import ctypes +import multiprocessing as mp +import os +import queue +import socket +import time +import traceback + +import pytest + +from simpler.distributed.transport_backend import EndpointSpec, HcommRuntime + + +REAL_HCOMM_E2E = pytest.mark.skipif( + os.getenv("SIMPLER_HCOMM_E2E_REAL_TEST") != "1", + reason="set SIMPLER_HCOMM_E2E_REAL_TEST=1 to run the real HCOMM channel smoke test", +) + + +def _unused_tcp_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def _server(q, lib_path: str, ip: str, port: int, payload_len: int) -> None: # noqa: ANN001 + runtime = None + endpoint = 0 + mem_handle = 0 + listening = False + try: + runtime = HcommRuntime(lib_path=lib_path, required=True) + endpoint = runtime.endpoint_create(EndpointSpec(ip=ip)) + target = bytearray(payload_len) + mem_handle = runtime.mem_reg( + endpoint, + ctypes.addressof(ctypes.c_char.from_buffer(target)), + len(target), + tag="server-target", + ) + desc = runtime.mem_export(endpoint, mem_handle) + runtime.endpoint_start_listen(endpoint, port) + listening = True + q.put(("ready", desc)) + remote_endpoint = EndpointSpec(ip=ip).to_ctypes() + channel = runtime.channel_create( + endpoint, + remote_endpoint=remote_endpoint, + socket_handle=0, + local_mem_handles=[mem_handle], + notify_num=0, + engine=0, + role=1, + port=port, + ) + deadline = time.time() + 30 + while bytes(target) == b"\x00" * payload_len and time.time() < deadline: + time.sleep(0.05) + runtime.channel_destroy([channel]) + q.put(("result", bytes(target))) + except Exception: + q.put(("error", "server", traceback.format_exc())) + finally: + if runtime is not None and listening: + runtime.endpoint_stop_listen(endpoint, port) + if runtime is not None and mem_handle: + runtime.mem_unreg(endpoint, mem_handle) + if runtime is not None and endpoint: + runtime.endpoint_destroy(endpoint) + + +def _client(q, lib_path: str, ip: str, port: int, payload: bytes, desc: bytes) -> None: # noqa: ANN001 + runtime = None + endpoint = 0 + staging_handle = 0 + try: + runtime = HcommRuntime(lib_path=lib_path, required=True) + endpoint = runtime.endpoint_create(EndpointSpec(ip=ip)) + staging = bytearray(payload) + staging_handle = runtime.mem_reg( + endpoint, + ctypes.addressof(ctypes.c_char.from_buffer(staging)), + len(staging), + tag="client-staging", + ) + remote = runtime.mem_import(endpoint, desc) + remote_endpoint = EndpointSpec(ip=ip).to_ctypes() + channel = runtime.channel_create( + endpoint, + remote_endpoint=remote_endpoint, + socket_handle=0, + local_mem_handles=[staging_handle], + notify_num=0, + engine=0, + role=0, + port=port, + ) + runtime.write_with_notify( + channel, + remote.remote_addr, + ctypes.addressof(ctypes.c_char.from_buffer(staging)), + len(staging), + 0, + ) + runtime.channel_fence(channel) + runtime.channel_destroy([channel]) + runtime.mem_unimport(endpoint, desc) + except Exception: + q.put(("error", "client", traceback.format_exc())) + finally: + if runtime is not None and staging_handle: + runtime.mem_unreg(endpoint, staging_handle) + if runtime is not None and endpoint: + runtime.endpoint_destroy(endpoint) + + +@REAL_HCOMM_E2E +def test_real_hcomm_cpu_roce_channel_write_smoke(): + lib_path = os.getenv("SIMPLER_HCOMM_LIB") + if not lib_path: + pytest.skip("set SIMPLER_HCOMM_LIB") + ip = os.getenv("SIMPLER_HCOMM_ENDPOINT_IP") or os.getenv("SIMPLER_RXE_SERVER_IP") or "192.168.0.243" + port = int(os.getenv("SIMPLER_HCOMM_CHANNEL_PORT") or _unused_tcp_port()) + ready_timeout = int(os.getenv("SIMPLER_HCOMM_E2E_READY_TIMEOUT", "180"), 0) + join_timeout = int(os.getenv("SIMPLER_HCOMM_E2E_JOIN_TIMEOUT", "180"), 0) + payload = b"simpler-hcomm-e2e-smoke" + q = mp.Queue() + server = mp.Process(target=_server, args=(q, lib_path, ip, port, len(payload))) + server.start() + try: + kind, *items = q.get(timeout=ready_timeout) + except queue.Empty: + server.terminate() + server.join(timeout=5) + pytest.fail(f"server did not publish HCOMM endpoint descriptor within {ready_timeout}s") + if kind == "error": + server.join(timeout=5) + pytest.fail(f"{items[0]} failed before ready:\n{items[1]}") + assert kind == "ready" + desc = items[0] + client = mp.Process(target=_client, args=(q, lib_path, ip, port, payload, desc)) + client.start() + client.join(timeout=join_timeout) + server.join(timeout=join_timeout) + if client.is_alive(): + client.terminate() + client.join(timeout=5) + if server.is_alive(): + server.terminate() + server.join(timeout=5) + result = None + errors = [] + while True: + try: + kind, *items = q.get_nowait() + except queue.Empty: + break + if kind == "result": + result = items[0] + elif kind == "error": + errors.append(tuple(items)) + _xfail_if_stock_hcomm_host_roce_unsupported(errors) + assert client.exitcode == 0, errors + assert server.exitcode == 0, errors + if errors: + role, tb = errors[0] + pytest.fail(f"{role} failed:\n{tb}") + if result is None: + pytest.fail("server did not publish HCOMM smoke result") + assert result == payload + + +def _xfail_if_stock_hcomm_host_roce_unsupported(errors): + if os.getenv("SIMPLER_HCOMM_E2E_REQUIRE_CHANNEL") == "1": + return + for role, tb in errors: + if "HcommChannelCreate failed with HcclResult=5" in tb: + pytest.xfail( + f"stock HCOMM Host CPU RoCE channel is unsupported in this environment ({role}: HCCL_E_NOT_SUPPORT)" + ) diff --git a/tests/ut/py/test_distributed/test_l4_l3_remote.py b/tests/ut/py/test_distributed/test_l4_l3_remote.py index b2b0e9128..3ccd79f9d 100644 --- a/tests/ut/py/test_distributed/test_l4_l3_remote.py +++ b/tests/ut/py/test_distributed/test_l4_l3_remote.py @@ -1,10 +1,18 @@ +import ctypes import struct from multiprocessing.shared_memory import SharedMemory +import pytest + from simpler.distributed.l3_daemon import L3Daemon -from simpler.task_interface import CallConfig, TaskArgs +from simpler.distributed.catalog import Catalog +from simpler.distributed.proto import dispatch_pb2 +from simpler.distributed.remote_proxy import RemoteUnavailable, RemoteWorkerProxy +from simpler.distributed.transport_backend import TransportBackendError, TransportUnavailable +from simpler.task_interface import CallConfig, ContinuousTensor, DataType, TaskArgs, TensorArgType from simpler.worker import Worker + def _scalar_value(args: TaskArgs) -> int: return int(args.scalar(0)) if args is not None and args.scalar_count() else 1 @@ -179,6 +187,226 @@ def l4_orch(orch, args, config): daemon.stop() +def test_l4_remote_inline_tensor_dispatch(tmp_path): + result = tmp_path / "remote_tensor_sum.txt" + payload = b"abcdef" + buf = ctypes.create_string_buffer(payload) + daemon, endpoint = _start_daemon() + + try: + w4 = Worker(level=4, num_sub_workers=0) + + def l3_orch(orch, args, config): + tensor = args.tensor(0) + data = ctypes.string_at(int(tensor.data), int(tensor.nbytes())) + result.write_text(str(sum(data))) + + l3_cid = w4.register(l3_orch) + w4.add_remote_worker(endpoint) + w4.init() + + def l4_orch(orch, args, config): + sub_args = TaskArgs() + tensor = ContinuousTensor.make(ctypes.addressof(buf), (len(payload),), DataType.UINT8) + sub_args.add_tensor(tensor, TensorArgType.INPUT) + orch.submit_next_level(l3_cid, sub_args, CallConfig()) + + w4.run(l4_orch) + w4.close() + assert int(result.read_text()) == sum(payload) + finally: + daemon.stop() + + +def test_l4_remote_handle_tensor_dispatch(tmp_path): + result = tmp_path / "remote_tensor_sum.txt" + payload = bytes(range(256)) * 32 + buf = ctypes.create_string_buffer(payload) + daemon, endpoint = _start_daemon() + + try: + w4 = Worker(level=4, num_sub_workers=0) + + def l3_orch(orch, args, config): + tensor = args.tensor(0) + data = ctypes.string_at(int(tensor.data), int(tensor.nbytes())) + result.write_text(str(len(data)) + ":" + str(sum(data))) + + l3_cid = w4.register(l3_orch) + w4.add_remote_worker(endpoint) + w4.init() + + def l4_orch(orch, args, config): + sub_args = TaskArgs() + tensor = ContinuousTensor.make(ctypes.addressof(buf), (len(payload),), DataType.UINT8) + sub_args.add_tensor(tensor, TensorArgType.INPUT) + orch.submit_next_level(l3_cid, sub_args, CallConfig()) + + w4.run(l4_orch) + w4.close() + assert result.read_text() == f"{len(payload)}:{sum(payload)}" + finally: + daemon.stop() + + +def test_l4_remote_inline_output_tensor_writeback(): + buf = ctypes.create_string_buffer(b"\x00" * 6) + daemon, endpoint = _start_daemon() + + try: + w4 = Worker(level=4, num_sub_workers=0) + + def l3_orch(orch, args, config): + tensor = args.tensor(0) + ctypes.memmove(int(tensor.data), b"fedcba", 6) + + l3_cid = w4.register(l3_orch) + w4.add_remote_worker(endpoint) + w4.init() + + def l4_orch(orch, args, config): + sub_args = TaskArgs() + tensor = ContinuousTensor.make(ctypes.addressof(buf), (6,), DataType.UINT8) + sub_args.add_tensor(tensor, TensorArgType.OUTPUT_EXISTING) + orch.submit_next_level(l3_cid, sub_args, CallConfig()) + + w4.run(l4_orch) + w4.close() + assert bytes(buf.raw) == b"fedcba\x00" + finally: + daemon.stop() + + +def test_l4_remote_handle_output_tensor_writeback(): + payload = bytes((255 - (i % 256) for i in range(8192))) + buf = ctypes.create_string_buffer(b"\x00" * len(payload)) + daemon, endpoint = _start_daemon() + + try: + w4 = Worker(level=4, num_sub_workers=0) + + def l3_orch(orch, args, config): + tensor = args.tensor(0) + ctypes.memmove(int(tensor.data), payload, len(payload)) + + l3_cid = w4.register(l3_orch) + w4.add_remote_worker(endpoint) + w4.init() + + def l4_orch(orch, args, config): + sub_args = TaskArgs() + tensor = ContinuousTensor.make(ctypes.addressof(buf), (len(payload),), DataType.UINT8) + sub_args.add_tensor(tensor, TensorArgType.OUTPUT_EXISTING) + orch.submit_next_level(l3_cid, sub_args, CallConfig()) + + w4.run(l4_orch) + w4.close() + assert bytes(buf.raw) == payload + b"\x00" + finally: + daemon.stop() + + +def test_l4_remote_inout_tensor_writeback(): + payload = bytearray(b"abcde") + buf = ctypes.create_string_buffer(bytes(payload)) + daemon, endpoint = _start_daemon() + + try: + w4 = Worker(level=4, num_sub_workers=0) + + def l3_orch(orch, args, config): + tensor = args.tensor(0) + data = bytearray(ctypes.string_at(int(tensor.data), int(tensor.nbytes()))) + data.reverse() + ctypes.memmove(int(tensor.data), bytes(data), len(data)) + + l3_cid = w4.register(l3_orch) + w4.add_remote_worker(endpoint) + w4.init() + + def l4_orch(orch, args, config): + sub_args = TaskArgs() + tensor = ContinuousTensor.make(ctypes.addressof(buf), (len(payload),), DataType.UINT8) + sub_args.add_tensor(tensor, TensorArgType.INOUT) + orch.submit_next_level(l3_cid, sub_args, CallConfig()) + + w4.run(l4_orch) + w4.close() + assert bytes(buf.raw) == b"edcba\x00" + finally: + daemon.stop() + + +def test_l4_remote_tensor_input_reaches_l3_sub(tmp_path): + result = tmp_path / "remote_sub_tensor_sum.txt" + payload = b"subtensor" + buf = ctypes.create_string_buffer(payload) + daemon, endpoint = _start_daemon() + + try: + w4 = Worker(level=4, num_sub_workers=0) + + def l3_sub(args): + tensor = args.tensor(0) + data = ctypes.string_at(int(tensor.data), int(tensor.nbytes())) + result.write_text(str(sum(data))) + + l3_sub_cid = w4.register(l3_sub) + + def l3_orch(orch, args, config): + orch.submit_sub(l3_sub_cid, args) + + l3_cid = w4.register(l3_orch) + w4.add_remote_worker(endpoint) + w4.init() + + def l4_orch(orch, args, config): + sub_args = TaskArgs() + tensor = ContinuousTensor.make(ctypes.addressof(buf), (len(payload),), DataType.UINT8) + sub_args.add_tensor(tensor, TensorArgType.INPUT) + orch.submit_next_level(l3_cid, sub_args, CallConfig()) + + w4.run(l4_orch) + w4.close() + assert int(result.read_text()) == sum(payload) + finally: + daemon.stop() + + +def test_l4_remote_tensor_output_from_l3_sub_writeback(): + payload = b"sub-output" + buf = ctypes.create_string_buffer(b"\x00" * len(payload)) + daemon, endpoint = _start_daemon() + + try: + w4 = Worker(level=4, num_sub_workers=0) + + def l3_sub(args): + tensor = args.tensor(0) + ctypes.memmove(int(tensor.data), payload, len(payload)) + + l3_sub_cid = w4.register(l3_sub) + + def l3_orch(orch, args, config): + orch.submit_sub(l3_sub_cid, args) + + l3_cid = w4.register(l3_orch) + w4.add_remote_worker(endpoint) + w4.init() + + def l4_orch(orch, args, config): + sub_args = TaskArgs() + tensor = ContinuousTensor.make(ctypes.addressof(buf), (len(payload),), DataType.UINT8) + sub_args.add_tensor(tensor, TensorArgType.OUTPUT_EXISTING) + orch.submit_next_level(l3_cid, sub_args, CallConfig()) + + w4.run(l4_orch) + w4.close() + assert bytes(buf.raw) == payload + b"\x00" + finally: + daemon.stop() + + def test_l4_remote_l3_multiple_subs(tmp_path): read_counter, add_counter = _make_file_counter(tmp_path / "remote_counter.txt") daemon, endpoint = _start_daemon() @@ -233,3 +461,4 @@ def l4_orch(orch, args, config): w4.close() finally: daemon.stop() + diff --git a/tests/ut/py/test_distributed/test_real_e2e_smoke.py b/tests/ut/py/test_distributed/test_real_e2e_smoke.py new file mode 100644 index 000000000..d7f37e39d --- /dev/null +++ b/tests/ut/py/test_distributed/test_real_e2e_smoke.py @@ -0,0 +1,233 @@ +import ctypes +import os +import shutil +import subprocess +import time +from pathlib import Path + +import pytest + +from simpler.distributed.l3_daemon import L3Daemon +from simpler.distributed.transport_backend import EndpointSpec, HcommRuntime, RxeRuntime, _EndpointDesc +from simpler.task_interface import CallConfig, ContinuousTensor, DataType, TaskArgs, TensorArgType +from simpler.worker import Worker + + +REAL_E2E = pytest.mark.skipif( + os.getenv("SIMPLER_REAL_E2E_TEST") != "1", + reason="set SIMPLER_REAL_E2E_TEST=1 to run the real distributed data-plane smoke", +) + + +def _start_daemon(): + daemon = L3Daemon(0, lambda: Worker(level=3, num_sub_workers=1)) + port = daemon.start() + return daemon, f"127.0.0.1:{port}" + + +def _start_daemon_with_transport(tensor_transport: str): + daemon = L3Daemon(0, lambda: Worker(level=3, num_sub_workers=1), tensor_transport=tensor_transport) + port = daemon.start() + return daemon, f"127.0.0.1:{port}" + + +@REAL_E2E +def test_real_l4_l3_tensorpool_handle_e2e(tmp_path): + result = tmp_path / "remote_tensor_sum.txt" + payload = bytes(range(256)) * 32 + in_buf = ctypes.create_string_buffer(payload) + out_payload = bytes((255 - (i % 256) for i in range(len(payload)))) + out_buf = ctypes.create_string_buffer(b"\x00" * len(out_payload)) + daemon, endpoint = _start_daemon() + + try: + w4 = Worker(level=4, num_sub_workers=0) + + def l3_orch(orch, args, config): + in_tensor = args.tensor(0) + out_tensor = args.tensor(1) + data = ctypes.string_at(int(in_tensor.data), int(in_tensor.nbytes())) + result.write_text(f"{len(data)}:{sum(data)}") + ctypes.memmove(int(out_tensor.data), out_payload, len(out_payload)) + + l3_cid = w4.register(l3_orch) + w4.add_remote_worker(endpoint) + w4.init() + + def l4_orch(orch, args, config): + sub_args = TaskArgs() + sub_args.add_tensor( + ContinuousTensor.make(ctypes.addressof(in_buf), (len(payload),), DataType.UINT8), + TensorArgType.INPUT, + ) + sub_args.add_tensor( + ContinuousTensor.make(ctypes.addressof(out_buf), (len(out_payload),), DataType.UINT8), + TensorArgType.OUTPUT_EXISTING, + ) + orch.submit_next_level(l3_cid, sub_args, CallConfig()) + + w4.run(l4_orch) + w4.close() + assert result.read_text() == f"{len(payload)}:{sum(payload)}" + assert bytes(out_buf.raw) == out_payload + b"\x00" + finally: + daemon.stop() + + +@REAL_E2E +def test_real_rxe_ibverbs_smoke(tmp_path): + binary = os.getenv("SIMPLER_RXE_PINGPONG") or shutil.which("ibv_rc_pingpong") + if not binary: + pytest.skip("ibv_rc_pingpong is not available") + + device = os.getenv("SIMPLER_RXE_DEVICE") or _first_existing_rxe_device() + if not device: + pytest.skip("no rxe* device found under /sys/class/infiniband") + + gid_index = os.getenv("SIMPLER_RXE_GID_INDEX") + server_ip = os.getenv("SIMPLER_RXE_SERVER_IP") + if not gid_index or not server_ip: + inferred = _find_ipv4_gid(device) + if inferred is None: + pytest.skip(f"no IPv4-mapped GID found for {device}; set SIMPLER_RXE_GID_INDEX and SIMPLER_RXE_SERVER_IP") + inferred_gid_index, inferred_ip = inferred + gid_index = gid_index or inferred_gid_index + server_ip = server_ip or inferred_ip + + server_log = tmp_path / "rxe_rc_server.log" + client_log = tmp_path / "rxe_rc_client.log" + server_cmd = [binary, "-d", device, "-i", "1", "-g", gid_index] + client_cmd = [binary, "-d", device, "-i", "1", "-g", gid_index, server_ip] + + with server_log.open("wb") as server_out: + server = subprocess.Popen(server_cmd, stdout=server_out, stderr=subprocess.STDOUT) + try: + time.sleep(1.0) + with client_log.open("wb") as client_out: + client = subprocess.run(client_cmd, stdout=client_out, stderr=subprocess.STDOUT, timeout=15, check=False) + try: + server_rc = server.wait(timeout=15) + except subprocess.TimeoutExpired: + server.kill() + server_rc = server.wait(timeout=5) + finally: + if server.poll() is None: + server.kill() + server.wait(timeout=5) + + server_text = server_log.read_text(encoding="utf-8", errors="replace") + client_text = client_log.read_text(encoding="utf-8", errors="replace") + assert client.returncode == 0 and server_rc == 0, ( + f"RXE RC pingpong failed for device={device}, gid_index={gid_index}, server_ip={server_ip}\n" + f"server rc={server_rc}\n{server_text}\nclient rc={client.returncode}\n{client_text}" + ) + assert "bytes in" in server_text + assert "bytes in" in client_text + + +@REAL_E2E +def test_real_l4_l3_rxe_tensor_transport_e2e(tmp_path, monkeypatch): + runtime = RxeRuntime.from_env(required=False) + if not runtime.available or not runtime.device or not runtime.server_ip: + pytest.skip(runtime.unavailable_reason() or "RXE runtime is not configured") + + monkeypatch.setenv("SIMPLER_TENSOR_TRANSPORT", "rxe") + result = tmp_path / "remote_rxe_tensor_sum.txt" + payload = bytes((i * 7) % 251 for i in range(12 * 1024)) + in_buf = ctypes.create_string_buffer(payload, len(payload)) + out_payload = bytes((i * 11) % 253 for i in range(len(payload))) + out_buf = ctypes.create_string_buffer(b"\x00" * len(out_payload), len(out_payload)) + daemon, endpoint = _start_daemon_with_transport("rxe") + + try: + w4 = Worker(level=4, num_sub_workers=0) + + def l3_orch(orch, args, config): + in_tensor = args.tensor(0) + out_tensor = args.tensor(1) + data = ctypes.string_at(int(in_tensor.data), int(in_tensor.nbytes())) + result.write_text(f"{len(data)}:{sum(data)}") + ctypes.memmove(int(out_tensor.data), out_payload, len(out_payload)) + + l3_cid = w4.register(l3_orch) + w4.add_remote_worker(endpoint, tensor_transport="rxe") + w4.init() + + def l4_orch(orch, args, config): + sub_args = TaskArgs() + sub_args.add_tensor( + ContinuousTensor.make(ctypes.addressof(in_buf), (len(payload),), DataType.UINT8), + TensorArgType.INPUT, + ) + sub_args.add_tensor( + ContinuousTensor.make(ctypes.addressof(out_buf), (len(out_payload),), DataType.UINT8), + TensorArgType.OUTPUT_EXISTING, + ) + orch.submit_next_level(l3_cid, sub_args, CallConfig()) + + w4.run(l4_orch) + w4.close() + assert result.read_text() == f"{len(payload)}:{sum(payload)}" + assert bytes(out_buf.raw) == out_payload + finally: + daemon.stop() + + +@REAL_E2E +def test_real_hcomm_endpoint_mem_export_smoke(): + lib_path = os.getenv("SIMPLER_HCOMM_LIB") + endpoint = EndpointSpec.from_env() + if not lib_path or endpoint is None: + pytest.skip("set SIMPLER_HCOMM_LIB and SIMPLER_HCOMM_ENDPOINT_IP") + + runtime = HcommRuntime.from_env(required=True) + endpoint_handle = runtime.endpoint_create(endpoint) + data = bytearray(b"simpler-real-e2e-hcomm") + mem_handle = 0 + try: + mem_handle = runtime.mem_reg( + endpoint_handle, + ctypes.addressof(ctypes.c_char.from_buffer(data)), + len(data), + tag="simpler-real-e2e", + ) + desc = runtime.mem_export(endpoint_handle, mem_handle) + assert desc + assert len(desc) >= ctypes.sizeof(_EndpointDesc) + finally: + if mem_handle: + runtime.mem_unreg(endpoint_handle, mem_handle) + runtime.endpoint_destroy(endpoint_handle) + + +def _first_existing_rxe_device() -> str | None: + infiniband = Path("/sys/class/infiniband") + if not infiniband.exists(): + return None + for path in sorted(infiniband.iterdir()): + if path.name.startswith("rxe"): + return path.name + return None + + +def _ipv4_from_gid(gid: str) -> str | None: + parts = gid.strip().split(":") + if len(parts) != 8 or parts[5].lower() != "ffff": + return None + try: + hi = int(parts[6], 16) + lo = int(parts[7], 16) + except ValueError: + return None + return ".".join(str(octet) for octet in (hi >> 8, hi & 0xFF, lo >> 8, lo & 0xFF)) + + +def _find_ipv4_gid(device: str) -> tuple[str, str] | None: + gid_dir = Path("/sys/class/infiniband") / device / "ports" / "1" / "gids" + if not gid_dir.exists(): + return None + for path in sorted(gid_dir.iterdir(), key=lambda item: int(item.name) if item.name.isdigit() else item.name): + ip = _ipv4_from_gid(path.read_text(encoding="ascii").strip()) + if ip: + return path.name, ip + return None diff --git a/tests/ut/py/test_distributed/test_rxe_real.py b/tests/ut/py/test_distributed/test_rxe_real.py new file mode 100644 index 000000000..d169b9157 --- /dev/null +++ b/tests/ut/py/test_distributed/test_rxe_real.py @@ -0,0 +1,97 @@ +import os +import shutil +import subprocess +import time +from pathlib import Path + +import pytest + + +REAL_RXE = pytest.mark.skipif( + os.getenv("SIMPLER_RXE_REAL_TEST") != "1", + reason="set SIMPLER_RXE_REAL_TEST=1 to run the local RXE/ibverbs smoke test", +) + + +def _first_existing_rxe_device() -> str | None: + infiniband = Path("/sys/class/infiniband") + if not infiniband.exists(): + return None + for path in sorted(infiniband.iterdir()): + if path.name.startswith("rxe"): + return path.name + return None + + +def _ipv4_from_gid(gid: str) -> str | None: + parts = gid.strip().split(":") + if len(parts) != 8 or parts[5].lower() != "ffff": + return None + try: + hi = int(parts[6], 16) + lo = int(parts[7], 16) + except ValueError: + return None + return ".".join(str(octet) for octet in (hi >> 8, hi & 0xFF, lo >> 8, lo & 0xFF)) + + +def _find_ipv4_gid(device: str) -> tuple[str, str] | None: + gid_dir = Path("/sys/class/infiniband") / device / "ports" / "1" / "gids" + if not gid_dir.exists(): + return None + for path in sorted(gid_dir.iterdir(), key=lambda item: int(item.name) if item.name.isdigit() else item.name): + ip = _ipv4_from_gid(path.read_text(encoding="ascii").strip()) + if ip: + return path.name, ip + return None + + +@REAL_RXE +def test_real_rxe_rc_pingpong_smoke(tmp_path): + binary = os.getenv("SIMPLER_RXE_PINGPONG") or shutil.which("ibv_rc_pingpong") + if not binary: + pytest.skip("ibv_rc_pingpong is not available") + + device = os.getenv("SIMPLER_RXE_DEVICE") or _first_existing_rxe_device() + if not device: + pytest.skip("no rxe* device found under /sys/class/infiniband") + + gid_index = os.getenv("SIMPLER_RXE_GID_INDEX") + server_ip = os.getenv("SIMPLER_RXE_SERVER_IP") + if not gid_index or not server_ip: + inferred = _find_ipv4_gid(device) + if inferred is None: + pytest.skip(f"no IPv4-mapped GID found for {device}; set SIMPLER_RXE_GID_INDEX and SIMPLER_RXE_SERVER_IP") + inferred_gid_index, inferred_ip = inferred + gid_index = gid_index or inferred_gid_index + server_ip = server_ip or inferred_ip + + server_log = tmp_path / "rxe_rc_server.log" + client_log = tmp_path / "rxe_rc_client.log" + server_cmd = [binary, "-d", device, "-i", "1", "-g", gid_index] + client_cmd = [binary, "-d", device, "-i", "1", "-g", gid_index, server_ip] + + with server_log.open("wb") as server_out: + server = subprocess.Popen(server_cmd, stdout=server_out, stderr=subprocess.STDOUT) + try: + time.sleep(1.0) + with client_log.open("wb") as client_out: + client = subprocess.run(client_cmd, stdout=client_out, stderr=subprocess.STDOUT, timeout=15, check=False) + try: + server_rc = server.wait(timeout=15) + except subprocess.TimeoutExpired: + server.kill() + server_rc = server.wait(timeout=5) + finally: + if server.poll() is None: + server.kill() + server.wait(timeout=5) + + server_text = server_log.read_text(encoding="utf-8", errors="replace") + client_text = client_log.read_text(encoding="utf-8", errors="replace") + assert client.returncode == 0 and server_rc == 0, ( + f"RXE RC pingpong failed for device={device}, gid_index={gid_index}, server_ip={server_ip}\n" + f"server rc={server_rc}\n{server_text}\nclient rc={client.returncode}\n{client_text}" + ) + assert "bytes in" in server_text + assert "bytes in" in client_text diff --git a/tests/ut/py/test_distributed/test_tensor_pool.py b/tests/ut/py/test_distributed/test_tensor_pool.py index 0dfaa5868..2f144d725 100644 --- a/tests/ut/py/test_distributed/test_tensor_pool.py +++ b/tests/ut/py/test_distributed/test_tensor_pool.py @@ -1,5 +1,20 @@ +import ctypes +import time + +import pytest + from simpler.distributed.proto import dispatch_pb2 -from simpler.distributed.tensor_pool import TensorPool +from simpler.distributed.serialization import ( + decode_task_args_with_tensor_refs, + encode_output_tensor_refs, + encode_tensor_ref, +) +from simpler.distributed.tensor_pool import DEFAULT_INLINE_THRESHOLD, TensorPool, TensorPoolFull +from simpler.distributed.transport_backend import ( + TransportUnavailable, + build_tensor_transport, +) +from simpler.task_interface import ContinuousTensor, DataType, TaskArgs, TensorArgType def test_tensor_pool_inline_bytes(): @@ -13,6 +28,66 @@ def test_tensor_pool_handle_bytes(): ref = pool.put_bytes(b"abcdef") assert ref.HasField("handle") assert pool.get_bytes(ref.handle) == b"abcdef" + assert ref.handle.nbytes == 6 + assert ref.handle.remote_addr != 0 + assert ref.handle.transport == "grpc" + assert ref.handle.lease_deadline_unix_ms > 0 + + +def test_build_tensor_transport_auto_falls_back_to_grpc(monkeypatch): + monkeypatch.delenv("SIMPLER_HCOMM_LIB", raising=False) + monkeypatch.delenv("SIMPLER_HCOMM_ENDPOINT_HANDLE", raising=False) + backend = build_tensor_transport("auto") + assert backend.name == "grpc" + + +def test_build_tensor_transport_explicit_hcomm_requires_configuration(monkeypatch): + monkeypatch.delenv("SIMPLER_HCOMM_LIB", raising=False) + monkeypatch.delenv("SIMPLER_HCOMM_ENDPOINT_HANDLE", raising=False) + with pytest.raises(TransportUnavailable): + build_tensor_transport("hcomm") + + +def test_tensor_pool_default_inline_threshold_is_four_kb(): + pool = TensorPool() + assert pool.inline_threshold == DEFAULT_INLINE_THRESHOLD + assert pool.put_bytes(b"x" * DEFAULT_INLINE_THRESHOLD).HasField("inline_data") + assert pool.put_bytes(b"x" * (DEFAULT_INLINE_THRESHOLD + 1)).HasField("handle") + + +def test_tensor_pool_alloc_write_read_free(): + pool = TensorPool(capacity_bytes=16) + handle = pool.alloc(6, shape=(2, 3), dtype=DataType.UINT8.value, tag=TensorArgType.INPUT.value) + pool.write_bytes(handle, b"abcdef") + assert pool.read_bytes(handle, offset=1, nbytes=3) == b"bcd" + assert pool.get_bytes(handle) == b"abcdef" + pool.free(handle) + with pytest.raises(KeyError): + pool.get_bytes(handle) + + +def test_tensor_pool_capacity_and_lease_gc(): + pool = TensorPool(capacity_bytes=4, default_ttl_ms=10) + handle = pool.alloc(4, ttl_ms=1000) + with pytest.raises(TensorPoolFull): + pool.alloc(1) + pool.free(handle) + handle = pool.alloc(4, ttl_ms=1) + time.sleep(0.01) + assert pool.gc_expired() == 1 + assert pool.used_bytes == 0 + with pytest.raises(KeyError): + pool.get_bytes(handle) + + +def test_tensor_pool_refresh_extends_lease(): + pool = TensorPool(default_ttl_ms=100) + handle = pool.alloc(1, ttl_ms=100) + refreshed = pool.refresh(handle, ttl_ms=1000) + assert refreshed.handle_id == handle.handle_id + assert refreshed.lease_deadline_unix_ms >= handle.lease_deadline_unix_ms + time.sleep(0.01) + assert pool.gc_expired() == 0 def test_tensor_pool_service_pull(): @@ -37,3 +112,65 @@ def test_tensor_pool_service_push(): None, ) assert pool.get_bytes(handle) == b"abcdef" + + +def test_tensor_pool_service_alloc_free_refresh_push_to_handle(): + pool = TensorPool(capacity_bytes=16) + service = pool.service() + handle = service.AllocTensor(dispatch_pb2.TensorAllocReq(nbytes=6, ttl_ms=100), None) + refreshed = service.RefreshTensor(dispatch_pb2.TensorRefreshReq(handle=handle, ttl_ms=1000), None) + assert refreshed.handle_id == handle.handle_id + out = service.PushTensor( + iter( + [ + dispatch_pb2.TensorChunk(handle=handle, offset=0, data=b"abc"), + dispatch_pb2.TensorChunk(handle=handle, offset=3, data=b"def", last=True), + ] + ), + None, + ) + assert out.handle_id == handle.handle_id + assert pool.get_bytes(handle) == b"abcdef" + assert list(service.PullTensor(handle, None))[-1].last + service.FreeTensor(dispatch_pb2.TensorFreeReq(handle=handle), None) + with pytest.raises(KeyError): + pool.get_bytes(handle) + + +def test_encode_decode_tensor_refs_inline_and_handle(): + pool = TensorPool(inline_threshold=4) + inline = encode_tensor_ref( + b"abc", + shape=(3,), + dtype=DataType.UINT8, + tag=TensorArgType.INPUT, + pool=pool, + ) + handled = encode_tensor_ref( + b"abcdef", + shape=(6,), + dtype=DataType.UINT8, + tag=TensorArgType.INPUT, + pool=pool, + force_handle=True, + ) + args, keepalive = decode_task_args_with_tensor_refs([inline, handled], [7], pool) + assert args.tensor_count() == 2 + assert args.scalar(0) == 7 + assert args.tensor(0).nbytes() == 3 + assert args.tensor(1).nbytes() == 6 + assert args.tag(0) == TensorArgType.INPUT + assert len(keepalive) == 2 + + +def test_encode_output_tensor_refs(): + pool = TensorPool(inline_threshold=4) + in_buf = ctypes.create_string_buffer(b"abc") + out_buf = ctypes.create_string_buffer(b"abcdef") + args = TaskArgs() + args.add_tensor(ContinuousTensor.make(ctypes.addressof(in_buf), (3,), DataType.UINT8), TensorArgType.INPUT) + args.add_tensor(ContinuousTensor.make(ctypes.addressof(out_buf), (6,), DataType.UINT8), TensorArgType.OUTPUT) + refs = encode_output_tensor_refs(args, pool) + assert len(refs) == 1 + assert refs[0].HasField("handle") + assert pool.get_bytes(refs[0].handle) == b"abcdef" diff --git a/tests/ut/py/test_distributed/test_transport_backend.py b/tests/ut/py/test_distributed/test_transport_backend.py new file mode 100644 index 000000000..403494286 --- /dev/null +++ b/tests/ut/py/test_distributed/test_transport_backend.py @@ -0,0 +1,163 @@ +import ctypes +import os + +import pytest + +from simpler.distributed.proto import dispatch_pb2 +from simpler.distributed.transport_backend import ( + EndpointSpec, + HcommDataPlaneClient, + HcommRuntime, + RxeDataPlaneClient, + RxeTensorTransport, + _CommAbiHeader, + _CommAddr, + _EndpointDesc, + _EndpointLoc, + _HcommChannelDesc, + _HcommMem, + _RxeServerDesc, + _RXE_DESC_MAGIC, + _decode_rxe_desc, + _encode_rxe_desc, +) + + +REAL_HCOMM = pytest.mark.skipif( + os.getenv("SIMPLER_HCOMM_REAL_TEST") != "1", + reason="set SIMPLER_HCOMM_REAL_TEST=1 to run against a real HCOMM library/environment", +) + + +def test_hcomm_ctypes_struct_layout_matches_public_headers(): + assert ctypes.sizeof(_CommAddr) == 40 + assert ctypes.sizeof(_EndpointLoc) == 64 + assert ctypes.sizeof(_EndpointDesc) == 160 + assert ctypes.sizeof(_HcommMem) == 24 + assert ctypes.sizeof(_CommAbiHeader) == 16 + assert ctypes.sizeof(_HcommChannelDesc) == 344 + assert _HcommChannelDesc.header.offset == 0 + assert _HcommChannelDesc.remoteEndpoint.offset == 16 + assert _HcommChannelDesc.notifyNum.offset == 176 + assert _HcommChannelDesc.memHandles.offset == 184 + assert _HcommChannelDesc.memHandleNum.offset == 192 + assert _HcommChannelDesc.socket.offset == 200 + assert _HcommChannelDesc.role.offset == 208 + assert _HcommChannelDesc.port.offset == 212 + assert _HcommChannelDesc.attr.offset == 216 + + +def test_rxe_desc_roundtrip(): + desc = _RxeServerDesc() + desc.ip = b"192.168.0.243" + desc.port = 12345 + desc.rkey = 678 + desc.addr = 0xABCDEF + desc.size = 4096 + + payload = _encode_rxe_desc(desc, "rxe0", 1) + assert payload.startswith(_RXE_DESC_MAGIC) + decoded = _decode_rxe_desc(payload) + assert decoded.ip == "192.168.0.243" + assert decoded.port == 12345 + assert decoded.rkey == 678 + assert decoded.addr == 0xABCDEF + assert decoded.size == 4096 + assert decoded.device == "rxe0" + assert decoded.gid_index == 1 + + +def test_rxe_legacy_json_desc_is_still_accepted(): + payload = ( + b'{"addr":11259375,"device":"rxe0","gid_index":1,"ip":"192.168.0.243",' + b'"port":12345,"rkey":678,"size":4096,"transport":"rxe","version":1}' + ) + decoded = _decode_rxe_desc(payload) + assert decoded.ip == "192.168.0.243" + assert decoded.port == 12345 + assert decoded.rkey == 678 + assert decoded.addr == 0xABCDEF + assert decoded.size == 4096 + assert decoded.device == "rxe0" + assert decoded.gid_index == 1 + + +def test_rxe_client_rejects_empty_transport_desc(): + client = RxeDataPlaneClient() + handle = dispatch_pb2.TensorHandle(transport="rxe", nbytes=1) + with pytest.raises(Exception, match="transport_desc"): + client.write_handle(handle, 1, 1) + + +@REAL_HCOMM +def test_real_hcomm_runtime_loads_required_symbols(): + runtime = HcommRuntime.from_env(required=True) + assert runtime.available + for symbol in ( + "HcommEndpointCreate", + "HcommEndpointDestroy", + "HcommMemReg", + "HcommMemUnreg", + "HcommMemExport", + "HcommMemImport", + "HcommMemUnimport", + "HcommChannelCreate", + "HcommChannelDestroy", + "HcommWriteWithNotifyNbi", + "HcommChannelFence", + ): + assert hasattr(runtime._lib, symbol) + + +@REAL_HCOMM +def test_real_hcomm_endpoint_mem_reg_export_smoke(): + endpoint = EndpointSpec.from_env() + if endpoint is None and os.getenv("SIMPLER_HCOMM_ENDPOINT_HANDLE"): + endpoint_handle = int(os.environ["SIMPLER_HCOMM_ENDPOINT_HANDLE"], 0) + owns_endpoint = False + elif endpoint is not None: + runtime = HcommRuntime.from_env(required=True) + endpoint_handle = runtime.endpoint_create(endpoint) + owns_endpoint = True + else: + pytest.skip("set SIMPLER_HCOMM_ENDPOINT_IP or SIMPLER_HCOMM_ENDPOINT_HANDLE") + + runtime = HcommRuntime.from_env(required=True) + data = bytearray(b"simpler-hcomm-smoke") + addr = ctypes.addressof(ctypes.c_char.from_buffer(data)) + mem_handle = 0 + try: + mem_handle = runtime.mem_reg(endpoint_handle, addr, len(data), tag="simpler-real-smoke") + desc = runtime.mem_export(endpoint_handle, mem_handle) + assert desc + assert len(desc) >= ctypes.sizeof(_EndpointDesc) + finally: + if mem_handle: + runtime.mem_unreg(endpoint_handle, mem_handle) + if owns_endpoint: + runtime.endpoint_destroy(endpoint_handle) + + +@REAL_HCOMM +def test_real_hcomm_client_precreated_channel_write_smoke(): + required = [ + "SIMPLER_HCOMM_ENDPOINT_HANDLE", + "SIMPLER_HCOMM_CHANNEL_HANDLE", + "SIMPLER_HCOMM_REMOTE_ADDR", + "SIMPLER_HCOMM_REMOTE_NBYTES", + ] + missing = [name for name in required if not os.getenv(name)] + if missing: + pytest.skip(f"missing real channel smoke env: {', '.join(missing)}") + + runtime = HcommRuntime.from_env(required=True) + client = HcommDataPlaneClient(runtime=runtime) + remote_addr = int(os.environ["SIMPLER_HCOMM_REMOTE_ADDR"], 0) + remote_nbytes = int(os.environ["SIMPLER_HCOMM_REMOTE_NBYTES"], 0) + payload = bytes(range(min(remote_nbytes, 64))) + local = ctypes.create_string_buffer(payload) + handle = dispatch_pb2.TensorHandle(remote_addr=remote_addr, nbytes=remote_nbytes) + + client.write_handle(handle, ctypes.addressof(local), len(payload)) + client.fence() + client.close() diff --git a/tools/benchmark_rxe_data_plane.py b/tools/benchmark_rxe_data_plane.py new file mode 100755 index 000000000..27d9c1fff --- /dev/null +++ b/tools/benchmark_rxe_data_plane.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +"""Benchmark gRPC chunking vs RXE data-plane L4/L3 tensor dispatch.""" + +from __future__ import annotations + +import argparse +import ctypes +import statistics +import time +from collections.abc import Iterable + +from simpler.distributed.l3_daemon import L3Daemon +from simpler.task_interface import CallConfig, ContinuousTensor, DataType, TaskArgs, TensorArgType +from simpler.worker import Worker + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--sizes", default="8192,65536,1048576") + parser.add_argument("--repeats", type=int, default=10) + parser.add_argument("--warmup", type=int, default=2) + parser.add_argument("--transports", default="grpc,rxe") + parser.add_argument("--inline-threshold", type=int, default=4096) + args = parser.parse_args() + + sizes = [int(item) for item in args.sizes.split(",") if item] + transports = [item.strip().lower() for item in args.transports.split(",") if item.strip()] + + print("transport,size_bytes,repeats,mean_ms,p50_ms,p95_ms,min_ms,max_ms") + for transport in transports: + for size in sizes: + samples = _benchmark_one( + transport, + size, + repeats=args.repeats, + warmup=args.warmup, + inline_threshold=args.inline_threshold, + ) + print(_format_row(transport, size, samples), flush=True) + return 0 + + +def _benchmark_one(transport: str, size: int, *, repeats: int, warmup: int, inline_threshold: int) -> list[float]: + daemon = L3Daemon(0, lambda: Worker(level=3, num_sub_workers=1), tensor_transport=transport) + endpoint = f"127.0.0.1:{daemon.start()}" + w4 = Worker(level=4, num_sub_workers=0) + + def l3_orch(orch, args, config): # noqa: ANN001 + in_tensor = args.tensor(0) + out_tensor = args.tensor(1) + data = ctypes.string_at(int(in_tensor.data), int(in_tensor.nbytes())) + out = bytes(value ^ 0x5A for value in data) + ctypes.memmove(int(out_tensor.data), out, len(out)) + + l3_cid = w4.register(l3_orch) + w4.add_remote_worker(endpoint, tensor_transport=transport, tensor_inline_threshold=inline_threshold) + w4.init() + + try: + samples: list[float] = [] + for iteration in range(warmup + repeats): + payload = _payload(size, iteration) + expected = bytes(value ^ 0x5A for value in payload) + in_buf = ctypes.create_string_buffer(payload, len(payload)) + out_buf = ctypes.create_string_buffer(b"\0" * len(payload), len(payload)) + + def l4_orch(orch, args, config): # noqa: ANN001 + sub_args = TaskArgs() + sub_args.add_tensor( + ContinuousTensor.make(ctypes.addressof(in_buf), (len(payload),), DataType.UINT8), + TensorArgType.INPUT, + ) + sub_args.add_tensor( + ContinuousTensor.make(ctypes.addressof(out_buf), (len(payload),), DataType.UINT8), + TensorArgType.OUTPUT_EXISTING, + ) + orch.submit_next_level(l3_cid, sub_args, CallConfig()) + + start = time.perf_counter() + w4.run(l4_orch) + elapsed_ms = (time.perf_counter() - start) * 1000.0 + if bytes(out_buf.raw) != expected: + raise RuntimeError(f"{transport} output mismatch for size={size}, iteration={iteration}") + if iteration >= warmup: + samples.append(elapsed_ms) + return samples + finally: + w4.close() + daemon.stop() + + +def _payload(size: int, salt: int) -> bytes: + return bytes((index + salt) % 251 for index in range(size)) + + +def _format_row(transport: str, size: int, samples: Iterable[float]) -> str: + values = list(samples) + sorted_values = sorted(values) + p50 = statistics.median(sorted_values) + p95 = sorted_values[min(len(sorted_values) - 1, int(len(sorted_values) * 0.95))] + return ( + f"{transport},{size},{len(values)}," + f"{statistics.mean(values):.3f},{p50:.3f},{p95:.3f},{min(values):.3f},{max(values):.3f}" + ) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/test_rxe_data_plane.sh b/tools/test_rxe_data_plane.sh new file mode 100755 index 000000000..0b88f01b2 --- /dev/null +++ b/tools/test_rxe_data_plane.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." >/dev/null 2>&1 && pwd)" + +cd "${REPO_ROOT}" + +export PYTHONPATH="${REPO_ROOT}/python${PYTHONPATH:+:${PYTHONPATH}}" + +# Default to the local rdma-core build that is known to back ibv_rc_pingpong on this host. +export SIMPLER_RXE_INCLUDE_DIR="${SIMPLER_RXE_INCLUDE_DIR:-/home/ntlab/rdma-build/rdma-core-50.0/build/include}" +export SIMPLER_RXE_LIB_DIR="${SIMPLER_RXE_LIB_DIR:-/home/ntlab/rdma-build/rdma-core-50.0/build/lib}" + +echo "[1/4] Python import/compile sanity" +python -m py_compile \ + python/simpler/distributed/transport_backend.py \ + python/simpler/distributed/remote_proxy.py \ + python/simpler/distributed/l3_daemon.py \ + tests/ut/py/test_distributed/test_real_e2e_smoke.py \ + tests/ut/py/test_distributed/test_transport_backend.py \ + tools/benchmark_rxe_data_plane.py + +echo "[2/4] Distributed unit tests" +python -m pytest \ + tests/ut/py/test_distributed/test_catalog.py \ + tests/ut/py/test_distributed/test_heartbeat.py \ + tests/ut/py/test_distributed/test_import.py \ + tests/ut/py/test_distributed/test_l4_l3_remote.py \ + tests/ut/py/test_distributed/test_rpc_roundtrip.py \ + tests/ut/py/test_distributed/test_tensor_pool.py \ + tests/ut/py/test_distributed/test_transport_backend.py \ + -q + +echo "[3/4] RXE/ibverbs RC pingpong smoke" +SIMPLER_RXE_REAL_TEST=1 \ +python -m pytest tests/ut/py/test_distributed/test_rxe_real.py -q -s + +echo "[4/4] L4 -> L3 RXE tensor data-plane E2E" +SIMPLER_REAL_E2E_TEST=1 \ +SIMPLER_TENSOR_TRANSPORT=rxe \ +python -m pytest tests/ut/py/test_distributed/test_real_e2e_smoke.py -q -s -k "rxe" + +if [[ "${SIMPLER_RUN_RXE_BENCHMARK:-0}" == "1" ]]; then + echo "[optional] RXE vs gRPC short benchmark" + tools/benchmark_rxe_data_plane.py --sizes 8192,65536 --repeats 3 --warmup 1 +fi + +echo "RXE data-plane tests passed." From 92abca766e0093e77e1cb1420f0529ad86103417 Mon Sep 17 00:00:00 2001 From: PKUZHOU Date: Fri, 8 May 2026 10:04:40 +0800 Subject: [PATCH 3/6] docs(distributed): fix remote result example --- docs/distributed-l4-implementation.zh.md | 7 +++++- examples/distributed/l4_l3_remote/README.md | 7 +++++- .../distributed/l4_l3_remote/l4_master.py | 25 +++++++++---------- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/docs/distributed-l4-implementation.zh.md b/docs/distributed-l4-implementation.zh.md index a2d16e64f..89faefd0e 100644 --- a/docs/distributed-l4-implementation.zh.md +++ b/docs/distributed-l4-implementation.zh.md @@ -386,9 +386,14 @@ python examples/distributed/l4_l3_remote/l4_master.py --remotes 127.0.0.1:5050 期望输出: ```text -remote counter=7 +remote result=7 ``` +注意:L4 注册的 callable 会通过 catalog 序列化后在 L3 daemon/backend +进程中执行。闭包里捕获的 Python 对象会变成远端反序列化副本,不会修改 L4 +进程内的原对象。示例通过 `OUTPUT_EXISTING` tensor 把结果写回 L4 本地 buffer, +而不是依赖远端 callable 修改 L4 本地闭包状态。 + ## 当前测试方式 安装或构建: diff --git a/examples/distributed/l4_l3_remote/README.md b/examples/distributed/l4_l3_remote/README.md index 6867657c2..885d52dd6 100644 --- a/examples/distributed/l4_l3_remote/README.md +++ b/examples/distributed/l4_l3_remote/README.md @@ -15,5 +15,10 @@ python examples/distributed/l4_l3_remote/l4_master.py --remotes 127.0.0.1:5050 Expected output: ```text -remote counter=7 +remote result=7 ``` + +The callable registered on L4 is serialized and executed inside the L3 daemon, +so it must not rely on mutating Python objects captured from the L4 process. +This example returns the distributed result through an `OUTPUT_EXISTING` tensor, +which is copied back into the L4-local buffer after dispatch completes. diff --git a/examples/distributed/l4_l3_remote/l4_master.py b/examples/distributed/l4_l3_remote/l4_master.py index aee49911d..d7b0b19ff 100644 --- a/examples/distributed/l4_l3_remote/l4_master.py +++ b/examples/distributed/l4_l3_remote/l4_master.py @@ -1,27 +1,22 @@ import argparse +import ctypes -from simpler.task_interface import CallConfig, TaskArgs +from simpler.task_interface import CallConfig, ContinuousTensor, DataType, TaskArgs, TensorArgType from simpler.worker import Worker -class Counter: - def __init__(self) -> None: - self.value = 0 - - def add(self, amount: int) -> None: - self.value += int(amount) - - def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--remotes", default="127.0.0.1:5050") args = parser.parse_args() - counter = Counter() + result = ctypes.c_int64(0) endpoints = [item.strip() for item in args.remotes.split(",") if item.strip()] def l3_sub(task_args): - counter.add(task_args.scalar(0)) + output = task_args.tensor(1) + current = ctypes.c_int64.from_address(int(output.data)) + current.value += int(task_args.scalar(0)) w4 = Worker(level=4, num_sub_workers=0) sub_cid = w4.register(l3_sub) @@ -38,14 +33,18 @@ def l4_orch(orch, task_args, config): for value in (2, 5): sub_args = TaskArgs() sub_args.add_scalar(value) + sub_args.add_tensor( + ContinuousTensor.make(ctypes.addressof(result), (1,), DataType.INT64), + TensorArgType.OUTPUT_EXISTING, + ) orch.submit_next_level(l3_cid, sub_args, CallConfig()) w4.run(l4_orch) finally: w4.close() - print(f"remote counter={counter.value}") - return 0 if counter.value == 7 else 1 + print(f"remote result={result.value}") + return 0 if result.value == 7 else 1 if __name__ == "__main__": From b6cc6ddcb186a6d7c6f36b1656c09d2bb309486e Mon Sep 17 00:00:00 2001 From: PKUZHOU Date: Fri, 8 May 2026 10:53:12 +0800 Subject: [PATCH 4/6] docs(distributed): add L4 L3 review guide --- docs/l4-l3-distributed-review-guide.zh.md | 917 ++++++++++++++++++++++ 1 file changed, 917 insertions(+) create mode 100644 docs/l4-l3-distributed-review-guide.zh.md diff --git a/docs/l4-l3-distributed-review-guide.zh.md b/docs/l4-l3-distributed-review-guide.zh.md new file mode 100644 index 000000000..75cdb893c --- /dev/null +++ b/docs/l4-l3-distributed-review-guide.zh.md @@ -0,0 +1,917 @@ +# L4/L3 分布式实现 Review Guide + +本文面向代码 review。目标不是重复使用说明,而是把当前 L4 到 L3 分布式实现按模块、流程和风险点拆开,帮助 reviewer 判断每个模块的设计是否合理、功能是否完整、边界是否清楚。 + +当前实现包含两层: + +- **控制面**:L4 本地 scheduler/mailbox 到远端 L3 daemon 的 gRPC dispatch。 +- **数据面**:tensor payload 的 inline、gRPC TensorPool handle、RXE/ibverbs RDMA write、HCOMM 可选适配。 + +主要代码入口: + +```text +python/simpler/worker.py +python/simpler/distributed/ + catalog.py + rpc.py + remote_proxy.py + l3_daemon.py + serialization.py + tensor_pool.py + transport_backend.py + rxe_verbs_helper.c + hcomm_abi_shim.cc + proto/dispatch.proto + +src/common/hierarchical/worker_manager.cpp +tests/ut/py/test_distributed/ +tools/test_rxe_data_plane.sh +tools/benchmark_rxe_data_plane.py +``` + +## 1. 总览 + +### 1.1 为什么要这样做 + +原有 L4->L3 路径假设 L4 可以在本机 fork L3 child worker,并通过共享内存 mailbox 与该 child 交互。跨 host 后这个假设失效: + +- 远端 L3 不能由 L4 直接 `fork()`。 +- L4 本地 pointer 不能直接发给远端。 +- L4 本地 callable registry 不能通过 fork copy-on-write 自动继承。 +- tensor payload 需要显式数据面传输。 + +当前设计把远端 L3 包装成一个“看起来像本地 PROCESS child”的 mailbox shim: + +```mermaid +flowchart LR + A[L4 C++ scheduler] --> B[本地 remote mailbox] + B --> C[Python remote shim thread] + C --> D[RemoteWorkerProxy] + D -->|gRPC| E[L3Daemon] + E -->|Pipe| F[L3 backend process] + F --> G[Worker level=3] +``` + +L4 C++ scheduler 不知道远端存在;它只看到一个普通 PROCESS-mode next-level mailbox。Python shim thread 负责把 mailbox 中的 task 转成远端 RPC。 + +### 1.2 控制面和数据面的边界 + +```mermaid +flowchart TB + subgraph ControlPlane[控制面] + A[Callable Catalog] + B[DispatchReq / DispatchResp] + C[Heartbeat] + D[TensorPool Alloc/Free/Refresh] + end + + subgraph DataPlane[数据面] + E[inline bytes] + F[gRPC PushTensor/PullTensor chunks] + G[RXE ibverbs RDMA write] + H[HCOMM optional write facade] + end + + B --> E + D --> F + D --> G + D --> H +``` + +控制面只传元数据、handle、callable id、错误信息和生命周期操作。数据面负责 tensor bytes 的实际搬运。 + +## 2. 模块拆解 + +### 2.1 Worker remote child 接入 + +代码: + +- `python/simpler/worker.py` +- `src/common/hierarchical/worker_manager.cpp` + +关键入口: + +- `Worker.add_remote_worker(endpoint, **options)` +- `Worker._ensure_distributed_catalog()` +- `Worker._init_hierarchical()` +- `Worker._start_hierarchical()` +- `_remote_worker_loop(buf, proxy)` + +职责: + +1. L4 用户调用 `add_remote_worker()` 注册远端 endpoint。 +2. `Worker.init()` 时创建本地 shared-memory mailbox。 +3. 创建 `RemoteWorkerProxy(endpoint, catalog, **options)`。 +4. `_start_hierarchical()` 时先 `proxy.handshake()`,再启动 `_remote_worker_loop` thread。 +5. 把 remote mailbox 注册给 C++ `_Worker.add_next_level_process(...)`。 +6. C++ scheduler 发布 `TASK_READY` 后,remote shim thread 读取 mailbox,调用 `proxy.dispatch(...)`。 +7. dispatch 成功或失败后,shim 写回 mailbox error 区和 `TASK_DONE`。 + +流程: + +```mermaid +sequenceDiagram + participant User + participant W4 as Worker(level=4) + participant Cpp as C++ scheduler + participant Shim as remote shim thread + participant Proxy as RemoteWorkerProxy + participant L3 as L3Daemon + + User->>W4: add_remote_worker(endpoint) + User->>W4: init() + W4->>Proxy: create proxy + User->>W4: run(l4_orch) + W4->>Cpp: submit_next_level(...) + Cpp->>Shim: TASK_READY in mailbox + Shim->>Proxy: dispatch(cid,args,cfg) + Proxy->>L3: gRPC Dispatch + L3-->>Proxy: DispatchResp + Proxy-->>Shim: return / raise + Shim->>Cpp: TASK_DONE + error code +``` + +Review 关注点: + +- `add_remote_worker()` 是否应该限制只能在 `level >= 4`、`init()` 前调用。 +- remote mailbox 与本地 child worker mailbox 是否共享同一状态协议。 +- `_remote_worker_loop()` 是否能正确传播异常到 mailbox error 区。 +- remote proxy close 是否和 `_SHUTDOWN` 时序匹配。 +- `worker_manager.cpp` 的 tensor tag 扩展是否保持旧 blob 格式兼容。 + +### 2.2 Callable Catalog + +代码: + +- `python/simpler/distributed/catalog.py` +- `python/simpler/worker.py` +- `python/simpler/distributed/l3_daemon.py` + +职责: + +- L4 侧把注册过的 callable 用 `cloudpickle` 序列化。 +- callable id 保持和 L4 registry 一致。 +- 版本号用 payload hash 表示,dispatch 时带上 `callable_version`。 +- L3 daemon 和 backend process 都维护一份 catalog mirror。 + +流程: + +```mermaid +sequenceDiagram + participant W4 as Worker.register + participant Cat as L4 Catalog + participant Proxy as RemoteWorkerProxy.handshake + participant D as L3Daemon CatalogService + participant B as backend catalog mirror + + W4->>Cat: register(fn, callable_id) + Proxy->>D: PushCallable(payload) + D->>D: install in daemon catalog + D->>B: pipe ("push", cid, version, payload) + B->>B: cloudpickle.loads(payload) +``` + +Review 关注点: + +- `cloudpickle` 是受信任集群内机制,不能暴露给不可信客户端。 +- callable closure 被序列化到 L3,远端修改的是反序列化副本,不会修改 L4 本地 Python 对象。 +- version mismatch 时当前行为是 L3 lookup 失败并返回 error。 + +### 2.3 RPC/protobuf 控制协议 + +代码: + +- `python/simpler/distributed/proto/dispatch.proto` +- `python/simpler/distributed/rpc.py` +- generated `dispatch_pb2*.py` + +核心 service: + +```protobuf +service L3Worker { + rpc Dispatch(DispatchReq) returns (DispatchResp); + rpc Heartbeat(Empty) returns (Health); +} + +service Catalog { + rpc PullCallable(CallableRef) returns (CallablePayload); + rpc PushCallable(CallablePayload) returns (Empty); +} + +service TensorPool { + rpc AllocTensor(TensorAllocReq) returns (TensorHandle); + rpc FreeTensor(TensorFreeReq) returns (Empty); + rpc RefreshTensor(TensorRefreshReq) returns (TensorHandle); + rpc PullTensor(TensorHandle) returns (stream TensorChunk); + rpc PushTensor(stream TensorChunk) returns (TensorHandle); +} +``` + +重要消息: + +- `DispatchReq.tensor_args`:旧地址式 tensor metadata,主要用于没有 tensor_refs 的兼容路径。 +- `DispatchReq.tensor_refs`:当前分布式 tensor 数据面路径,小 tensor inline,大 tensor handle。 +- `DispatchResp.output_tensors`:L3 output 回传,可能是 inline、L3 TensorPool handle,或 L4 local output ACK handle。 +- `TensorHandle.transport`:`grpc` / `rxe` / `hcomm`。 +- `TensorHandle.transport_desc`:transport 私有描述。 + +Review 关注点: + +- protobuf 字段是否后向兼容。 +- `tensor_args` 和 `tensor_refs` 同时存在时的语义是否明确。当前 L4 生成 tensor_refs 时会把 `tensor_args=[]`。 +- `TensorHandle.node_id` 用于区分 L3 pool handle 和 L4 local output ACK。 + +### 2.4 L4 RemoteWorkerProxy + +代码: + +- `python/simpler/distributed/remote_proxy.py` + +职责: + +- 执行 heartbeat/catalog handshake。 +- 把本地 `TaskArgs` staged 成 wire `TensorRef`。 +- 分配/释放 L3 TensorPool handle。 +- 对 input tensor 选择 gRPC/HCOMM/RXE push。 +- 对 output tensor 选择 L4 local RXE writeback 或普通 PullTensor。 +- 处理错误和资源释放。 + +关键函数: + +- `handshake()` +- `dispatch()` +- `_stage_tensor_args()` +- `_stage_local_output_tensor()` +- `_push_remote_tensor_rxe()` +- `_push_remote_tensor_hcomm()` +- `_push_remote_tensor_grpc()` +- `_write_output_tensors()` +- `_free_remote_handles()` +- `_close_local_output_regions()` + +#### 2.4.1 input tensor staging + +```mermaid +flowchart TD + A[TaskArgs tensor] --> B{nbytes <= inline_threshold?} + B -->|yes| C[TensorRef.inline_data] + B -->|no| D[TensorPool.AllocTensor on L3] + D --> E{handle.transport} + E -->|rxe| F[RDMA write into L3 buffer] + E -->|hcomm| G[HCOMM write facade] + E -->|grpc| H[PushTensor chunks] + F --> I[TensorRef.handle] + G --> I + H --> I +``` + +#### 2.4.2 output tensor staging + +当前只有大 `OUTPUT / OUTPUT_EXISTING` 且 transport 是 `rxe` 或 `auto` 时,会走 L4 local RXE region: + +```mermaid +flowchart TD + A[OUTPUT/OUTPUT_EXISTING tensor] --> B{rxe/auto and large?} + B -->|yes| C[L4 RxeRuntime.server_start(local output ptr)] + C --> D[TensorRef.handle node_id=l4-rxe-*] + B -->|no| E[普通 input-style staging / inline] +``` + +`INOUT` 没有走 local output RXE fast path。它仍然走 input staging,因为它需要把初始值发给 L3。 + +#### 2.4.3 dispatch 返回处理 + +```mermaid +flowchart TD + A[DispatchResp.output_tensors] --> B[zip with local output tensor indexes] + B --> C{is local RXE ACK?} + C -->|yes| D[skip copy, remote already wrote local buffer] + C -->|no| E[inline or PullTensor] + E --> F[memmove into local output ptr] + D --> G[free L3 input handles] + F --> G + G --> H[close local output RXE servers] +``` + +Review 关注点: + +- `_stage_tensor_args()` 会对所有 large INPUT 分配远端 handle,并在异常时 best-effort free。 +- local output RXE server lifetime 覆盖 dispatch 整个远端执行和写回。 +- `auto` 模式下 RXE/HCOMM push 失败会回退 gRPC;显式 `rxe`/`hcomm` 失败会让远端不可用。 +- `OUTPUT_EXISTING` 在当前设计中不会把旧值发送给 L3。如果用户依赖旧值,应使用 `INOUT`。 +- `_write_output_tensors()` 依赖 output tensor 顺序与 L3 返回顺序一致。 + +### 2.5 L3Daemon 与 backend process + +代码: + +- `python/simpler/distributed/l3_daemon.py` + +职责: + +- 提供 gRPC service。 +- 在 gRPC server 启动前 fork backend process。 +- 通过 Pipe 把 Catalog/TensorPool/Dispatch 操作转发给 backend。 +- backend process 拥有真实 `TensorPool` 和 `Worker(level=3)`。 + +进程结构: + +```mermaid +flowchart LR + subgraph Daemon[L3Daemon process] + A[gRPC threads] + B[_BackendTensorPoolService facade] + C[_BackendCatalogService facade] + D[L3Worker.Dispatch] + end + + subgraph Backend[backend process] + E[Catalog mirror] + F[TensorPool] + G[transport backend] + H[Worker level=3] + end + + A --> B + A --> C + A --> D + B -->|Pipe op| F + C -->|Pipe push| E + D -->|Pipe dispatch| H + F --> G +``` + +为什么需要 backend process: + +- gRPC server 有线程。 +- `Worker(level=3)` 内部还可能 fork sub/chip worker。 +- 在有活跃 gRPC 线程的进程里 fork 风险较高。 +- backend process 在 gRPC server 启动前创建,后续 Worker fork 发生在 backend 中。 + +#### 2.5.1 backend op + +`_backend_loop()` 当前处理: + +```text +stop +push +tensor_alloc +tensor_free +tensor_refresh +tensor_pull +tensor_push +dispatch +``` + +TensorPool gRPC facade 不直接持有 pool,只把请求 serialize 后通过 Pipe 发送给 backend。 + +#### 2.5.2 dispatch 执行策略 + +```mermaid +flowchart TD + A[backend dispatch] --> B{req.tensor_refs?} + B -->|no| C[decode_task_args tensor_args] + C --> D{persistent inner exists?} + D -->|no| E[create Worker level=3] + D -->|yes| F[reuse inner] + E --> G[run_inner.run] + F --> G + + B -->|yes| H[decode tensor_refs into mmap buffers] + H --> I[create ephemeral Worker level=3] + I --> J[run_inner.run] + J --> K[encode_output_tensor_refs] + K --> L[close ephemeral Worker] +``` + +重点:有 tensor_refs 的 dispatch 使用 **ephemeral inner Worker**。原因是 tensor payload 被 materialize 到 backend process 的 mmap buffer,L3 sub/chip children 必须在这些 mmap 存在后 fork,才能继承同一映射。 + +Review 关注点: + +- scalar-only dispatch 复用 persistent inner,提高普通控制面路径效率。 +- tensor dispatch 每次创建/关闭 inner Worker,语义安全但性能较重。 +- Pipe 同步调用由 `_backend_lock` 串行化,目前没有并发 dispatch 并行执行。 +- backend 异常会被序列化为 `(False, traceback)`,daemon handler 转为 error resp 或 gRPC abort。 + +### 2.6 Tensor serialization + +代码: + +- `python/simpler/distributed/serialization.py` + +职责: + +- `CallConfig` encode/decode。 +- 旧 `ContinuousTensorRef` decode。 +- 新 `TensorRef` materialize。 +- output tensor encode/writeback。 + +关键函数: + +- `encode_task_args()` +- `decode_task_args()` +- `decode_task_args_with_tensor_refs_and_writebacks()` +- `encode_output_tensor_refs()` + +#### 2.6.1 L3 materialize TensorRef + +```mermaid +flowchart TD + A[TensorRef] --> B{local RXE output handle?} + B -->|yes| C[allocate empty mmap buffer sized by shape*dtype] + B -->|no| D[pool.materialize_ref inline or L3 handle] + C --> E[TaskArgs ContinuousTensor ptr=mmap addr] + D --> E + C --> F[record RemoteTensorWriteback] +``` + +`RemoteTensorWriteback` 记录 tensor index 和 L4 output handle,供执行后写回。 + +#### 2.6.2 output encoding + +```mermaid +flowchart TD + A[TaskArgs output tensor] --> B{has writeback?} + B -->|yes| C[try RxeDataPlaneClient.write_handle to L4] + C -->|success| D[return ACK TensorRef(handle=L4 handle)] + C -->|failure| E[fallback pool.put_bytes] + B -->|no| E + E --> F[inline or L3 TensorPool handle] +``` + +Review 关注点: + +- `OUTPUT / OUTPUT_EXISTING` 可作为 remote output writeback。 +- `INOUT` 当前不在 `_REMOTE_OUTPUT_TAGS`,避免丢初始输入值。 +- RXE writeback 失败被吞掉并 fallback 到 pool path;这保证语义,但可能隐藏性能路径失败,后续可以加日志。 +- `_shape_nbytes()` 依赖 `get_element_size(dtype)`。 + +### 2.7 TensorPool + +代码: + +- `python/simpler/distributed/tensor_pool.py` + +职责: + +- 管理 backend process 内的 bytearray storage。 +- 提供 handle/lease/capacity/GC。 +- 暴露 TensorPool gRPC servicer 形状。 +- 调用 `TensorTransportBackend` 注册/注销 region。 + +核心对象: + +```text +TensorPool._entries[handle_id] = _Entry( + data=bytearray, + nbytes, + expires_at_ms, + shape, + dtype, + tag, + region=RegisteredRegion(...) +) +``` + +生命周期: + +```mermaid +flowchart TD + A[AllocTensor] --> B[bytearray] + B --> C[transport_backend.register_region] + C --> D[TensorHandle] + D --> E[PushTensor/RDMA write] + E --> F[RefreshTensor] + F --> G{backend has refresh_region?} + G -->|yes RXE| H[close old one-shot server, start new server] + G -->|no| I[lease only] + D --> J[FreeTensor or GC] + J --> K[transport_backend.unregister_region] +``` + +Review 关注点: + +- capacity 与 lease 是否满足远端执行时间。 +- `refresh_region()` 目前主要服务 RXE one-shot server 重建。 +- `bytearray` 地址在 entry 生命周期内稳定;但 Python 对象生命周期必须由 `_Entry.data` 持有。 +- gRPC `PushTensor` fallback 仍写同一个 pool buffer。 + +### 2.8 Transport backend + +代码: + +- `python/simpler/distributed/transport_backend.py` +- `python/simpler/distributed/rxe_verbs_helper.c` +- `python/simpler/distributed/hcomm_abi_shim.cc` + +抽象: + +```python +class TensorTransportBackend: + def register_region(self, data: bytearray, *, tag: str) -> RegisteredRegion: ... + def unregister_region(self, region: RegisteredRegion) -> None: ... +``` + +#### 2.8.1 gRPC backend + +`GrpcTensorTransport` 只是返回本地 buffer 地址,实际传输走 `PushTensor/PullTensor`。 + +适合作为 fallback 和默认路径。 + +#### 2.8.2 RXE backend + +Python 层: + +- `RxeTensorTransport` +- `RxeRuntime` +- `RxeDataPlaneClient` + +C 层: + +- `simpler_rxe_server_start` +- `simpler_rxe_server_stop` +- `simpler_rxe_write` + +RXE input 流程: + +```mermaid +sequenceDiagram + participant L3 as L3 TensorPool/RxeTensorTransport + participant S as rxe server helper + participant L4 as RxeDataPlaneClient + participant C as rxe client helper + + L3->>S: server_start(buffer addr, size) + S-->>L3: ip, port, rkey, addr, size + L3-->>L4: TensorHandle transport_desc + L4->>C: simpler_rxe_write(local addr, desc) + C->>S: TCP connect, exchange QP info + C->>S: IBV_WR_RDMA_WRITE + C-->>L4: CQ success + L4->>L3: RefreshTensor + L3->>S: stop old server, start new server +``` + +RXE desc v2: + +```text +magic="SRXE" +version=2 +header_size +port +gid_index +rkey +addr +size +ip[64] +device[64] +``` + +Review 关注点: + +- C helper 是 MVP:每个 region/write 还是临时 QP/TCP 控制连接,不是连接池。 +- `server_stop()` 会 shutdown listen fd 和 accepted fd,避免失败路径挂住。 +- `_build_rxe_verbs_helper()` 动态编译 C helper 到 `.cache`,依赖本机 rdma-core include/lib。 +- desc parser 兼容旧 JSON desc。 + +#### 2.8.3 HCOMM backend + +HCOMM 只做 Simpler 侧适配: + +- 不修改 `3rd/hcomm` 源码。 +- `HcommRuntime` 加载 `libhcomm.so` 和 CANN/HCOMM 依赖。 +- `hcomm_abi_shim.cc` 只补 stock HCOMM 本地 build 可能缺的 ABI 符号。 +- CPU RoCE channel E2E 在当前 910B1 host 环境不是主路径。 + +Review 关注点: + +- HCOMM explicit 模式失败应清晰报错。 +- `auto` 默认不会因为 HCOMM/RXE 不可用而破坏 gRPC。 +- HCOMM source tree 不应被本项目提交依赖修改。 + +## 3. 端到端流程详解 + +### 3.1 初始化流程 + +```mermaid +sequenceDiagram + participant User + participant W4 + participant Catalog + participant Proxy + participant L3 + participant Backend + participant Cpp + + User->>W4: register(l3_sub/l3_orch) + W4->>Catalog: register callable payload + User->>W4: add_remote_worker(endpoint) + User->>W4: init() + W4->>Proxy: create RemoteWorkerProxy + User->>W4: run() + W4->>Proxy: handshake() + Proxy->>L3: Heartbeat + Proxy->>L3: PushCallable for every payload + L3->>Backend: pipe push catalog + W4->>Cpp: add_next_level_process(remote mailbox) +``` + +### 3.2 scalar-only dispatch + +```mermaid +sequenceDiagram + participant Cpp as L4 scheduler + participant Shim + participant Proxy + participant L3 + participant Backend + participant Inner as persistent Worker(level=3) + + Cpp->>Shim: TASK_READY callable,args,cfg + Shim->>Proxy: dispatch() + Proxy->>L3: DispatchReq tensor_args/scalars + L3->>Backend: pipe dispatch + Backend->>Inner: create if missing + Backend->>Inner: run(orch_fn,args,cfg) + Inner-->>Backend: done + Backend-->>L3: DispatchResp error_code=0 + L3-->>Proxy: DispatchResp + Proxy-->>Shim: return + Shim->>Cpp: TASK_DONE +``` + +### 3.3 large input with RXE + +```mermaid +sequenceDiagram + participant Proxy as L4 RemoteWorkerProxy + participant TP as L3 TensorPool facade + participant Backend as L3 backend TensorPool + participant RxeS as L3 RXE server + participant RxeC as L4 RXE client + participant L3 as L3 Dispatch + + Proxy->>TP: AllocTensor(nbytes,shape,dtype,tag) + TP->>Backend: pipe tensor_alloc + Backend->>RxeS: register_region / server_start + Backend-->>Proxy: TensorHandle transport=rxe + Proxy->>RxeC: write_handle(handle, local bytes) + RxeC->>RxeS: RDMA write + Proxy->>TP: RefreshTensor(handle) + TP->>Backend: pipe tensor_refresh + Backend->>RxeS: refresh_region stop/start + Proxy->>L3: DispatchReq TensorRef(handle) +``` + +### 3.4 large output with RXE writeback + +```mermaid +sequenceDiagram + participant Proxy as L4 RemoteWorkerProxy + participant RxeS as L4 RXE server on output ptr + participant L3 as L3 backend + participant W as Worker(level=3) + participant RxeC as L3 RXE client + + Proxy->>RxeS: server_start(local output ptr) + Proxy->>L3: DispatchReq TensorRef(handle=node_id l4-rxe-*) + L3->>L3: allocate mmap temp output buffer + L3->>W: run with temp output tensor + W-->>L3: output bytes in temp buffer + L3->>RxeC: write_handle(L4 output handle) + RxeC->>RxeS: RDMA write into L4 buffer + L3-->>Proxy: DispatchResp TensorRef(handle ACK) + Proxy->>Proxy: ACK recognized, skip PullTensor + Proxy->>RxeS: close local output server +``` + +Fallback: + +```mermaid +flowchart TD + A[L3 RXE writeback fails] --> B[catch TransportBackendError/Unavailable] + B --> C[pool.put_bytes output in L3 TensorPool] + C --> D[DispatchResp TensorRef inline or L3 handle] + D --> E[L4 _read_tensor_ref] + E --> F[memmove into local output ptr] +``` + +## 4. Error handling and lifecycle + +### 4.1 remote unavailable + +- heartbeat RPC 失败:`RemoteUnavailable`。 +- dispatch RPC 失败:mark proxy unavailable,free already allocated remote handles。 +- remote response `error_code != 0`:free remote handles and close local output regions, raise `RuntimeError` with remote traceback. + +### 4.2 tensor handle cleanup + +L4 owns two kinds of temporary resources: + +```text +remote_handles: + L3 TensorPool allocated handles for large input or fallback output. + cleanup: TensorPool.FreeTensor best-effort. + +local_output_regions: + L4 RXE servers opened on local output buffer. + cleanup: RxeRuntime.server_stop best-effort. +``` + +L3 backend owns: + +```text +TensorPool entries: + freed by FreeTensor, GC, or TensorPool.close(). + +ephemeral Worker for tensor dispatch: + closed in finally. + +mmap keepalive buffers: + kept alive until run and output encoding finish, then cleared. +``` + +### 4.3 fallback policy + +```text +L4 -> L3 input: + explicit rxe/hcomm failure: error + auto mode rxe/hcomm failure: fallback to gRPC PushTensor + +L3 -> L4 output: + RXE writeback failure: fallback to L3 TensorPool / gRPC output response +``` + +Review 关注点: + +- fallback 是否应该记录 warning/log。当前 output fallback 是静默的。 +- explicit `rxe` input 失败不 fallback,这是为了避免用户以为走了真实 RDMA。 +- cleanup 是 best-effort,失败不会覆盖原始 dispatch 错误。 + +## 5. 当前测试覆盖 + +主要测试: + +```text +tests/ut/py/test_distributed/test_l4_l3_remote.py + - scalar remote dispatch + - inline tensor input + - large handle tensor input + - output tensor writeback + - INOUT writeback + - sub-worker tensor path + - heartbeat behavior + +tests/ut/py/test_distributed/test_tensor_pool.py + - TensorPool alloc/free/refresh + - inline/handle encode/decode + - output TensorRef encode + +tests/ut/py/test_distributed/test_transport_backend.py + - HCOMM struct layout + - RXE desc v2 roundtrip + - RXE legacy JSON desc compatibility + +tests/ut/py/test_distributed/test_real_e2e_smoke.py + - real L4/L3 TensorPool handle E2E + - real RXE ibverbs smoke + - real L4/L3 RXE transport E2E + - HCOMM endpoint/mem export smoke + +tests/ut/py/test_distributed/test_rxe_real.py + - ibv_rc_pingpong smoke +``` + +脚本: + +```bash +tools/test_rxe_data_plane.sh +``` + +当前已验证结果: + +```text +38 passed, 3 skipped +1 passed +2 passed, 2 deselected +RXE data-plane tests passed. +``` + +Benchmark: + +```bash +PYTHONPATH=python tools/benchmark_rxe_data_plane.py \ + --sizes 8192,65536,1048576 \ + --repeats 10 \ + --warmup 2 \ + --transports grpc,rxe +``` + +## 6. 已知局限 + +1. **RXE helper 不是连接池** + + 当前是 TCP 控制连接 + RC QP + MR 的 MVP,每个 region/write 成本较高。`RefreshTensor` 重建 server 解决可用性,不解决性能上限。 + +2. **tensor dispatch 使用 ephemeral inner Worker** + + 这是为了让 sub/chip child fork 后继承 mmap tensor storage。语义安全,但性能重。要复用 persistent L3 worker,需要更强的跨进程 tensor 注入机制。 + +3. **INOUT 未做完整双向 RXE fast path** + + `INOUT` 需要先传初始值到 L3,再把结果写回 L4。当前走 input staging + response output 路径,不走 local output RXE ACK。 + +4. **Pipe backend 串行化** + + `_backend_lock` 让 L3 daemon 到 backend 的操作串行执行。当前简单可靠,但限制并发。 + +5. **Catalog 是受信任执行模型** + + `cloudpickle` payload 是可执行代码反序列化,只能用于受信任集群。 + +6. **output RXE fallback 静默** + + L3->L4 output RXE writeback 失败时会 fallback 到 TensorPool response,功能正确但不容易发现性能路径失效。 + +7. **跨节点和多并发覆盖不足** + + 当前实机验证主要是单机 RXE。跨 host RoCE、多 remote worker、多并发 dispatch、长时间压测还需要补。 + +8. **HCOMM 不是当前主路径** + + HCOMM adapter 可加载 endpoint/mem,CPU RoCE channel E2E 在当前 910B1 host 环境不是主验证路径。 + +## 7. Review Checklist + +### 控制面 + +- `Worker.add_remote_worker()` 的 API 语义是否和 `add_worker()` 一致。 +- remote mailbox shim 是否正确模拟 PROCESS child worker。 +- remote worker thread 与 C++ scheduler 的状态机是否有 race。 +- heartbeat 失败后 `_available` 的语义是否符合调度期望。 +- Catalog version/hash 是否足以避免 stale callable。 + +### 数据面 + +- inline threshold 的默认值和配置入口是否合理。 +- L4 input staging 是否会错误读取 OUTPUT-only tensor 的旧值。 +- `OUTPUT_EXISTING` 不发送旧值、`INOUT` 发送旧值的语义是否清晰。 +- L3 output 顺序与 L4 output tensor index zip 是否可靠。 +- TensorPool lease/GC 是否可能在长任务中提前释放 handle。 + +### RXE + +- RXE desc v2 字段是否足够支持跨 host。 +- RXE helper 的 direct struct/TCP control protocol 是否可接受为 MVP。 +- `server_stop()` 在 client failure、timeout、FreeTensor 时是否不会 hang。 +- refresh 重建 server 是否有重复写/并发写 race。 +- 显式 `rxe` 失败是否应该 fallback,还是继续 fail-fast。 + +### L3 backend + +- 有 tensor_refs 时创建 ephemeral Worker 的成本是否可接受。 +- backend Pipe 串行是否满足当前需求。 +- backend exception 到 `DispatchResp.remote_traceback` 的传播是否足够。 +- TensorPool facade 把所有操作转发 backend 是否有死锁风险。 + +### 测试 + +- 是否需要把 example 加入自动测试。 +- 是否需要新增多 remote worker 分发测试。 +- 是否需要新增 RXE output fallback 强制失败测试。 +- 是否需要跨 host RoCE smoke。 +- 是否需要 benchmark 阈值或只保留观测输出。 + +## 8. 建议 reviewer 先看的文件顺序 + +1. `python/simpler/worker.py` + + 先看 `add_remote_worker()`、`_remote_worker_loop()`、`_init_hierarchical()`、`_start_hierarchical()`,理解 remote worker 如何伪装成本地 PROCESS child。 + +2. `python/simpler/distributed/proto/dispatch.proto` + + 看清楚 `DispatchReq`、`DispatchResp`、`TensorHandle`、`TensorRef`。 + +3. `python/simpler/distributed/remote_proxy.py` + + 这是 L4 侧最核心逻辑,重点看 `_stage_tensor_args()`、`_push_remote_tensor_*()`、`_write_output_tensors()`。 + +4. `python/simpler/distributed/l3_daemon.py` + + 看 daemon/backend process 分离、backend op、`_backend_dispatch()`。 + +5. `python/simpler/distributed/serialization.py` + + 看 TensorRef 到 mmap buffer 的 materialize,以及 output writeback/fallback。 + +6. `python/simpler/distributed/tensor_pool.py` + + 看 handle 生命周期和 refresh hook。 + +7. `python/simpler/distributed/transport_backend.py` + + 看 RXE/HCOMM/gRPC backend 边界和 `build_tensor_transport()` 策略。 + +8. `python/simpler/distributed/rxe_verbs_helper.c` + + 最后看真实 ibverbs MVP 是否符合你对数据面的预期。 From a3ea12e8109cab5a464b93bcce601e94d1a35400 Mon Sep 17 00:00:00 2001 From: PKUZHOU Date: Fri, 8 May 2026 11:15:40 +0800 Subject: [PATCH 5/6] docs(distributed): align L4 review guide with serving design --- docs/l4-l3-distributed-review-guide.zh.md | 216 ++++++++++++++++++---- 1 file changed, 176 insertions(+), 40 deletions(-) diff --git a/docs/l4-l3-distributed-review-guide.zh.md b/docs/l4-l3-distributed-review-guide.zh.md index 75cdb893c..19199476c 100644 --- a/docs/l4-l3-distributed-review-guide.zh.md +++ b/docs/l4-l3-distributed-review-guide.zh.md @@ -1,13 +1,149 @@ # L4/L3 分布式实现 Review Guide -本文面向代码 review。目标不是重复使用说明,而是把当前 L4 到 L3 分布式实现按模块、流程和风险点拆开,帮助 reviewer 判断每个模块的设计是否合理、功能是否完整、边界是否清楚。 +本文面向代码 review。它先说明 `/home/ntlab/zhouzhe/simpler_l4/pypto_top_level_documents/UBL128_serving.md` 要做的完整 serving 系统是什么,再说明当前代码里的 L4/L3 原型实现了其中哪些基础能力、没有实现哪些 serving 功能,最后按模块拆解代码路径和 review 风险点。 -当前实现包含两层: +先给结论:**UBL128_serving.md 描述的是完整的 prefill/decode 解耦推理服务系统;当前 Simpler L4 代码不是完整 serving 系统,而是实现了一个可被 serving 系统复用的“分层 Worker 远程派发 + Tensor 数据面传输基座”。** -- **控制面**:L4 本地 scheduler/mailbox 到远端 L3 daemon 的 gRPC dispatch。 -- **数据面**:tensor payload 的 inline、gRPC TensorPool handle、RXE/ibverbs RDMA write、HCOMM 可选适配。 +当前实现已经打通: -主要代码入口: +- **L4 -> remote L3 控制面**:L4 本地 scheduler/mailbox 到远端 L3 daemon 的 dispatch。 +- **Callable 分发**:L4 注册的 Python callable 通过 catalog 推送到 L3 daemon/backend。 +- **Tensor 数据面抽象**:small tensor inline,大 tensor 走 TensorPool handle。 +- **Transport backend 边界**:默认 gRPC,RXE/ibverbs RDMA write MVP,HCOMM 可选适配层。 +- **Output writeback**:大 `OUTPUT / OUTPUT_EXISTING` 可由 L3 通过 RXE 写回 L4 本地 output buffer,失败时 fallback 到 response/TensorPool 路径。 + +当前实现没有做: + +- 外部 HTTP/gRPC serving frontend。 +- prefill/decode 角色拆分、continuous batching、request scheduler/router。 +- KV Meta Server、prefix radix tree、SSU LBA 分配、KV block 生命周期。 +- NPU -> SSU 的 SO/UB Urma KV 数据面。 +- 生产设计中的 SO uRPC hot-path 协议栈。 +- UBL128/PC16/SSU 的配置驱动拓扑管理。 + +## 1. UBL128 Serving 总设计先做什么 + +`UBL128_serving.md` 的目标是一个带 prefix cache 的 prefill/decode 解耦推理服务。它不是单个 runtime API,而是一套跨 CPU、NPU、SSU 存储和多张网络的完整 serving 系统。 + +核心角色如下: + +| 角色 | UBL128 文档中的名字 | 物理载体 | 主要职责 | +|------|--------------------|----------|----------| +| 前端/调度器 | `frontend` / F | 鲲鹏 CPU service_access | 外部 ingress、tokenize、PrefixMatch、选择 prefill/decode worker、回流 token | +| KV 元数据 | `kv-meta` / M | 鲲鹏 CPU | prefix radix tree、ChunkRecord、SSU LBA 分配、引用计数 | +| Prefill host | PC | PC16 host CPU | 代 NPU 做 metadata RPC、给 NPU doorbell/mailbox | +| Prefill NPU | PN | Ascend NPU | 计算 prompt/unmatched tail,写新 KV | +| Decode host | DC | PC16 host CPU | 管理 decode slot、代 NPU 做 metadata RPC | +| Decode NPU | DN | Ascend NPU | 读 KV、decode、continuous batching | +| KV 存储 | SSU / S | SSU12 存储框 | LBA-direct KV/prefix block 存储 | + +典型请求路径: + +```mermaid +flowchart LR + U[external user] -->|HTTP/gRPC on DCN| F[frontend / scheduler] + F -->|PrefixMatch on SO uRPC| M[KV Meta Server] + M -->|ChunkRecord list| F + F -->|DispatchPrefill on SO uRPC| PC[Prefill host CPU] + PC -->|doorbell / mailbox| PN[Prefill NPU] + PN -->|SO Urma read/write| S[SSU KV store] + F -->|DispatchDecode on SO uRPC| DC[Decode host CPU] + DC -->|doorbell / mailbox| DN[Decode NPU] + DN -->|SO Urma read/write| S + DN -->|token stream via host| F + F -->|HTTP/gRPC on DCN| U +``` + +网络/协议分层是这个设计的硬约束: + +| 网络 | 设计用途 | 协议选择 | +|------|----------|----------| +| DCN/RoCE | 外部 ingress、运维、健康检查、POSIX FS 等非 hot-path | gRPC 或 HTTP/JSON | +| SO/UBG | 跨 UBL128 hot-path RPC、KV bytes、prefill->decode 移交 | uRPC over UB Urma;数据面用 Urma read/write | +| SU | UBL128 域内 NPU 间 EP/DP 通信 | Urma,不承载普通 RPC | +| 服务器内 UB | host CPU 与本机 NPU doorbell/mailbox | 本地 mailbox/doorbell | + +所以,完整 serving 系统需要闭合三件事: + +1. 请求级控制面:frontend/scheduler/router,把请求拆成 prefill 和 decode 阶段。 +2. KV/prefix 语义:KV Meta Server、prefix radix tree、ChunkRecord、SSU LBA 分配与回收。 +3. 数据面:NPU/CPU/SSU 之间通过 SO/SU/UB 的真实数据搬运。 + +## 2. 当前代码里的 L4 是什么 + +当前代码里的 **L4** 指 Simpler runtime 的 level-4 orchestrator/parent worker 层,不等同于 UBL128 文档里的完整 frontend/scheduler。它现在解决的问题更底层: + +> L4 进程如何把一个 Simpler task 派发到远端 L3 daemon,并让远端 L3 能拿到 callable、scalar 参数、tensor 参数,执行后把 output tensor 返回或写回。 + +也就是说,当前 L4/L3 原型更接近 UBL128 文档中 **F -> PC/DC 的 Dispatch 基础能力**,外加一个临时的 host-side tensor transport MVP。它还没有实现请求级 serving 编排,也没有实现 KV/prefix 的业务语义。 + +当前原型的大图: + +```mermaid +flowchart LR + subgraph L4Host[L4 host process] + A[User L4 orchestrator] + B[L4 C++ scheduler] + C[remote PROCESS mailbox shim] + D[RemoteWorkerProxy] + E[local output RXE region optional] + end + + subgraph L3Host[L3 daemon host] + F[gRPC L3Daemon] + G[backend process] + H[TensorPool] + I[Worker level=3] + J[transport backend grpc/rxe/hcomm] + end + + A --> B + B --> C + C --> D + D -->|DispatchReq/Resp gRPC| F + D -->|TensorPool RPC / RXE write| H + F -->|Pipe ops| G + G --> H + G --> I + H --> J + I -->|output bytes| H + I -->|RXE writeback optional| E +``` + +当前代码把远端 L3 包装成一个“看起来像本地 PROCESS child”的 mailbox shim。这样 L4 C++ scheduler 的调度模型可以先不重写;Python shim thread 负责把 mailbox task 转成远端 RPC。 + +## 3. 与 UBL128_serving.md 的功能映射 + +| UBL128 设计能力 | 当前代码状态 | 相关代码 | 说明 / 差距 | +|----------------|--------------|----------|-------------| +| 外部 HTTP/gRPC ingress | 未实现 | 无 | 当前没有服务接入层、tokenize/detokenize、用户请求协议。 | +| Frontend scheduler/router | 部分基础可复用 | `python/simpler/worker.py`, `remote_proxy.py` | 当前 L4 能向远端 L3 派发 task,但还不是请求级 scheduler,不懂 prefill/decode、slot、batch、quota。 | +| F -> PC/DC Dispatch | 已实现原型 | `Worker.add_remote_worker()`, `_remote_worker_loop()`, `RemoteWorkerProxy.dispatch()` | 现在走 gRPC 到 L3 daemon;设计文档中的 hot-path 目标是 SO uRPC。 | +| Callable/任务代码分发 | 已实现 | `catalog.py`, `l3_daemon.py` | 用 `cloudpickle` 推送 callable。closure 是远端副本,不能修改 L4 本地 Python 状态。 | +| Scalar 参数派发 | 已实现 | `serialization.py`, `dispatch.proto` | scalar-only dispatch 可复用 persistent L3 Worker。 | +| Tensor 参数派发 | 已实现原型 | `serialization.py`, `tensor_pool.py`, `transport_backend.py` | small inline,大 tensor handle;host-side tensor,不是 NPU HBM KV。 | +| L3 output 返回 | 已实现原型 | `remote_proxy.py`, `serialization.py` | 支持 inline、L3 TensorPool handle、RXE local output ACK。 | +| DCN gRPC 控制面 | 已实现原型 | `rpc.py`, `l3_daemon.py`, `dispatch.proto` | 当前 gRPC 适合 smoke/control path;与 UBL128 文档中 DCN 非 hot-path gRPC 一致。 | +| SO uRPC hot-path | 未实现 | 无 | 目前没有 uRPC over UB Urma runtime;gRPC 不能代表最终 SO hot-path。 | +| RXE/RoCE 数据面 smoke | 已实现 MVP | `transport_backend.py`, `rxe_verbs_helper.c` | 基于 ibverbs RC RDMA write;用于实机 smoke,不是 UBL128 SO/UBG 生产路径。 | +| HCOMM 适配边界 | 部分实现 | `transport_backend.py`, `hcomm_abi_shim.cc` | 只做 Simpler 外部适配,不依赖修改 `3rd/hcomm`。 | +| Prefill worker 角色 | 未实现 | 无 | 当前 L3 只是 generic `Worker(level=3)`,不区分 prefill。 | +| Decode worker 角色 | 未实现 | 无 | 当前没有 decode slot table、continuous batching、sampling loop。 | +| KV Meta Server | 未实现 | 无 | 没有 prefix radix tree、ChunkRecord、LBA free list、refcount。 | +| Prefix cache | 未实现 | 无 | 没有 token chunk hash、longest prefix match、cache hit/miss。 | +| SSU LBA-direct KV store | 未实现 | 无 | 当前 TensorPool 是进程内 bytearray storage,不是 SSU。 | +| NPU -> SSU SO Urma read/write | 未实现 | 无 | 当前 RXE 是 host memory RDMA write smoke;不涉及 NPU HBM/SSU。 | +| EP/DP on SU | 未实现 | 无 | Simpler 当前没有 UBL128 拓扑感知 EP/DP 通信调度。 | +| 配置驱动 A/B/C 三档硬件 | 未实现 | 无 | 当前 remote endpoint/transport 是局部配置,不是完整拓扑配置系统。 | + +Review 时建议把当前 PR/分支按“runtime 基座”来评估,而不是按“完整 UBL128 serving 系统”来验收。更具体地说: + +- 可以 review:L4 到远端 L3 的 task 派发语义、callable catalog、tensor handle 生命周期、RXE/HCOMM backend 边界、测试覆盖。 +- 不应误判为已完成:prefill/decode serving、KV cache、prefix cache、SO uRPC、NPU/SSU 数据面。 + +## 4. 主要代码入口 + +后续章节按当前实现拆解。主要代码入口: ```text python/simpler/worker.py @@ -29,9 +165,9 @@ tools/test_rxe_data_plane.sh tools/benchmark_rxe_data_plane.py ``` -## 1. 总览 +## 5. 当前 L4/L3 原型总览 -### 1.1 为什么要这样做 +### 5.1 为什么需要 remote L3 原有 L4->L3 路径假设 L4 可以在本机 fork L3 child worker,并通过共享内存 mailbox 与该 child 交互。跨 host 后这个假设失效: @@ -54,7 +190,7 @@ flowchart LR L4 C++ scheduler 不知道远端存在;它只看到一个普通 PROCESS-mode next-level mailbox。Python shim thread 负责把 mailbox 中的 task 转成远端 RPC。 -### 1.2 控制面和数据面的边界 +### 5.2 控制面和数据面的边界 ```mermaid flowchart TB @@ -80,9 +216,9 @@ flowchart TB 控制面只传元数据、handle、callable id、错误信息和生命周期操作。数据面负责 tensor bytes 的实际搬运。 -## 2. 模块拆解 +## 6. 模块拆解 -### 2.1 Worker remote child 接入 +### 6.1 Worker remote child 接入 代码: @@ -139,7 +275,7 @@ Review 关注点: - remote proxy close 是否和 `_SHUTDOWN` 时序匹配。 - `worker_manager.cpp` 的 tensor tag 扩展是否保持旧 blob 格式兼容。 -### 2.2 Callable Catalog +### 6.2 Callable Catalog 代码: @@ -177,7 +313,7 @@ Review 关注点: - callable closure 被序列化到 L3,远端修改的是反序列化副本,不会修改 L4 本地 Python 对象。 - version mismatch 时当前行为是 L3 lookup 失败并返回 error。 -### 2.3 RPC/protobuf 控制协议 +### 6.3 RPC/protobuf 控制协议 代码: @@ -221,7 +357,7 @@ Review 关注点: - `tensor_args` 和 `tensor_refs` 同时存在时的语义是否明确。当前 L4 生成 tensor_refs 时会把 `tensor_args=[]`。 - `TensorHandle.node_id` 用于区分 L3 pool handle 和 L4 local output ACK。 -### 2.4 L4 RemoteWorkerProxy +### 6.4 L4 RemoteWorkerProxy 代码: @@ -249,7 +385,7 @@ Review 关注点: - `_free_remote_handles()` - `_close_local_output_regions()` -#### 2.4.1 input tensor staging +#### 6.4.1 input tensor staging ```mermaid flowchart TD @@ -265,7 +401,7 @@ flowchart TD H --> I ``` -#### 2.4.2 output tensor staging +#### 6.4.2 output tensor staging 当前只有大 `OUTPUT / OUTPUT_EXISTING` 且 transport 是 `rxe` 或 `auto` 时,会走 L4 local RXE region: @@ -279,7 +415,7 @@ flowchart TD `INOUT` 没有走 local output RXE fast path。它仍然走 input staging,因为它需要把初始值发给 L3。 -#### 2.4.3 dispatch 返回处理 +#### 6.4.3 dispatch 返回处理 ```mermaid flowchart TD @@ -301,7 +437,7 @@ Review 关注点: - `OUTPUT_EXISTING` 在当前设计中不会把旧值发送给 L3。如果用户依赖旧值,应使用 `INOUT`。 - `_write_output_tensors()` 依赖 output tensor 顺序与 L3 返回顺序一致。 -### 2.5 L3Daemon 与 backend process +### 6.5 L3Daemon 与 backend process 代码: @@ -348,7 +484,7 @@ flowchart LR - 在有活跃 gRPC 线程的进程里 fork 风险较高。 - backend process 在 gRPC server 启动前创建,后续 Worker fork 发生在 backend 中。 -#### 2.5.1 backend op +#### 6.5.1 backend op `_backend_loop()` 当前处理: @@ -365,7 +501,7 @@ dispatch TensorPool gRPC facade 不直接持有 pool,只把请求 serialize 后通过 Pipe 发送给 backend。 -#### 2.5.2 dispatch 执行策略 +#### 6.5.2 dispatch 执行策略 ```mermaid flowchart TD @@ -393,7 +529,7 @@ Review 关注点: - Pipe 同步调用由 `_backend_lock` 串行化,目前没有并发 dispatch 并行执行。 - backend 异常会被序列化为 `(False, traceback)`,daemon handler 转为 error resp 或 gRPC abort。 -### 2.6 Tensor serialization +### 6.6 Tensor serialization 代码: @@ -413,7 +549,7 @@ Review 关注点: - `decode_task_args_with_tensor_refs_and_writebacks()` - `encode_output_tensor_refs()` -#### 2.6.1 L3 materialize TensorRef +#### 6.6.1 L3 materialize TensorRef ```mermaid flowchart TD @@ -427,7 +563,7 @@ flowchart TD `RemoteTensorWriteback` 记录 tensor index 和 L4 output handle,供执行后写回。 -#### 2.6.2 output encoding +#### 6.6.2 output encoding ```mermaid flowchart TD @@ -446,7 +582,7 @@ Review 关注点: - RXE writeback 失败被吞掉并 fallback 到 pool path;这保证语义,但可能隐藏性能路径失败,后续可以加日志。 - `_shape_nbytes()` 依赖 `get_element_size(dtype)`。 -### 2.7 TensorPool +### 6.7 TensorPool 代码: @@ -496,7 +632,7 @@ Review 关注点: - `bytearray` 地址在 entry 生命周期内稳定;但 Python 对象生命周期必须由 `_Entry.data` 持有。 - gRPC `PushTensor` fallback 仍写同一个 pool buffer。 -### 2.8 Transport backend +### 6.8 Transport backend 代码: @@ -512,13 +648,13 @@ class TensorTransportBackend: def unregister_region(self, region: RegisteredRegion) -> None: ... ``` -#### 2.8.1 gRPC backend +#### 6.8.1 gRPC backend `GrpcTensorTransport` 只是返回本地 buffer 地址,实际传输走 `PushTensor/PullTensor`。 适合作为 fallback 和默认路径。 -#### 2.8.2 RXE backend +#### 6.8.2 RXE backend Python 层: @@ -574,7 +710,7 @@ Review 关注点: - `_build_rxe_verbs_helper()` 动态编译 C helper 到 `.cache`,依赖本机 rdma-core include/lib。 - desc parser 兼容旧 JSON desc。 -#### 2.8.3 HCOMM backend +#### 6.8.3 HCOMM backend HCOMM 只做 Simpler 侧适配: @@ -589,9 +725,9 @@ Review 关注点: - `auto` 默认不会因为 HCOMM/RXE 不可用而破坏 gRPC。 - HCOMM source tree 不应被本项目提交依赖修改。 -## 3. 端到端流程详解 +## 7. 端到端流程详解 -### 3.1 初始化流程 +### 7.1 初始化流程 ```mermaid sequenceDiagram @@ -616,7 +752,7 @@ sequenceDiagram W4->>Cpp: add_next_level_process(remote mailbox) ``` -### 3.2 scalar-only dispatch +### 7.2 scalar-only dispatch ```mermaid sequenceDiagram @@ -640,7 +776,7 @@ sequenceDiagram Shim->>Cpp: TASK_DONE ``` -### 3.3 large input with RXE +### 7.3 large input with RXE ```mermaid sequenceDiagram @@ -663,7 +799,7 @@ sequenceDiagram Proxy->>L3: DispatchReq TensorRef(handle) ``` -### 3.4 large output with RXE writeback +### 7.4 large output with RXE writeback ```mermaid sequenceDiagram @@ -696,15 +832,15 @@ flowchart TD E --> F[memmove into local output ptr] ``` -## 4. Error handling and lifecycle +## 8. Error handling and lifecycle -### 4.1 remote unavailable +### 8.1 remote unavailable - heartbeat RPC 失败:`RemoteUnavailable`。 - dispatch RPC 失败:mark proxy unavailable,free already allocated remote handles。 - remote response `error_code != 0`:free remote handles and close local output regions, raise `RuntimeError` with remote traceback. -### 4.2 tensor handle cleanup +### 8.2 tensor handle cleanup L4 owns two kinds of temporary resources: @@ -731,7 +867,7 @@ mmap keepalive buffers: kept alive until run and output encoding finish, then cleared. ``` -### 4.3 fallback policy +### 8.3 fallback policy ```text L4 -> L3 input: @@ -748,7 +884,7 @@ Review 关注点: - explicit `rxe` input 失败不 fallback,这是为了避免用户以为走了真实 RDMA。 - cleanup 是 best-effort,失败不会覆盖原始 dispatch 错误。 -## 5. 当前测试覆盖 +## 9. 当前测试覆盖 主要测试: @@ -807,7 +943,7 @@ PYTHONPATH=python tools/benchmark_rxe_data_plane.py \ --transports grpc,rxe ``` -## 6. 已知局限 +## 10. 已知局限 1. **RXE helper 不是连接池** @@ -841,7 +977,7 @@ PYTHONPATH=python tools/benchmark_rxe_data_plane.py \ HCOMM adapter 可加载 endpoint/mem,CPU RoCE channel E2E 在当前 910B1 host 环境不是主验证路径。 -## 7. Review Checklist +## 11. Review Checklist ### 控制面 @@ -882,7 +1018,7 @@ PYTHONPATH=python tools/benchmark_rxe_data_plane.py \ - 是否需要跨 host RoCE smoke。 - 是否需要 benchmark 阈值或只保留观测输出。 -## 8. 建议 reviewer 先看的文件顺序 +## 12. 建议 reviewer 先看的文件顺序 1. `python/simpler/worker.py` From 2dd89eeaff9164166a6b4f36edce3c4621777b53 Mon Sep 17 00:00:00 2001 From: PKUZHOU Date: Fri, 8 May 2026 11:20:53 +0800 Subject: [PATCH 6/6] docs(distributed): add L4 review glossary --- docs/l4-l3-distributed-review-guide.zh.md | 45 +++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/docs/l4-l3-distributed-review-guide.zh.md b/docs/l4-l3-distributed-review-guide.zh.md index 19199476c..b1c4e5db4 100644 --- a/docs/l4-l3-distributed-review-guide.zh.md +++ b/docs/l4-l3-distributed-review-guide.zh.md @@ -21,6 +21,51 @@ - 生产设计中的 SO uRPC hot-path 协议栈。 - UBL128/PC16/SSU 的配置驱动拓扑管理。 +## 术语表 + +这张表只解释本文 review 里会反复出现的词,重点是帮助 reviewer 判断“当前代码实现的是哪一层能力”。 + +| 术语 | 含义 | 在当前实现里的状态 | +|------|------|--------------------| +| L4 | Simpler runtime 的 level-4 orchestrator/parent worker 层,负责向下一层 worker 派发任务。 | 已实现远程 L3 派发入口;不是完整 serving frontend。 | +| L3 | Simpler runtime 的 level-3 worker 层,接收 L4 派发的 task 并继续执行/下发。 | 远端由 `L3Daemon` + backend process 承载。 | +| Worker | Simpler 的层级执行单元。L4/L3/L2 等层级通过 mailbox 或 RPC 组织任务。 | 当前复用本地 hierarchical worker 模型,并用 remote mailbox shim 接远端 L3。 | +| L3Daemon | 远端 L3 的 gRPC server 进程,接收 dispatch、catalog、TensorPool RPC。 | 已实现;真正执行在 daemon 启动前 fork 出来的 backend process。 | +| backend process | L3 daemon 后面的执行进程,持有 `Worker(level=3)`、`TensorPool` 和 transport backend。 | 已实现;避免在已有 gRPC 线程的进程里继续 fork worker。 | +| 控制面 | 传 task 元数据、callable id、tensor handle、错误码、heartbeat 等小消息的路径。 | 当前主要是 gRPC + protobuf。 | +| 数据面 | 传 tensor bytes 的路径。 | 当前支持 inline、gRPC chunk、RXE RDMA write,HCOMM 为可选适配。 | +| Dispatch | L4 把一个 callable + args + config 发给 L3 执行的动作。 | 已实现 `DispatchReq/DispatchResp`。 | +| Callable Catalog | callable 的分发/同步表。L4 注册 Python 函数后,把序列化 payload 推给 L3。 | 已实现;使用 `cloudpickle`,只适合受信任集群。 | +| closure | Python 函数捕获的外部变量。 | 会被序列化到 L3;远端修改的是副本,不会改 L4 本地对象。 | +| TensorPool | L3 backend 内部的 tensor byte storage 和 handle 管理器。 | 已实现;当前是进程内 bytearray storage,不是 UBL128 设计里的 SSU KV 存储。 | +| TensorRef | dispatch 消息里的 tensor 参数描述,可能是 inline bytes,也可能指向 `TensorHandle`。 | 已实现。 | +| TensorHandle | TensorPool/transport 返回的远端数据句柄,包含 shape、dtype、size、transport、transport_desc 等。 | 已实现。 | +| inline tensor | 小 tensor 直接放进 protobuf 消息里传输。 | 已实现;由 inline threshold 控制。 | +| transport backend | TensorPool 注册内存和搬运 bytes 的后端抽象。 | 已实现 `grpc`、`rxe`、`hcomm` 三类入口。 | +| gRPC | 基于 HTTP/2 + protobuf 的通用 RPC。 | 当前控制面主路径;在 UBL128 设计里适合 DCN/运维/非 hot-path。 | +| RDMA | Remote Direct Memory Access,远端直接内存访问,允许一端把数据直接写入/读出另一端注册内存。 | 当前 RXE helper 用 RDMA write 做实机数据面 smoke。 | +| RoCE | RDMA over Converged Ethernet,在以太网上承载 RDMA verbs。 | 当前机器 `ibv_devices` 可见时用于 smoke;不是 UBL128 SO/UBG 生产路径。 | +| RXE | Linux Soft-RoCE provider,用软件方式把普通以太网设备暴露成 RDMA verbs 设备。 | 当前用于开发/烟测真实 ibverbs 流程;性能和生产硬件 RoCE/UB 不是一回事。 | +| ibverbs / verbs | Linux RDMA 用户态接口族,应用通过 verbs 创建 MR/QP/CQ 并发起 RDMA write/read。 | `rxe_verbs_helper.c` 直接使用。 | +| MR / QP / CQ | RDMA 里的 Memory Region、Queue Pair、Completion Queue。 | RXE helper 内部使用;文档里不要求 reviewer 深入 verbs 细节。 | +| HCOMM | 昇腾通信相关库/接口集合,这里作为可选通信后端候选。 | 当前只在 Simpler 侧做适配层,不依赖修改 `3rd/hcomm`。 | +| CANN | 昇腾软件栈基础组件。 | HCOMM runtime 可能依赖;当前 L4/L3 主路径不依赖 NPU CANN 执行。 | +| UB / UBG | UBL128 设计里的统一总线/Scale Out 物理协议族。 | 当前代码没有实现 UB/UBG。 | +| Urma | UB 上的可靠内存访问语义层,类似面向 UB 网络的 RDMA。 | UBL128 目标路径;当前未实现。 | +| uRPC | 基于 Urma 的轻量 RPC,用于 SO/SU hot-path 内部 RPC。 | UBL128 目标路径;当前未实现,当前用 gRPC 做原型。 | +| DCN | 数据中心 RoCE/TCP 网络,CPU 可见,承载外部接入、运维和非 hot-path 服务。 | 当前 gRPC 原型可类比 DCN 控制路径。 | +| SO | Scale Out 网络,跨 UBL128 的 NPU/CPU/SSU any-to-any 网络。 | UBL128 目标数据面/热控制面;当前未实现。 | +| SU | Scale Up 网络,UBL128 域内 NPU 间高带宽网络,主要承载 EP/DP。 | 当前未实现。 | +| UBL128 | 设计文档里的 128 NPU high bandwidth domain。 | 当前代码不感知该拓扑。 | +| PC16 | 设计文档里的 16 NPU 服务器单元。 | 当前代码不感知该拓扑。 | +| SSU / SSU12 | 设计文档里的 KV/prefix 持久化存储单元/机框。 | 当前未实现;不要和 TensorPool 混淆。 | +| KV cache | LLM attention 的 key/value 缓存。 | UBL128 核心设计内容;当前 L4/L3 原型未实现 KV 语义。 | +| Prefix cache | 按 token prefix 复用已计算 KV 的 cache。 | 当前未实现。 | +| ChunkRecord | UBL128 KV Meta Server 返回的 KV block 元数据,描述每层 KV 在哪个 SSU/LBA。 | 当前未实现。 | +| LBA | Logical Block Address,SSU 上无文件系统的块地址。 | 当前未实现。 | +| HBM | NPU 高带宽显存。 | 当前 RXE smoke 传的是 host memory,不是 NPU HBM。 | +| `INPUT` / `OUTPUT` / `INOUT` | Simpler tensor 参数方向 tag。 | 已支持远程参数语义;`INOUT` 尚未做完整双向 RXE fast path。 | + ## 1. UBL128 Serving 总设计先做什么 `UBL128_serving.md` 的目标是一个带 prefix cache 的 prefill/decode 解耦推理服务。它不是单个 runtime API,而是一套跨 CPU、NPU、SSU 存储和多张网络的完整 serving 系统。