diff --git a/.github/configs/amd-master.yaml b/.github/configs/amd-master.yaml index c362604f1..d0c992cdc 100644 --- a/.github/configs/amd-master.yaml +++ b/.github/configs/amd-master.yaml @@ -1031,6 +1031,81 @@ dsr1-fp8-mi355x-sglang-disagg-mtp: # - "DECODE_MTP_SIZE=0" +dsr1-fp8-mi355x-vllm-disagg: + image: vllm/vllm-openai-rocm:v0.17.1 + model: deepseek-ai/DeepSeek-R1-0528 + model-prefix: dsr1 + runner: mi355x-disagg + precision: fp8 + framework: vllm-disagg + multinode: true + disagg: true + seq-len-configs: + - isl: 1024 + osl: 1024 + search-space: + # 1P2D: 1 prefill node (co-located with proxy) + 2 decode nodes = 3 nodes total + - spec-decoding: "none" + conc-list: [ 8, 16, 32, 64, 128, 256, 512 ] + prefill: + num-worker: 1 + tp: 8 + ep: 1 + dp-attn: false + additional-settings: + - "PREFILL_NODES=1" + - "VLLM_MORIIO_CONNECTOR_READ_MODE=1" + decode: + num-worker: 2 + tp: 8 + ep: 8 + dp-attn: false + additional-settings: + - "DECODE_NODES=2" + + - isl: 8192 + osl: 1024 + search-space: + - spec-decoding: "none" + conc-list: [ 8, 16, 32, 64, 128, 256, 512 ] + prefill: + num-worker: 1 + tp: 8 + ep: 1 + dp-attn: false + additional-settings: + - "PREFILL_NODES=1" + - "VLLM_MORIIO_CONNECTOR_READ_MODE=1" + decode: + num-worker: 2 + tp: 8 + ep: 8 + dp-attn: false + additional-settings: + - "DECODE_NODES=2" + + - isl: 1024 + osl: 8192 + search-space: + - spec-decoding: "none" + conc-list: [ 8, 16, 32, 64, 128, 256, 512 ] + prefill: + num-worker: 1 + tp: 8 + ep: 1 + dp-attn: false + additional-settings: + - "PREFILL_NODES=1" + - "VLLM_MORIIO_CONNECTOR_READ_MODE=1" + decode: + num-worker: 2 + tp: 8 + ep: 8 + dp-attn: false + additional-settings: + - "DECODE_NODES=2" + + dsr1-fp4-mi355x-sglang-disagg: image: rocm/sgl-dev:sglang-0.5.9-rocm720-mi35x-mori-0227-3 model: amd/DeepSeek-R1-0528-MXFP4 diff --git a/benchmarks/multi_node/dsr1_fp8_mi355x_vllm-disagg.sh b/benchmarks/multi_node/dsr1_fp8_mi355x_vllm-disagg.sh new file mode 100755 index 000000000..b21e9204a --- /dev/null +++ b/benchmarks/multi_node/dsr1_fp8_mi355x_vllm-disagg.sh @@ -0,0 +1,79 @@ +#!/usr/bin/env bash + +source "$(dirname "$0")/../benchmark_lib.sh" + +check_env_vars \ + CONC_LIST \ + ISL \ + OSL \ + IMAGE \ + SPEC_DECODING \ + MODEL_PATH \ + PREFILL_NUM_WORKERS \ + PREFILL_TP \ + PREFILL_EP \ + PREFILL_DP_ATTN \ + DECODE_NUM_WORKERS \ + DECODE_TP \ + DECODE_EP \ + DECODE_DP_ATTN \ + PREFILL_NODES \ + DECODE_NODES \ + RANDOM_RANGE_RATIO + +if [[ -n "$SLURM_JOB_ID" ]]; then + echo "JOB $SLURM_JOB_ID running on $SLURMD_NODENAME" +fi + +set -x + +cd "$GITHUB_WORKSPACE/benchmarks/multi_node/vllm_disagg_utils" || exit 1 + +export TIME_LIMIT="08:00:00" +export MODEL_PATH=$MODEL_PATH +export MODEL_NAME=$MODEL_NAME +export CONTAINER_IMAGE=$IMAGE + +# Same EP/DP booleans as dsr1_fp8_mi355x_sglang-disagg.sh → amd_utils/submit.sh +if [[ "${PREFILL_EP:-1}" -eq 1 ]]; then + export PREFILL_ENABLE_EP=false +else + export PREFILL_ENABLE_EP=true +fi + +if [[ "$PREFILL_DP_ATTN" == "true" ]]; then + export PREFILL_ENABLE_DP=true +else + export PREFILL_ENABLE_DP=false +fi + +if [[ "${DECODE_EP:-1}" -eq 1 ]]; then + export DECODE_ENABLE_EP=false +else + export DECODE_ENABLE_EP=true +fi + +if [[ "$DECODE_DP_ATTN" == "true" ]]; then + export DECODE_ENABLE_DP=true +else + export DECODE_ENABLE_DP=false +fi + +# Parameter order matches SGLang disagg submit.sh; arg 16 is optional NODELIST. +JOB_ID=$(bash ./submit.sh $PREFILL_NODES \ + $PREFILL_NUM_WORKERS \ + $DECODE_NODES \ + $DECODE_NUM_WORKERS \ + $ISL $OSL "${CONC_LIST// /x}" inf \ + ${PREFILL_ENABLE_EP} ${PREFILL_ENABLE_DP} \ + ${DECODE_ENABLE_EP} ${DECODE_ENABLE_DP} \ + ${PREFILL_TP} ${DECODE_TP} \ + ${RANDOM_RANGE_RATIO} \ + "${NODELIST:-}") + +if [[ $? -ne 0 ]]; then + echo "Failed to submit job" >&2 + exit 1 +fi + +echo "$JOB_ID" diff --git a/benchmarks/multi_node/vllm_disagg_utils/bench.sh b/benchmarks/multi_node/vllm_disagg_utils/bench.sh new file mode 100755 index 000000000..5b9f5c772 --- /dev/null +++ b/benchmarks/multi_node/vllm_disagg_utils/bench.sh @@ -0,0 +1,75 @@ +#!/bin/bash +# vLLM Disaggregated Benchmark Runner +# +# Produces JSON result files via benchmark_serving.py (same as SGLang bench.sh) +# so that the CI pipeline can collect and process results. +# +# Usage: bash bench.sh \ +# \ +# + +n_prefill=$1 +n_decode=$2 +prefill_gpus=$3 +decode_gpus=$4 +model_path=$5 +model_name=$6 +MODEL_PATH="${MODEL_PATH:-${model_path}/${model_name}}" +log_path=$7 + +chosen_isl=${8:-1024} +chosen_osl=${9:-1024} +concurrency_list=${10:-"512x1"} +chosen_req_rate=${11:-inf} +random_range_ratio=${12:-0.8} +num_prompts_multiplier=${13:-10} + +IFS='x' read -r -a chosen_concurrencies <<< "$concurrency_list" + +ROUTER_PORT="${ROUTER_PORT:-30000}" + +echo "Config ${chosen_isl}; ${chosen_osl}; ${chosen_concurrencies[0]}; ${chosen_req_rate}" + +profile_folder="${log_path}/vllm_isl_${chosen_isl}_osl_${chosen_osl}" +mkdir -p "$profile_folder" + +source "$(dirname "$0")/../../benchmark_lib.sh" + +REPO_ROOT="$(cd "$(dirname "$0")/../../.." && pwd)" + +for max_concurrency in "${chosen_concurrencies[@]}"; do + + export_file="${profile_folder}/concurrency_${max_concurrency}_req_rate_${chosen_req_rate}_gpus_$((prefill_gpus+decode_gpus))_ctx_${prefill_gpus}_gen_${decode_gpus}" + + num_prompts=$(( max_concurrency * num_prompts_multiplier )) + if [[ "$num_prompts" -lt 16 ]]; then + num_prompts=16 + fi + + echo "profile_folder: $profile_folder" + echo "max_concurrency: $max_concurrency" + echo "chosen_req_rate: $chosen_req_rate" + echo "MODEL_PATH: $MODEL_PATH" + echo "ROUTER_PORT: $ROUTER_PORT" + echo "chosen_isl: $chosen_isl" + echo "chosen_osl: $chosen_osl" + echo "num_prompts: $num_prompts" + echo "export_file: $export_file" + + run_benchmark_serving \ + --bench-serving-dir "$REPO_ROOT" \ + --model "$MODEL_PATH" \ + --port "$ROUTER_PORT" \ + --backend openai \ + --input-len "$chosen_isl" \ + --output-len "$chosen_osl" \ + --random-range-ratio "$random_range_ratio" \ + --num-prompts "$num_prompts" \ + --max-concurrency "$max_concurrency" \ + --result-filename "$export_file" \ + --result-dir /workspace/ + + echo "-----------------------------------------" + echo "[BENCH] Cooldown: waiting 10s for idle KV block reaper..." + sleep 10 +done diff --git a/benchmarks/multi_node/vllm_disagg_utils/env.sh b/benchmarks/multi_node/vllm_disagg_utils/env.sh new file mode 100755 index 000000000..e1cc2f6af --- /dev/null +++ b/benchmarks/multi_node/vllm_disagg_utils/env.sh @@ -0,0 +1,98 @@ +#!/bin/bash +# vLLM/Nixl environment setup for multi-node disaggregated serving. +# +# REQUIRED ENVIRONMENT VARIABLES: +# IBDEVICES - RDMA/InfiniBand device names (e.g., ionic_0,ionic_1,... or mlx5_0,mlx5_1,...) +# Set by runner or auto-detected from hostname. +# +# UCX and RIXL paths (LD_LIBRARY_PATH, PATH) are set by setup_deps.sh, which is +# sourced at the top of server.sh before this file. + +set -x + +# IBDEVICES configuration +# Prefer IBDEVICES set by runner (runners/launch_mi355x-amds.sh) +# Fall back to hostname detection if not set (for direct script execution) +if [[ -z "$IBDEVICES" ]]; then + NODENAME=$(hostname -s) + if [[ $NODENAME == GPU* ]] || [[ $NODENAME == smci355-ccs-aus* ]]; then + export IBDEVICES=ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7 + elif [[ $NODENAME == mia1* ]]; then + export IBDEVICES=rdma0,rdma1,rdma2,rdma3,rdma4,rdma5,rdma6,rdma7 + else + DETECTED=$(ibv_devinfo 2>/dev/null | grep "hca_id:" | awk '{print $2}' | paste -sd',') + if [[ -n "$DETECTED" ]]; then + export IBDEVICES="$DETECTED" + else + echo "WARNING: Unable to detect RDMA devices. Set IBDEVICES explicitly." >&2 + fi + fi + echo "[INFO] Auto-detected IBDEVICES=$IBDEVICES from hostname $(hostname -s)" +else + echo "[INFO] Using IBDEVICES=$IBDEVICES (set by runner or environment)" +fi + +if [[ -z "$UCX_NET_DEVICES" ]]; then + # Use the first benic interface for UCX TCP transport (maps to ionic RDMA NIC). + # We use TCP device names (benicXp1) instead of IB device names (ionic_X:1) + # because ud_verbs/ionic crashes in ucp_request_memory_dereg (UCX bug with ionic provider). + UCX_NET_DEV=$(ip -o link show 2>/dev/null | awk -F': ' '/benic1p1/{print $2}' | head -1) + if [[ -n "$UCX_NET_DEV" ]]; then + export UCX_NET_DEVICES="$UCX_NET_DEV" + else + FIRST_IB=$(echo "$IBDEVICES" | cut -d',' -f1) + if [[ -n "$FIRST_IB" ]]; then + export UCX_NET_DEVICES="${FIRST_IB}:1" + fi + fi + echo "[INFO] Auto-set UCX_NET_DEVICES=$UCX_NET_DEVICES" +else + echo "[INFO] Using UCX_NET_DEVICES=$UCX_NET_DEVICES (set by environment)" +fi + +export NCCL_SOCKET_IFNAME=$(ip route | grep '^default' | awk '{print $5}' | head -n 1) +export NCCL_IB_HCA=${NCCL_IB_HCA:-$IBDEVICES} + +# RoCEv2: use IPv4-mapped GID (index 1) for inter-node RDMA routing +export UCX_IB_GID_INDEX=${UCX_IB_GID_INDEX:-1} + +# QoS/DSCP configuration for lossless RoCEv2 fabric. +# Priority order: 1) Set by runner, 2) Detect via nicctl, 3) Detect from hostname +if [[ -n "$UCX_IB_TRAFFIC_CLASS" ]]; then + echo "[INFO] Using UCX_IB_TRAFFIC_CLASS=$UCX_IB_TRAFFIC_CLASS (set by environment)" +elif command -v nicctl &> /dev/null; then + ND_PRIO=$(nicctl show qos 2>/dev/null | awk '/PFC no-drop priorities/ {print $NF; exit}') + ND_DSCP=$(nicctl show qos 2>/dev/null | awk -v p="$ND_PRIO" ' +$1 == "DSCP" && $2 == ":" && $NF == p { + print $3; exit +}') + if [[ -n "$ND_DSCP" ]] && [[ -n "$ND_PRIO" ]]; then + export UCX_IB_TRAFFIC_CLASS=$(( 4 * ND_DSCP )) + export UCX_IB_SL=$ND_PRIO + echo "[INFO] Detected QoS from nicctl: UCX_IB_TRAFFIC_CLASS=$UCX_IB_TRAFFIC_CLASS, UCX_IB_SL=$UCX_IB_SL" + else + echo "[WARN] nicctl available but QoS data unavailable; trying hostname detection." + NODENAME=$(hostname -s) + if [[ $NODENAME == GPU* ]] || [[ $NODENAME == smci355-ccs-aus* ]]; then + export UCX_IB_TRAFFIC_CLASS=96 + echo "[INFO] Auto-detected UCX_IB_TRAFFIC_CLASS=$UCX_IB_TRAFFIC_CLASS from hostname $NODENAME" + elif [[ $NODENAME == mia1* ]]; then + export UCX_IB_TRAFFIC_CLASS=104 + echo "[INFO] Auto-detected UCX_IB_TRAFFIC_CLASS=$UCX_IB_TRAFFIC_CLASS from hostname $NODENAME" + fi + fi +else + NODENAME=$(hostname -s) + if [[ $NODENAME == GPU* ]] || [[ $NODENAME == smci355-ccs-aus* ]]; then + export UCX_IB_TRAFFIC_CLASS=96 + echo "[INFO] Auto-detected UCX_IB_TRAFFIC_CLASS=$UCX_IB_TRAFFIC_CLASS from hostname $NODENAME" + elif [[ $NODENAME == mia1* ]]; then + export UCX_IB_TRAFFIC_CLASS=104 + echo "[INFO] Auto-detected UCX_IB_TRAFFIC_CLASS=$UCX_IB_TRAFFIC_CLASS from hostname $NODENAME" + else + echo "[INFO] No nicctl and unable to detect from hostname. Skipping QoS configuration." + fi +fi + +set +x +echo "[INFO] IBDEVICES=$IBDEVICES UCX_NET_DEVICES=$UCX_NET_DEVICES NCCL_SOCKET_IFNAME=$NCCL_SOCKET_IFNAME UCX_IB_GID_INDEX=$UCX_IB_GID_INDEX UCX_IB_TRAFFIC_CLASS=${UCX_IB_TRAFFIC_CLASS:-unset}" diff --git a/benchmarks/multi_node/vllm_disagg_utils/job.slurm b/benchmarks/multi_node/vllm_disagg_utils/job.slurm new file mode 100644 index 000000000..e1cad0817 --- /dev/null +++ b/benchmarks/multi_node/vllm_disagg_utils/job.slurm @@ -0,0 +1,358 @@ +#!/bin/bash +#SBATCH --job-name=vllm-pd-bench +#SBATCH -N 3 # Overridden by submit.sh -N flag +#SBATCH -n 3 # Overridden by submit.sh -n flag +#SBATCH --ntasks-per-node=1 +#SBATCH --spread-job +#SBATCH --gres=gpu:8 +#SBATCH --time=24:00:00 +# --output and --error are set by submit.sh via BENCHMARK_LOGS_DIR + +echo "=== Job Start Time ===" +echo "UTC Time: $(TZ=UTC date '+%Y-%m-%d %H:%M:%S %Z')" +echo "PST Time: $(TZ=America/Los_Angeles date '+%Y-%m-%d %H:%M:%S %Z')" +echo "=======================" +echo "" + +# ============================================================================= +# Model Validation +# ============================================================================= + +# Use $(pwd) not BASH_SOURCE — sbatch copies the script to /var/spool/slurmd/ +# at runtime, but the CWD remains the submit-time directory (vllm_disagg_utils/). +MODELS_YAML="$(pwd)/models.yaml" + +if [[ ! -f "$MODELS_YAML" ]]; then + echo "Error: models.yaml not found at $MODELS_YAML" + exit 1 +fi + +if [[ -z "${DOCKER_IMAGE_NAME:-}" ]]; then + echo "Error: DOCKER_IMAGE_NAME is not set." + exit 1 +fi + +MODEL_NAME="${MODEL_NAME:-None}" +if ! grep -q "^${MODEL_NAME}:" "$MODELS_YAML"; then + echo "Error: Model '$MODEL_NAME' not found in models.yaml" + echo "Available models:" + grep -E '^[A-Za-z]' "$MODELS_YAML" | sed 's/:.*$//' | sed 's/^/ - /' + exit 1 +fi +echo "Model found: $MODEL_NAME" + +RUN_FILE="server.sh" +echo "Runfile set: $RUN_FILE" + +# DI_REPO_DIR points to the repo root. +# $(pwd) is vllm_disagg_utils/ (the sbatch submit dir); go up 3 levels to reach the repo root. +export DI_REPO_DIR=$(cd "$(pwd)/../../.." && pwd) + +xP="${xP:-1}" +yD="${yD:-1}" + +# Benchmark configuration +BENCH_INPUT_LEN="${BENCH_INPUT_LEN:-1024}" +BENCH_OUTPUT_LEN="${BENCH_OUTPUT_LEN:-1024}" +BENCH_RANDOM_RANGE_RATIO="${BENCH_RANDOM_RANGE_RATIO:-1}" +BENCH_NUM_PROMPTS_MULTIPLIER="${BENCH_NUM_PROMPTS_MULTIPLIER:-10}" +BENCH_MAX_CONCURRENCY="${BENCH_MAX_CONCURRENCY:-512}" +BENCH_REQUEST_RATE="${BENCH_REQUEST_RATE:-inf}" + +GPUS_PER_NODE="${GPUS_PER_NODE:-8}" + +# ============================================================================= +# Docker privilege detection +# ============================================================================= +# Detect on the batch host (used for post-srun cleanup). +# Per-node detection happens inside the srun inline script below because +# some nodes may require sudo while others do not. +if docker ps &>/dev/null; then + DOCKER_CMD="docker" +else + DOCKER_CMD="sudo docker" +fi +export DOCKER_CMD + +# ============================================================================= +# Model Path Resolution +# ============================================================================= + +# MODEL_DIR detection: prefer env var, fall back to hostname detection +if [[ -z "$MODEL_DIR" ]]; then + NODENAME=$(hostname -s) + if [[ $NODENAME == GPU* ]] || [[ $NODENAME == smci355-ccs-aus* ]]; then + MODEL_DIR="/nfsdata" + elif [[ $NODENAME == mia1* ]]; then + MODEL_DIR="/it-share/data" + else + MODEL_DIR="/nfsdata" + fi + echo "[INFO] Auto-detected MODEL_DIR=$MODEL_DIR from hostname $(hostname -s)" +fi +export MODEL_DIR + +# Extract hf_dir from models.yaml (the line after the model's top-level key) +DISK_DIR_NAME=$(awk '/^'"$MODEL_NAME"':/{found=1; next} + found && /^[^ ]/{exit} + found && /hf_dir:/{gsub(/[" ]/, "", $2); print $2; exit}' "$MODELS_YAML") +DISK_DIR_NAME="${DISK_DIR_NAME:-$MODEL_NAME}" +echo "Looking for model: $MODEL_NAME (disk dir: $DISK_DIR_NAME)" + +resolve_hf_cache_path() { + local base_path=$1 + if [[ -d "${base_path}/snapshots" ]]; then + local snapshot=$(ls -1 "${base_path}/snapshots" 2>/dev/null | head -1) + if [[ -n "$snapshot" ]]; then + echo "${base_path}/snapshots/${snapshot}" + return 0 + fi + fi + echo "$base_path" + return 1 +} + +MODEL_PATH="" +SEARCH_PATHS=( + "${MODEL_DIR}/${DISK_DIR_NAME}" + "${MODEL_DIR}/${MODEL_NAME}" + "/nfsdata/hf_hub_cache-0/${DISK_DIR_NAME}" + "/nfsdata/hf_hub_cache-0/${MODEL_NAME}" +) + +for search_path in "${SEARCH_PATHS[@]}"; do + if [[ -d "$search_path" ]]; then + RESOLVED=$(resolve_hf_cache_path "$search_path") + MODEL_PATH="$RESOLVED" + echo "Found MODEL_PATH: $MODEL_PATH" + break + fi +done + +if [[ -z "$MODEL_PATH" ]]; then + echo "FATAL: Model '$MODEL_NAME' not found. Searched:" + for p in "${SEARCH_PATHS[@]}"; do echo " - $p"; done + exit 1 +fi +echo "Final MODEL_PATH: $MODEL_PATH" + +# ============================================================================= +# Node Selection and vLLM-Specific NUM_NODES +# ============================================================================= + +# Router co-located with first prefill: xP + yD nodes total (same as SGLang) +NUM_NODES=$((xP + yD)) +echo "NUM_NODES: $NUM_NODES (xP=$xP + yD=$yD, proxy co-located with first prefill)" + +FULL_NODELIST=$(scontrol show hostnames "$SLURM_JOB_NODELIST") +SELECTED_NODES=$(echo "$FULL_NODELIST" | head -n $NUM_NODES) +SELECTED_NODELIST_STR=$(echo "$SELECTED_NODES" | tr '\n' ',' | sed 's/,$//') + +# Update SLURM environment variables +export SLURM_NNODES=$NUM_NODES +export SLURM_NTASKS=$NUM_NODES +export SLURM_JOB_NUM_NODES=$NUM_NODES +export SLURM_NPROCS=$NUM_NODES +export SLURM_JOB_NODELIST="$SELECTED_NODELIST_STR" +export SLURM_NODELIST="$SELECTED_NODELIST_STR" +export SLURM_TASKS_PER_NODE="1(x$NUM_NODES)" +export SLURM_NTASKS_PER_NODE=1 + +echo "" +echo "Selected nodes: $SELECTED_NODELIST_STR" + +# ============================================================================= +# IP Resolution +# ============================================================================= + +USER_NAME=$(whoami) +MASTER_NODE=$(echo "$SELECTED_NODES" | head -n 1) +NODE0_ADDR=$(srun --nodes=1 --ntasks=1 --time=00:20:00 --nodelist="$MASTER_NODE" bash -c 'ip route get 1.1.1.1') +NODE0_ADDR=$(echo "$NODE0_ADDR" | awk '/src/ {print $7}') + +IPS=() +for NODE in $SELECTED_NODES; do + IP=$(srun --nodes=1 --ntasks=1 --time=00:20:00 --nodelist="$NODE" bash -c 'ip route get 1.1.1.1') + IP=$(echo "$IP" | awk '/src/ {print $7}') + IPS+=("$IP") +done + +echo "Node IPs: ${IPS[*]}" + +DOCKER_MOUNT_PATH="/workspace" +VLLM_WS_PATH="${DOCKER_MOUNT_PATH}/benchmarks/multi_node/vllm_disagg_utils" + +NNODES=$NUM_NODES + +echo "MASTER_NODE: ${MASTER_NODE}" +echo "NODE0_ADDR: ${NODE0_ADDR}" +echo "NNODES: ${NNODES}" +echo "REPO DIR: ${DI_REPO_DIR}" +echo "USER: ${USER_NAME}" + +# Reduce log spam +export TQDM_MININTERVAL=20 + +# Translate the host-resolved MODEL_PATH to the Docker mount namespace +DOCKER_MODEL_PATH="${MODEL_PATH/#$MODEL_DIR//models}" + +export DI_REPO_DIR=$DI_REPO_DIR +export VLLM_WS_PATH=$VLLM_WS_PATH +export NNODES=$NNODES +export NODE0_ADDR=$NODE0_ADDR +export MODEL_PATH=$MODEL_PATH +export MODEL_DIR=$MODEL_DIR +export xP=$xP +export yD=$yD +export MODEL_NAME=$MODEL_NAME +export USER_NAME=$USER_NAME +export IPADDRS="$(echo "${IPS[*]}" | sed 's/ /,/g')" +export GPUS_PER_NODE=$GPUS_PER_NODE +export BENCH_INPUT_LEN=$BENCH_INPUT_LEN +export BENCH_OUTPUT_LEN=$BENCH_OUTPUT_LEN +export BENCH_RANDOM_RANGE_RATIO=$BENCH_RANDOM_RANGE_RATIO +export BENCH_NUM_PROMPTS_MULTIPLIER=$BENCH_NUM_PROMPTS_MULTIPLIER +export BENCH_MAX_CONCURRENCY=$BENCH_MAX_CONCURRENCY +export BENCH_REQUEST_RATE=$BENCH_REQUEST_RATE +export DRY_RUN="${DRY_RUN:-0}" +export BENCHMARK_LOGS_DIR="${BENCHMARK_LOGS_DIR:-$(pwd)/benchmark_logs}" + +# TP / EP / DP (from vllm_disagg_utils/submit.sh; mirrors amd_utils disagg) +export PREFILL_ENABLE_EP="${PREFILL_ENABLE_EP:-false}" +export PREFILL_ENABLE_DP="${PREFILL_ENABLE_DP:-false}" +export DECODE_ENABLE_EP="${DECODE_ENABLE_EP:-false}" +export DECODE_ENABLE_DP="${DECODE_ENABLE_DP:-false}" +export PREFILL_TP="${PREFILL_TP:-8}" +export DECODE_TP="${DECODE_TP:-8}" + +SANITIZED_USER=$(echo "$USER_NAME" | tr -c 'a-zA-Z0-9_.-' '_') +export DOCKER_CONT_NAME="container_vllm_${SANITIZED_USER}_${MODEL_NAME}_${SLURM_JOB_ID}" +export RUN_FILE_FULL="$VLLM_WS_PATH/${RUN_FILE}" + +SELECTED_NODELIST_SRUN=$(echo "$SELECTED_NODES" | paste -sd,) + +cleanup() { + echo "[${SLURM_JOB_ID}] termination received on $(hostname); cleaning up..." + rm -rf ${SLURM_SUBMIT_DIR}/logs 2>/dev/null || true + echo "[${SLURM_JOB_ID}] cleanup done." +} + +trap cleanup INT TERM HUP + +# Force NFS cache refresh on all nodes +echo "Refreshing NFS caches on all nodes..." +srun --nodelist="$SELECTED_NODELIST_SRUN" bash -c ' + sync + ls -la '"$DI_REPO_DIR"'/benchmarks/multi_node/vllm_disagg_utils > /dev/null 2>&1 + stat '"$DI_REPO_DIR"'/benchmarks/multi_node/vllm_disagg_utils/server.sh > /dev/null 2>&1 + cat '"$DI_REPO_DIR"'/benchmarks/multi_node/vllm_disagg_utils/server.sh > /dev/null 2>&1 + echo 3 | sudo tee /proc/sys/vm/drop_caches > /dev/null 2>&1 || true + echo "NFS cache refreshed on $(hostname)" +' + +srun \ + --nodelist="$SELECTED_NODELIST_SRUN" \ + --kill-on-bad-exit=1 \ + --signal=TERM@30 \ + --unbuffered \ + bash -lc " +set -euo pipefail + +echo \"Rank \$SLURM_PROCID on \$(hostname)\" + +# Per-node Docker privilege detection (some nodes need sudo, others don't) +if docker ps &>/dev/null; then + _DCMD=docker +else + _DCMD='sudo docker' +fi + +# Pre-clean (idempotent) +\$_DCMD ps -aq --filter \"name=^container_vllm_\" | xargs -r \$_DCMD rm -f || true +\$_DCMD ps -aq | xargs -r \$_DCMD stop || true + +exec \$_DCMD run --rm \ + --init \ + --stop-timeout 10 \ + --device /dev/dri \ + --device /dev/kfd \ + --device /dev/infiniband \ + --device=/dev/infiniband/rdma_cm \ + --device=/dev/infiniband/uverbs0 \ + --device=/dev/infiniband/uverbs1 \ + --device=/dev/infiniband/uverbs2 \ + --device=/dev/infiniband/uverbs3 \ + --device=/dev/infiniband/uverbs4 \ + --device=/dev/infiniband/uverbs5 \ + --device=/dev/infiniband/uverbs6 \ + --device=/dev/infiniband/uverbs7 \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + --network host \ + --ipc host \ + --group-add video \ + --cap-add SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --privileged \ + -v /sys:/sys \ + $(command -v nicctl >/dev/null 2>&1 && echo "-v $(which nicctl):/usr/sbin/nicctl") \ + -v ${MODEL_DIR}:/models \ + -v \$HOME/.ssh:/root/.ssh \ + --shm-size 128G \ + -v /tmp:/run_logs \ + -v ${BENCHMARK_LOGS_DIR}:/benchmark_logs \ + -v ${DI_REPO_DIR}:${DOCKER_MOUNT_PATH} \ + -e SLURM_JOB_ID=\$SLURM_JOB_ID \ + -e SLURM_JOB_NODELIST=\$SLURM_JOB_NODELIST \ + -e NNODES=\$NNODES \ + -e NODE_RANK=\$SLURM_PROCID \ + -e NODE0_ADDR=\$NODE0_ADDR \ + -e MODEL_DIR=/models \ + -e MODEL_NAME=\$MODEL_NAME \ + -e MODEL_PATH=$DOCKER_MODEL_PATH \ + -e VLLM_WS_PATH=${VLLM_WS_PATH} \ + -e GPUS_PER_NODE=\$GPUS_PER_NODE \ + -e xP=\$xP \ + -e yD=\$yD \ + -e IPADDRS=\$IPADDRS \ + -e BENCH_INPUT_LEN=\$BENCH_INPUT_LEN \ + -e BENCH_OUTPUT_LEN=\$BENCH_OUTPUT_LEN \ + -e BENCH_RANDOM_RANGE_RATIO=\$BENCH_RANDOM_RANGE_RATIO \ + -e BENCH_NUM_PROMPTS_MULTIPLIER=\$BENCH_NUM_PROMPTS_MULTIPLIER \ + -e BENCH_MAX_CONCURRENCY=\$BENCH_MAX_CONCURRENCY \ + -e BENCH_REQUEST_RATE=\$BENCH_REQUEST_RATE \ + -e TQDM_MININTERVAL=\$TQDM_MININTERVAL \ + -e DRY_RUN=\$DRY_RUN \ + -e BENCHMARK_LOGS_DIR=/benchmark_logs \ + -e UCX_TLS=tcp,self,shm,rocm_ipc,rocm_copy,cma \ + -e UCX_SOCKADDR_TLS_PRIORITY=tcp \ + -e UCX_MEMTYPE_CACHE=y \ + -e UCX_RNDV_SCHEME=get_zcopy \ + -e UCX_RNDV_THRESH=4k \ + -e UCX_ROCM_IPC_MIN_ZCOPY=0 \ + -e UCX_LOG_LEVEL=warn \ + -e HSA_ENABLE_SDMA=1 \ + -e PROXY_STREAM_IDLE_TIMEOUT=\${PROXY_STREAM_IDLE_TIMEOUT:-300} \ + -e VLLM_MORIIO_CONNECTOR_READ_MODE=\${VLLM_MORIIO_CONNECTOR_READ_MODE:-1} \ + -e PYTHONPYCACHEPREFIX=/tmp/pycache \ + -e PREFILL_ENABLE_EP=\$PREFILL_ENABLE_EP \ + -e PREFILL_ENABLE_DP=\$PREFILL_ENABLE_DP \ + -e DECODE_ENABLE_EP=\$DECODE_ENABLE_EP \ + -e DECODE_ENABLE_DP=\$DECODE_ENABLE_DP \ + -e PREFILL_TP=\$PREFILL_TP \ + -e DECODE_TP=\$DECODE_TP \ + --name \"$DOCKER_CONT_NAME\" \ + --entrypoint \"\" \ + \"$DOCKER_IMAGE_NAME\" bash -lc ' + mkdir -p /run_logs/slurm_job-'\"\$SLURM_JOB_ID\"' + '"$RUN_FILE_FULL"' 2>&1 | tee /run_logs/slurm_job-'\"\$SLURM_JOB_ID\"'/server_\$(hostname).log + ' + +DOCKER_EXIT_CODE=\$? +if [[ \$DOCKER_EXIT_CODE -ne 0 ]]; then + echo \"ERROR: docker exited rc=\$DOCKER_EXIT_CODE on \$(hostname)\" + exit \$DOCKER_EXIT_CODE +fi +" + +srun --nodelist="$SELECTED_NODELIST_SRUN" bash -c 'if docker ps &>/dev/null; then D=docker; else D="sudo docker"; fi; $D rm -f '"$DOCKER_CONT_NAME"' 2>/dev/null || true' diff --git a/benchmarks/multi_node/vllm_disagg_utils/models.yaml b/benchmarks/multi_node/vllm_disagg_utils/models.yaml new file mode 100644 index 000000000..ef062e5f4 --- /dev/null +++ b/benchmarks/multi_node/vllm_disagg_utils/models.yaml @@ -0,0 +1,41 @@ +# Model-specific vLLM server configurations for disaggregated inference. +# +# Each top-level key is a MODEL_NAME value (must match the model identifier +# used in amd-master.yaml and the directory/HF-cache name under MODEL_DIR). +# +# To add a new model: add a new top-level entry following the same schema. +# No script changes are required. +# +# Schema: +# : +# prefill_flags: str # vLLM CLI flags for prefill workers +# decode_flags: str # vLLM CLI flags for decode workers +# env: str # Space-separated KEY=VALUE pairs exported before vllm serve +# hf_dir: str # (optional) On-disk directory name if it differs from the key +# # e.g. HF cache layout: models--deepseek-ai--DeepSeek-R1-0528 + +Llama-3.1-405B-Instruct-FP8-KV: + prefill_flags: "--tensor-parallel-size 8 --kv-cache-dtype fp8" + decode_flags: "--tensor-parallel-size 8 --kv-cache-dtype fp8" + env: "VLLM_USE_V1=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 AMDGCN_USE_BUFFER_OPS=1 VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_RMSNORM=1 VLLM_USE_AITER_TRITON_ROPE=1 TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE=1 TRITON_HIP_USE_ASYNC_COPY=1 TRITON_HIP_USE_BLOCK_PINGPONG=1 TRITON_HIP_ASYNC_FAST_SWIZZLE=1" + +amd-Llama-3.3-70B-Instruct-FP8-KV: + prefill_flags: "--tensor-parallel-size 8 --max-model-len 65536 --kv-cache-dtype fp8" + decode_flags: "--tensor-parallel-size 8 --max-model-len 65536 --kv-cache-dtype fp8" + env: "VLLM_USE_V1=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 AMDGCN_USE_BUFFER_OPS=1 VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_RMSNORM=1 VLLM_USE_AITER_TRITON_ROPE=1 TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE=1 TRITON_HIP_USE_ASYNC_COPY=1 TRITON_HIP_USE_BLOCK_PINGPONG=1 TRITON_HIP_ASYNC_FAST_SWIZZLE=1" + +DeepSeek-V3: + prefill_flags: "--tensor-parallel-size 8 --compilation-config '{\"cudagraph_mode\":\"PIECEWISE\"}' --no-enable-prefix-caching --block-size 1" + decode_flags: "--tensor-parallel-size 8 --compilation-config '{\"cudagraph_mode\":\"PIECEWISE\"}' --no-enable-prefix-caching --block-size 1" + env: "VLLM_USE_V1=1 VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_PAGED_ATTN=0 VLLM_ROCM_USE_AITER_RMSNORM=1 VLLM_USE_AITER_TRITON_SILU_MUL=0" + +DeepSeek-R1-0528: + prefill_flags: "--tensor-parallel-size 8 --compilation-config '{\"cudagraph_mode\":\"PIECEWISE\"}' --no-enable-prefix-caching --block-size 1" + decode_flags: "--tensor-parallel-size 8 --enable-expert-parallel --all2all-backend mori --compilation-config '{\"cudagraph_mode\":\"PIECEWISE\"}' --no-enable-prefix-caching --block-size 1" + env: "VLLM_USE_V1=1 VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_PAGED_ATTN=0 VLLM_ROCM_USE_AITER_RMSNORM=1 VLLM_USE_AITER_TRITON_SILU_MUL=0 VLLM_ENGINE_READY_TIMEOUT_S=3600" + hf_dir: "models--deepseek-ai--DeepSeek-R1-0528" + +gpt-oss-120b: + prefill_flags: "--tensor-parallel-size 8" + decode_flags: "--tensor-parallel-size 8" + env: "VLLM_USE_V1=1 VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM=0 VLLM_USE_AITER_UNIFIED_ATTENTION=1 VLLM_ROCM_USE_AITER_MHA=0 ROCM_TRITON_MOE_PRESHUFFLE_SCALES=0" diff --git a/benchmarks/multi_node/vllm_disagg_utils/moriio_proxy.py b/benchmarks/multi_node/vllm_disagg_utils/moriio_proxy.py new file mode 100644 index 000000000..b2162c98a --- /dev/null +++ b/benchmarks/multi_node/vllm_disagg_utils/moriio_proxy.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python3 +# MoRI-IO proxy server for vLLM PD disaggregation. +# +# Based on vLLM's examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py +# with the following adaptations for production multi-node use: +# - Ports configurable via PROXY_HTTP_PORT / PROXY_PING_PORT env vars +# - /health endpoint for sync.py barrier readiness checks +# - Uses stdlib `re` instead of `regex` to avoid extra dep +# +# The proxy performs two roles that vllm-router cannot: +# 1. ZMQ service discovery — prefill/decode workers register their RDMA ports +# 2. Request enrichment — injects remote endpoint info into kv_transfer_params + +import asyncio +import copy +import logging +import os +import re +import socket +import threading +import time +import uuid + +import aiohttp +import msgpack +import zmq +from quart import Quart, make_response, request + +logger = logging.getLogger("moriio_proxy") +logger.setLevel(logging.DEBUG) +handler = logging.StreamHandler() +handler.setFormatter(logging.Formatter( + "%(asctime)s %(levelname)s [%(name)s] %(message)s")) +logger.addHandler(handler) + +prefill_instances: list[dict] = [] +decode_instances: list[dict] = [] +request_nums = 0 +app = Quart(__name__) + +STREAM_IDLE_TIMEOUT = int(os.environ.get("PROXY_STREAM_IDLE_TIMEOUT", "300")) + +IP_PORT_PATTERN = re.compile(r"//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)") + +TRANSFER_TYPE = None + + +def _append_whole_dict_unique(target_list, data_dict): + new_filtered = {k: v for k, v in data_dict.items() if k != "index"} + for existed in target_list: + existed_filtered = {k: v for k, v in existed.items() if k != "index"} + if existed_filtered == new_filtered: + return False + logger.info("Registered instance: role=%s addr=%s hs_port=%s notify=%s dp=%s tp=%s", + data_dict.get("role"), data_dict.get("request_address"), + data_dict.get("handshake_port"), data_dict.get("notify_port"), + data_dict.get("dp_size"), data_dict.get("tp_size")) + target_list.append(data_dict) + transfer_mode = data_dict.get("transfer_mode", "unknown") + global TRANSFER_TYPE + + if TRANSFER_TYPE is None: + TRANSFER_TYPE = transfer_mode + logger.info("Transfer mode set to: %s", TRANSFER_TYPE) + elif transfer_mode != TRANSFER_TYPE: + raise ValueError(f"mismatched transfer mode {TRANSFER_TYPE} vs {transfer_mode}") + + return True + + +_list_lock = threading.RLock() + + +def _listen_for_register(hostname, port): + context = zmq.Context() + router_socket = context.socket(zmq.ROUTER) + router_socket.bind(f"tcp://{hostname}:{port}") + poller = zmq.Poller() + poller.register(router_socket, zmq.POLLIN) + global prefill_instances + global decode_instances + + while True: + socks = dict(poller.poll()) + if router_socket in socks: + remote_addr, msg = router_socket.recv_multipart() + data = msgpack.loads(msg) + if data["type"] == "HELLO": + pass + elif ( + data["type"] == "register" + and data["role"] == "P" + and data["request_address"] not in prefill_instances + ): + with _list_lock: + _append_whole_dict_unique(prefill_instances, data) + + elif ( + data["type"] == "register" + and data["role"] == "D" + and data["request_address"] not in decode_instances + ): + with _list_lock: + _append_whole_dict_unique(decode_instances, data) + + +def start_service_discovery(hostname, port): + if not hostname: + hostname = socket.gethostname() + if port == 0: + raise ValueError("Port cannot be 0") + + _listener_thread = threading.Thread( + target=_listen_for_register, args=(hostname, port), daemon=True + ) + _listener_thread.start() + logger.info("Service discovery listening on %s:%s", hostname, port) + return _listener_thread + + +async def send_request_to_prefill( + endpoint, req_data, request_id, d_endpoint, dip, dport, selected_prefill_dp_rank +): + req_data_copy = req_data + + req_data_copy["kv_transfer_params"].update( + { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_handshake_port": d_endpoint["handshake_port"], + "remote_notify_port": d_endpoint["notify_port"], + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": dip, + "remote_port": dport, + } + ) + req_data_copy["stream"] = False + req_data_copy["max_tokens"] = 1 + if "max_completion_tokens" in req_data_copy: + req_data_copy["max_completion_tokens"] = 1 + if "stream_options" in req_data_copy: + del req_data_copy["stream_options"] + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000) + ) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + if selected_prefill_dp_rank is not None: + headers["X-data-parallel-rank"] = str(selected_prefill_dp_rank) + async with session.post( + url=endpoint, json=req_data_copy, headers=headers + ) as response: + if response.status == 200: + return await response.json() + else: + raise RuntimeError( + f"Prefill response status={response.status}" + ) + + +async def start_decode_request(endpoint, req_data, request_id): + session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000) + ) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + response = await session.post(url=endpoint, json=req_data, headers=headers) + return session, response + + +async def stream_decode_response(session, response, request_id): + try: + if response.status == 200: + chunk_iter = response.content.iter_chunked(1024).__aiter__() + while True: + try: + chunk_bytes = await asyncio.wait_for( + chunk_iter.__anext__(), timeout=STREAM_IDLE_TIMEOUT, + ) + yield chunk_bytes + except StopAsyncIteration: + break + except asyncio.TimeoutError: + logger.error( + "Decode stream %s idle for %ds, aborting", + request_id, STREAM_IDLE_TIMEOUT, + ) + break + else: + raise RuntimeError( + f"Decode response status={response.status}" + ) + finally: + await response.release() + await session.close() + + +@app.route("/health", methods=["GET"]) +async def health_check(): + with _list_lock: + p_count = len(prefill_instances) + d_count = len(decode_instances) + return await make_response( + ({"status": "ok", "prefill_instances": p_count, "decode_instances": d_count}, 200) + ) + + +@app.route("/v1/completions", methods=["POST"]) +@app.route("/v1/chat/completions", methods=["POST"]) +async def handle_request(): + try: + with _list_lock: + global request_nums + request_nums += 1 + + def extract_ip_port_fast(url): + match = IP_PORT_PATTERN.search(url) + if not match: + raise ValueError(f"Invalid URL format: {url}") + return match.groups() + + req_data = await request.get_json() + request_id = str(uuid.uuid4()) + + if not prefill_instances or not decode_instances: + return await make_response( + ("Service Unavailable: No prefill or decode instances registered.", 503) + ) + + pid = request_nums % len(prefill_instances) + did = request_nums % len(decode_instances) + prefill_instance_endpoint = prefill_instances[pid] + decode_instance_endpoint = decode_instances[did] + + selected_prefill_dp_rank = None + if prefill_instance_endpoint["dp_size"] > 1: + selected_prefill_dp_rank = request_nums % prefill_instance_endpoint["dp_size"] + + dip, dport = extract_ip_port_fast(decode_instance_endpoint["request_address"]) + + req_data_to_prefill = copy.deepcopy(req_data) + req_data_to_prefill["kv_transfer_params"] = {} + req_data["kv_transfer_params"] = {} + req_data_to_prefill["kv_transfer_params"]["remote_dp_size"] = ( + decode_instance_endpoint["dp_size"] + ) + req_data_to_prefill["kv_transfer_params"]["remote_tp_size"] = ( + decode_instance_endpoint["tp_size"] + ) + + send_prefill_task = asyncio.create_task( + send_request_to_prefill( + prefill_instance_endpoint["request_address"], + req_data_to_prefill, + request_id, + decode_instance_endpoint, + dip, + dport, + selected_prefill_dp_rank, + ) + ) + ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"]) + + req_data["max_tokens"] -= 1 + + req_data["kv_transfer_params"] = { + "do_remote_decode": False, + "do_remote_prefill": True, + "remote_handshake_port": prefill_instance_endpoint["handshake_port"], + "remote_notify_port": prefill_instance_endpoint["notify_port"], + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": ip, + "remote_port": port, + } + if TRANSFER_TYPE == "READ": + prefill_response = await send_prefill_task + req_data["kv_transfer_params"]["remote_engine_id"] = prefill_response[ + "kv_transfer_params" + ]["remote_engine_id"] + req_data["kv_transfer_params"]["remote_block_ids"] = prefill_response[ + "kv_transfer_params" + ]["remote_block_ids"] + + req_data["kv_transfer_params"]["remote_dp_size"] = prefill_instance_endpoint[ + "dp_size" + ] + req_data["kv_transfer_params"]["remote_tp_size"] = prefill_instance_endpoint[ + "tp_size" + ] + + if selected_prefill_dp_rank is not None: + req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank + + decode_request_task = asyncio.create_task( + start_decode_request( + decode_instance_endpoint["request_address"], req_data, request_id + ) + ) + + session, decode_response = await decode_request_task + stream_generator = stream_decode_response(session, decode_response, request_id) + response = await make_response(stream_generator) + return response + except Exception as e: + logger.exception("Error handling request: %s", e) + return await make_response((f"Internal Server Error: {e!s}", 500)) + + +if __name__ == "__main__": + http_port = int(os.environ.get("PROXY_HTTP_PORT", "30000")) + ping_port = int(os.environ.get("PROXY_PING_PORT", "36367")) + + t = start_service_discovery("0.0.0.0", ping_port) + app.debug = False + app.config["BODY_TIMEOUT"] = 360000 + app.config["RESPONSE_TIMEOUT"] = 360000 + + logger.info("MoRI-IO proxy starting: HTTP=%d, ZMQ=%d", http_port, ping_port) + app.run(host="0.0.0.0", port=http_port) + t.join() diff --git a/benchmarks/multi_node/vllm_disagg_utils/server.sh b/benchmarks/multi_node/vllm_disagg_utils/server.sh new file mode 100755 index 000000000..9b0ff2ebb --- /dev/null +++ b/benchmarks/multi_node/vllm_disagg_utils/server.sh @@ -0,0 +1,490 @@ +#!/bin/bash +# vLLM Disaggregated Server Launcher with Model-Specific Configurations +# ============================================================================= +# +# Node role assignment (by NODE_RANK): +# 0 -> Proxy/Router + first Prefill node (kv_producer) +# 1..xP-1 -> Additional Prefill nodes (kv_producer) +# xP..xP+yD-1 -> Decode nodes (kv_consumer) +# +# Total nodes = xP + yD (router co-located with first prefill, like SGLang). + +# ============================================================================= +# Dependency Setup (idempotent; required when using base vLLM image) +# ============================================================================= +source "$(dirname "${BASH_SOURCE[0]}")/setup_deps.sh" + +# ============================================================================= +# Environment Configuration +# ============================================================================= + +NODE0_ADDR="${NODE0_ADDR:-localhost}" +NODE_RANK="${NODE_RANK:-0}" +MODEL_DIR="${MODEL_DIR:-}" +MODEL_NAME="${MODEL_NAME:-}" + +xP="${xP:-1}" +yD="${yD:-1}" + +IPADDRS="${IPADDRS:-localhost}" + +# Benchmark Configuration +BENCH_INPUT_LEN="${BENCH_INPUT_LEN:-1024}" +BENCH_OUTPUT_LEN="${BENCH_OUTPUT_LEN:-1024}" +BENCH_RANDOM_RANGE_RATIO="${BENCH_RANDOM_RANGE_RATIO:-1}" +BENCH_REQUEST_RATE="${BENCH_REQUEST_RATE:-inf}" +BENCH_NUM_PROMPTS_MULTIPLIER="${BENCH_NUM_PROMPTS_MULTIPLIER:-10}" +BENCH_MAX_CONCURRENCY="${BENCH_MAX_CONCURRENCY:-512}" + +DRY_RUN="${DRY_RUN:-0}" +GPUS_PER_NODE="${GPUS_PER_NODE:-8}" + +ROUTER_PORT="${ROUTER_PORT:-30000}" +SERVER_PORT="${SERVER_PORT:-2584}" +ENGINE_ID="${ENGINE_ID:-${MODEL_NAME}-pd-run}" + +# Prefer MODEL_PATH from job.slurm (handles HF cache snapshot resolution) +MODEL_PATH="${MODEL_PATH:-${MODEL_DIR}/${MODEL_NAME}}" + +# ============================================================================= +# Dependencies and Environment Setup +# ============================================================================= +source $VLLM_WS_PATH/env.sh + +host_ip=$(ip route get 1.1.1.1 2>/dev/null | awk '/src/ {print $7}') +# RDMA IP for Nixl KV transfer (prefer 192.168.x.x subnet if available) +rdma_ip=$(hostname -I | tr ' ' '\n' | grep '^192\.168\.' | head -1) +rdma_ip="${rdma_ip:-$host_ip}" +host_name=$(hostname) + +echo "[INFO] Management IP (barriers/proxy): $host_ip" +echo "[INFO] RDMA IP (Nixl KV transfer): $rdma_ip" + +# ============================================================================= +# RDMA / Nixl Workarounds +# ============================================================================= + +setup_rdma_env() { + # Pensando ionic (RoCEv2) point-to-point /31 route fix. + # Each benic interface has a /31 to the TOR switch. Without explicit routes, + # traffic to other nodes' RDMA IPs falls through to the management network. + if [[ "$rdma_ip" =~ ^192\.168\.([0-9]+)\.([0-9]+)$ ]]; then + local rdma_subnet="${BASH_REMATCH[1]}" + local rdma_host="${BASH_REMATCH[2]}" + local rdma_gw="192.168.${rdma_subnet}.$(( rdma_host | 1 ))" + local rdma_iface + rdma_iface=$(ip -o addr show | awk -v ip="$rdma_ip" '$4 ~ ip {print $2}' | head -1) + if [[ -n "$rdma_iface" ]]; then + ip route replace "192.168.${rdma_subnet}.0/24" via "$rdma_gw" dev "$rdma_iface" 2>/dev/null && \ + echo "[RDMA-ROUTE] Added 192.168.${rdma_subnet}.0/24 via $rdma_gw dev $rdma_iface" || \ + echo "[RDMA-ROUTE] Route add failed for 192.168.${rdma_subnet}.0/24" + fi + fi + + # Patch Nixl UCX backend: set ucx_error_handling_mode=none. + # Required for ALL NIC types under high concurrency (C512+). Without this, + # UCX's default UCP_ERR_HANDLING_MODE_PEER triggers transport-level error + # recovery on ibv_post_send failures, preventing RIXL RDMA READ retries from + # recovering gracefully. This causes the prefill KV cache to fill to 100% + # and deadlock the pipeline. On ionic NICs this was already applied (rdmacm + # incompatibility); on mlx5 NICs it was incorrectly skipped. + local nixl_api + nixl_api=$(python3 -c "import rixl._api; print(rixl._api.__file__)" 2>/dev/null) + if [[ -n "$nixl_api" ]]; then + if ! grep -q 'ucx_error_handling_mode' "$nixl_api"; then + sed -i '/self\.create_backend(bknd, init)/i\ init["ucx_error_handling_mode"] = "none"' "$nixl_api" + echo "[PATCH] Added ucx_error_handling_mode=none to $nixl_api (IBDEVICES=${IBDEVICES:-unset})" + else + echo "[PATCH] ucx_error_handling_mode already set in $nixl_api" + fi + fi +} + +setup_rdma_env + +if [[ -z "$UCX_NET_DEVICES" ]]; then + echo "Error: UCX_NET_DEVICES is empty after env.sh detection" >&2 + exit 1 +fi + +# ============================================================================= +# Model-Specific Configuration from YAML +# ============================================================================= +MODELS_YAML="${VLLM_WS_PATH}/models.yaml" + +if [[ ! -f "$MODELS_YAML" ]]; then + echo "ERROR: models.yaml not found at $MODELS_YAML" + exit 1 +fi + +if [[ -z "$MODEL_NAME" ]]; then + echo "ERROR: MODEL_NAME is not set"; exit 1 +fi + +eval "$(python3 -c " +import yaml, sys + +with open('${MODELS_YAML}') as f: + models = yaml.safe_load(f) + +model_name = '${MODEL_NAME}' +if model_name not in models: + print(f'echo \"ERROR: Model {model_name} not in models.yaml\"; exit 1') + sys.exit(0) + +m = models[model_name] + +def bash_escape(s): + \"\"\"Escape a value for safe embedding in a bash double-quoted assignment.\"\"\" + return s.replace('\\\\', '\\\\\\\\').replace('\"', '\\\\\"').replace('\$', '\\\\\$').replace('\`', '\\\\\`') + +pf = bash_escape(m.get('prefill_flags', '--tensor-parallel-size 8')) +df = bash_escape(m.get('decode_flags', '--tensor-parallel-size 8')) +ev = bash_escape(m.get('env', '')) +dev = bash_escape(m.get('decode_env', '')) +print(f'PREFILL_SERVER_CONFIG=\"{pf}\"') +print(f'DECODE_SERVER_CONFIG=\"{df}\"') +print(f'MODEL_ENVS=\"{ev}\"') +print(f'DECODE_MODEL_ENVS=\"{dev}\"') +")" + +echo "Loaded model configuration for: $MODEL_NAME" + +# Apply tensor-parallel size and EP/DP flags from submit pipeline (YAML PREFILL_TP / dp-attn / ep). +if [[ -n "${PREFILL_TP:-}" ]]; then + if echo "$PREFILL_SERVER_CONFIG" | grep -q -- '--tensor-parallel-size'; then + PREFILL_SERVER_CONFIG=$(echo "$PREFILL_SERVER_CONFIG" | sed -E "s/--tensor-parallel-size[[:space:]]+[0-9]+/--tensor-parallel-size ${PREFILL_TP}/g") + else + PREFILL_SERVER_CONFIG+=" --tensor-parallel-size ${PREFILL_TP}" + fi +fi +if [[ -n "${DECODE_TP:-}" ]]; then + if echo "$DECODE_SERVER_CONFIG" | grep -q -- '--tensor-parallel-size'; then + DECODE_SERVER_CONFIG=$(echo "$DECODE_SERVER_CONFIG" | sed -E "s/--tensor-parallel-size[[:space:]]+[0-9]+/--tensor-parallel-size ${DECODE_TP}/g") + else + DECODE_SERVER_CONFIG+=" --tensor-parallel-size ${DECODE_TP}" + fi +fi +if [[ "${PREFILL_ENABLE_EP:-false}" == "true" ]] && ! echo "$PREFILL_SERVER_CONFIG" | grep -q -- '--enable-expert-parallel'; then + PREFILL_SERVER_CONFIG+=" --enable-expert-parallel" +fi +if [[ "${PREFILL_ENABLE_DP:-false}" == "true" ]] && ! echo "$PREFILL_SERVER_CONFIG" | grep -q -- '--enable-dp-attention'; then + PREFILL_SERVER_CONFIG+=" --enable-dp-attention" +fi +if [[ "${DECODE_ENABLE_EP:-false}" == "true" ]] && ! echo "$DECODE_SERVER_CONFIG" | grep -q -- '--enable-expert-parallel'; then + DECODE_SERVER_CONFIG+=" --enable-expert-parallel" +fi +if [[ "${DECODE_ENABLE_DP:-false}" == "true" ]] && ! echo "$DECODE_SERVER_CONFIG" | grep -q -- '--enable-dp-attention'; then + DECODE_SERVER_CONFIG+=" --enable-dp-attention" +fi + +echo "PREFILL_SERVER_CONFIG (after TP/EP/DP): $PREFILL_SERVER_CONFIG" +echo "DECODE_SERVER_CONFIG (after TP/EP/DP): $DECODE_SERVER_CONFIG" + +# ============================================================================= +# Container Synchronization +# ============================================================================= + +echo "Waiting at the container creation barrier on $host_name" +python3 $VLLM_WS_PATH/sync.py barrier \ + --local-ip ${host_ip} \ + --local-port 5000 \ + --enable-port \ + --node-ips ${IPADDRS} \ + --node-ports 5000 \ + --wait-for-all-ports \ + --timeout 600 + +# ============================================================================= +# ETCD Server Setup +# ============================================================================= + +echo "Proceeding to start etcd server on $host_name" +bash ${VLLM_WS_PATH}/start_etcd.sh > /dev/null 2>&1 & +etcd_pid=$! + +echo "Waiting at etcd server barrier on $host_name" +python3 $VLLM_WS_PATH/sync.py barrier \ + --node-ips ${IPADDRS} \ + --node-ports 2379 \ + --wait-for-all-ports \ + --timeout 300 + +echo "All etcd servers are up : $host_name" +sleep 3 + +echo "etcd endpoint health==================" +etcdctl endpoint health 2>&1 || /usr/local/bin/etcd/etcdctl endpoint health 2>&1 || true +echo "======================================" + +python3 $VLLM_WS_PATH/sync.py barrier \ + --node-ips ${IPADDRS} \ + --node-ports 2379 \ + --wait-for-all-ports \ + --timeout 300 + +# ============================================================================= +# Cluster Topology Configuration +# ============================================================================= +IFS=',' read -ra IP_ARRAY <<< "$IPADDRS" + +PREFILL_ARGS="" +DECODE_ARGS="" + +for ((i=0; i "$PROXY_LOG_FILE" 2>&1 & + set +x + proxy_pid=$! + sleep 3 + fi + + PREFILL_CMD="vllm serve ${MODEL_PATH} \ + --port $SERVER_PORT \ + --trust-remote-code \ + --kv-transfer-config '{\"kv_connector\": \"MoRIIOConnector\", \"kv_role\": \"kv_producer\", \"kv_connector_extra_config\": {\"proxy_ip\": \"${NODE0_ADDR}\", \"proxy_ping_port\": \"${PROXY_PING_PORT}\", \"http_port\": \"${SERVER_PORT}\"}}' \ + ${PREFILL_SERVER_CONFIG}" + + if [[ "$DRY_RUN" -eq 1 ]]; then + echo "DRY RUN: $PREFILL_CMD" + else + PREFILL_LOG_FILE="/run_logs/slurm_job-${SLURM_JOB_ID}/prefill_${host_name}.log" + set -x + eval "$PREFILL_CMD" > "$PREFILL_LOG_FILE" 2>&1 & + set +x + prefill_pid=$! + fi + + echo "Waiting for all prefill and decode servers to be up . . ." + if [[ "$DRY_RUN" -eq 1 ]]; then + echo "DRY RUN: skipping barrier (wait-for-all-ports)" + else + python3 $VLLM_WS_PATH/sync.py barrier \ + --node-ips ${IPADDRS} \ + --node-ports $SERVER_PORT \ + --wait-for-all-ports \ + --timeout 1800 + fi + + echo "Congratulations!!! All prefill and decode servers are up . . ." + + # Wait for proxy /health to confirm it is accepting requests + HEALTH_BARRIER_CMD="python3 $VLLM_WS_PATH/sync.py barrier \ + --node-ips ${NODE0_ADDR} \ + --node-ports ${ROUTER_PORT} \ + --wait-for-all-health \ + --health-endpoint /health \ + --timeout 1800" + + if [[ "$DRY_RUN" -eq 1 ]]; then + echo "DRY RUN: $HEALTH_BARRIER_CMD" + else + eval "$HEALTH_BARRIER_CMD" + echo "MoRI-IO proxy is ready for benchmarking" + fi + + echo "Ready for benchmarking on ${host_name}:${host_ip}" + echo "Benchmarking on ${host_name}:${host_ip}" + cd $VLLM_WS_PATH + + export ROUTER_PORT=$ROUTER_PORT + BENCH_CMD="bash $VLLM_WS_PATH/bench.sh ${xP} ${yD} $((GPUS_PER_NODE*xP)) $((GPUS_PER_NODE*yD)) \ + $MODEL_DIR $MODEL_NAME /run_logs/slurm_job-${SLURM_JOB_ID} ${BENCH_INPUT_LEN} \ + ${BENCH_OUTPUT_LEN} \"${BENCH_MAX_CONCURRENCY}\" ${BENCH_REQUEST_RATE} \ + ${BENCH_RANDOM_RANGE_RATIO} ${BENCH_NUM_PROMPTS_MULTIPLIER}" + + if [[ "$DRY_RUN" -eq 1 ]]; then + echo "DRY RUN: $BENCH_CMD" + else + set -x + eval "$BENCH_CMD" + set +x + fi + + # Copy benchmark results to BENCHMARK_LOGS_DIR (mounted from host) + LOGS_OUTPUT="${BENCHMARK_LOGS_DIR:-/run_logs}/logs" + mkdir -p "$LOGS_OUTPUT" + + if [[ "$DRY_RUN" -eq 0 ]]; then + cp -r /run_logs/slurm_job-${SLURM_JOB_ID} "$LOGS_OUTPUT/" + echo "Copied results to $LOGS_OUTPUT/slurm_job-${SLURM_JOB_ID}" + fi + + echo "Killing the proxy server and prefill server" + if [[ "$DRY_RUN" -eq 0 ]]; then + [[ -n "${proxy_pid:-}" ]] && kill $proxy_pid 2>/dev/null || true + [[ -n "${prefill_pid:-}" ]] && kill $prefill_pid 2>/dev/null || true + sleep 2 + # Fallback: ensure no orphaned processes keep ports open + pkill -f moriio_proxy 2>/dev/null || true + pkill -f "vllm serve" 2>/dev/null || true + fi + +elif [ "$NODE_RANK" -gt 0 ] && [ "$NODE_RANK" -lt "$xP" ]; then + echo "${host_name}:${host_ip} is Additional Prefill Node (Model: ${MODEL_NAME})" + echo "Using prefill config: $PREFILL_SERVER_CONFIG" + + setup_vllm_env + + PREFILL_CMD="vllm serve ${MODEL_PATH} \ + --port $SERVER_PORT \ + --trust-remote-code \ + --kv-transfer-config '{\"kv_connector\": \"MoRIIOConnector\", \"kv_role\": \"kv_producer\", \"kv_connector_extra_config\": {\"proxy_ip\": \"${NODE0_ADDR}\", \"proxy_ping_port\": \"${PROXY_PING_PORT}\", \"http_port\": \"${SERVER_PORT}\"}}' \ + ${PREFILL_SERVER_CONFIG}" + + if [[ "$DRY_RUN" -eq 1 ]]; then + echo "DRY RUN: $PREFILL_CMD" + else + PREFILL_LOG_FILE="/run_logs/slurm_job-${SLURM_JOB_ID}/prefill_${host_name}.log" + set -x + eval "$PREFILL_CMD" > "$PREFILL_LOG_FILE" 2>&1 & + set +x + prefill_pid=$! + fi + + echo "Waiting for proxy server to be up..." + BARRIER_CMD="python3 $VLLM_WS_PATH/sync.py barrier \ + --node-ips ${NODE0_ADDR} \ + --node-ports ${ROUTER_PORT} \ + --wait-for-all-ports \ + --timeout 1800" + + if [[ "$DRY_RUN" -eq 1 ]]; then + echo "DRY RUN: $BARRIER_CMD" + else + eval "$BARRIER_CMD" + fi + + echo "Waiting until proxy server closes..." + WAIT_CMD="python3 $VLLM_WS_PATH/sync.py wait \ + --remote-ip ${NODE0_ADDR} \ + --remote-port ${ROUTER_PORT}" + + if [[ "$DRY_RUN" -eq 1 ]]; then + echo "DRY RUN: $WAIT_CMD" + else + eval "$WAIT_CMD" + fi + + echo "Killing the prefill server" + [[ "$DRY_RUN" -eq 0 ]] && kill $prefill_pid 2>/dev/null || true + +else + echo "${host_name}:${host_ip} is Decode Node (Model: ${MODEL_NAME})" + echo "Using decode config: $DECODE_SERVER_CONFIG" + + setup_vllm_env + + for env_pair in ${DECODE_MODEL_ENVS}; do + export "$env_pair" + echo "[DECODE_ENV] $env_pair" + done + + DECODE_CMD="vllm serve ${MODEL_PATH} \ + --port $SERVER_PORT \ + --trust-remote-code \ + --kv-transfer-config '{\"kv_connector\": \"MoRIIOConnector\", \"kv_role\": \"kv_consumer\", \"kv_connector_extra_config\": {\"proxy_ip\": \"${NODE0_ADDR}\", \"proxy_ping_port\": \"${PROXY_PING_PORT}\", \"http_port\": \"${SERVER_PORT}\"}}' \ + ${DECODE_SERVER_CONFIG}" + + if [[ "$DRY_RUN" -eq 1 ]]; then + echo "DRY RUN: $DECODE_CMD" + else + DECODE_LOG_FILE="/run_logs/slurm_job-${SLURM_JOB_ID}/decode_${host_name}.log" + set -x + eval "$DECODE_CMD" > "$DECODE_LOG_FILE" 2>&1 & + set +x + decode_pid=$! + fi + + echo "Waiting for proxy server to be up..." + BARRIER_CMD="python3 $VLLM_WS_PATH/sync.py barrier \ + --node-ips ${NODE0_ADDR} \ + --node-ports ${ROUTER_PORT} \ + --wait-for-all-ports \ + --timeout 1800" + + if [[ "$DRY_RUN" -eq 1 ]]; then + echo "DRY RUN: $BARRIER_CMD" + else + eval "$BARRIER_CMD" + fi + + echo "Waiting until proxy server closes..." + WAIT_CMD="python3 $VLLM_WS_PATH/sync.py wait \ + --remote-ip ${NODE0_ADDR} \ + --remote-port ${ROUTER_PORT}" + + if [[ "$DRY_RUN" -eq 1 ]]; then + echo "DRY RUN: $WAIT_CMD" + else + eval "$WAIT_CMD" + fi + + echo "Killing the decode server" + [[ "$DRY_RUN" -eq 0 ]] && kill $decode_pid 2>/dev/null || true +fi + +echo "Killing the etcd server" +kill $etcd_pid 2>/dev/null || true +pkill -f etcd 2>/dev/null || true + +echo "Script completed successfully" +exit 0 diff --git a/benchmarks/multi_node/vllm_disagg_utils/setup_deps.sh b/benchmarks/multi_node/vllm_disagg_utils/setup_deps.sh new file mode 100644 index 000000000..e8437a5c9 --- /dev/null +++ b/benchmarks/multi_node/vllm_disagg_utils/setup_deps.sh @@ -0,0 +1,848 @@ +#!/bin/bash +# ============================================================================= +# setup_deps.sh — Install missing vLLM disagg dependencies at container start. +# +# Base image: vllm/vllm-openai-rocm:v0.17.1 +# Sourced by server.sh so PATH / LD_LIBRARY_PATH exports persist. +# Idempotent: each component is skipped if already present. +# +# Build steps run in subshells to avoid CWD pollution between installers. +# ============================================================================= + +ROCM_PATH="${ROCM_PATH:-/opt/rocm}" +UCX_HOME="${UCX_HOME:-/usr/local/ucx}" +RIXL_HOME="${RIXL_HOME:-/usr/local/rixl}" + +_SETUP_START=$(date +%s) +_SETUP_INSTALLED=() + +git_clone_retry() { + local url="$1" dest="$2" max_tries=3 try=1 + while (( try <= max_tries )); do + if git clone --quiet "$url" "$dest" 2>/dev/null; then return 0; fi + echo "[SETUP] git clone attempt $try/$max_tries failed for $url, retrying in 10s..." + rm -rf "$dest" + sleep 10 + (( try++ )) + done + echo "[SETUP] git clone failed after $max_tries attempts: $url" + return 1 +} + +# --------------------------------------------------------------------------- +# 1. UCX (ROCm fork — required for GPU-direct RDMA via Nixl) +# --------------------------------------------------------------------------- +install_ucx() { + if [[ -x "${UCX_HOME}/bin/ucx_info" ]]; then + echo "[SETUP] UCX already present at ${UCX_HOME}" + return 0 + fi + + echo "[SETUP] Installing UCX build dependencies..." + apt-get update -q -y && apt-get install -q -y \ + autoconf automake libtool pkg-config \ + librdmacm-dev rdmacm-utils libibverbs-dev ibverbs-utils ibverbs-providers \ + infiniband-diags perftest ethtool rdma-core strace \ + && rm -rf /var/lib/apt/lists/* + + echo "[SETUP] Building UCX from source (ROCm/ucx @ da3fac2a)..." + ( + set -e + mkdir -p /usr/local/src && cd /usr/local/src + git_clone_retry https://github.com/ROCm/ucx.git ucx && cd ucx + git checkout da3fac2a + ./autogen.sh && mkdir -p build && cd build + ../configure \ + --prefix="${UCX_HOME}" \ + --enable-shared --disable-static \ + --disable-doxygen-doc --enable-optimizations \ + --enable-devel-headers --enable-mt \ + --with-rocm="${ROCM_PATH}" --with-verbs --with-dm + make -j"$(nproc)" && make install + ) + rm -rf /usr/local/src/ucx + + if [[ ! -x "${UCX_HOME}/bin/ucx_info" ]]; then + echo "[SETUP] ERROR: UCX build failed"; exit 1 + fi + _SETUP_INSTALLED+=("UCX") +} + +# --------------------------------------------------------------------------- +# 2. RIXL (ROCm fork of NIXL — KV cache transfer for disaggregated vLLM) +# --------------------------------------------------------------------------- +install_rixl() { + if python3 -c "import rixl" 2>/dev/null; then + echo "[SETUP] RIXL Python bindings already present" + return 0 + fi + + echo "[SETUP] Installing RIXL build dependencies..." + apt-get update -q -y && apt-get install -q -y \ + libgrpc-dev libgrpc++-dev libprotobuf-dev protobuf-compiler-grpc \ + libcpprest-dev libaio-dev \ + && rm -rf /var/lib/apt/lists/* + pip3 install --quiet meson "pybind11[global]" + + echo "[SETUP] Building RIXL from source (ROCm/RIXL @ f33a5599)..." + ( + set -e + git_clone_retry https://github.com/ROCm/RIXL.git /opt/rixl && cd /opt/rixl + git checkout f33a5599 + meson setup build --prefix="${RIXL_HOME}" \ + -Ducx_path="${UCX_HOME}" \ + -Drocm_path="${ROCM_PATH}" + cd build && ninja && ninja install + cd /opt/rixl + pip install --quiet \ + --config-settings=setup-args="-Drocm_path=${ROCM_PATH}" \ + --config-settings=setup-args="-Ducx_path=${UCX_HOME}" . + ) + rm -rf /opt/rixl + + if ! python3 -c "import rixl" 2>/dev/null; then + echo "[SETUP] ERROR: RIXL build failed"; exit 1 + fi + _SETUP_INSTALLED+=("RIXL") +} + +# --------------------------------------------------------------------------- +# 3. etcd (distributed KV store for vLLM disagg service discovery) +# --------------------------------------------------------------------------- +install_etcd() { + if [[ -x /usr/local/bin/etcd/etcd ]]; then + echo "[SETUP] etcd already present" + return 0 + fi + + local version="v3.6.0-rc.5" + echo "[SETUP] Downloading etcd ${version}..." + wget -q "https://github.com/etcd-io/etcd/releases/download/${version}/etcd-${version}-linux-amd64.tar.gz" \ + -O /tmp/etcd.tar.gz + mkdir -p /usr/local/bin/etcd + tar -xf /tmp/etcd.tar.gz -C /usr/local/bin/etcd --strip-components=1 + rm /tmp/etcd.tar.gz + _SETUP_INSTALLED+=("etcd") +} + +# --------------------------------------------------------------------------- +# 4. libionic1 (Pensando ionic RDMA verbs provider for RoCEv2 KV transfer) +# Harmless on non-Pensando nodes (shared lib is simply unused). +# --------------------------------------------------------------------------- +install_libionic() { + if dpkg -l libionic1 2>/dev/null | grep -q '^ii'; then + echo "[SETUP] libionic1 already installed" + return 0 + fi + + echo "[SETUP] Downloading and installing libionic1..." + wget -q "https://repo.radeon.com/amdainic/pensando/ubuntu/1.117.5/pool/main/r/rdma-core/libionic1_54.0-149.g3304be71_amd64.deb" \ + -O /tmp/libionic1.deb + dpkg -i /tmp/libionic1.deb || true + rm -f /tmp/libionic1.deb + _SETUP_INSTALLED+=("libionic1") +} + +# --------------------------------------------------------------------------- +# 5. MoRI-IO proxy deps (Python packages for the MoRI-IO-aware proxy server) +# The proxy replaces vllm-router: it handles both HTTP routing AND the +# MoRI-IO ZMQ registration/request-enrichment protocol. +# Only needed on NODE_RANK=0 (proxy node). +# --------------------------------------------------------------------------- +install_mori_proxy_deps() { + if python3 -c "import quart, aiohttp, msgpack, zmq" 2>/dev/null; then + echo "[SETUP] MoRI-IO proxy Python deps already present" + return 0 + fi + + echo "[SETUP] Installing MoRI-IO proxy Python deps..." + pip install --quiet --ignore-installed blinker + pip install --quiet quart aiohttp msgpack pyzmq + + if ! python3 -c "import quart, aiohttp, msgpack, zmq" 2>/dev/null; then + echo "[SETUP] ERROR: MoRI-IO proxy deps install failed"; exit 1 + fi + _SETUP_INSTALLED+=("mori-proxy-deps") +} + +# --------------------------------------------------------------------------- +# 6. MoRI (Modular RDMA Interface — EP dispatch/combine kernels for MoE) +# Required for --all2all-backend mori (Expert Parallelism via RDMA). +# GPU kernels are JIT-compiled on first use; no hipcc needed at install. +# --------------------------------------------------------------------------- +install_mori() { + local MORI_TARGET_COMMIT="b645fc8" + local MORI_MARKER="/usr/local/lib/python3.*/dist-packages/.mori_commit_${MORI_TARGET_COMMIT}" + + # The pre-installed MoRI in vllm base images has a PCI topology bug: it + # only maps the secondary bus of each bridge instead of the full + # secondary-to-subordinate range (dsp2dev). This causes an assertion + # failure in TopoSystemPci::Load() on nodes with deeply-nested PCIe + # switch topologies (e.g. Broadcom PEX890xx on MI355X mia1 nodes). + # Always rebuild from the target commit unless the marker file proves + # the correct version was already installed in this container. + if ls $MORI_MARKER &>/dev/null; then + echo "[SETUP] MoRI @ $MORI_TARGET_COMMIT already installed (marker found)" + return 0 + fi + + echo "[SETUP] Installing MoRI build dependencies..." + apt-get update -q -y && apt-get install -q -y \ + libopenmpi-dev openmpi-bin libpci-dev \ + && rm -rf /var/lib/apt/lists/* + + echo "[SETUP] Building MoRI from source (ROCm/mori @ $MORI_TARGET_COMMIT)..." + echo "[SETUP] (overriding pre-installed version to fix PCI topology bug)" + ( + set -e + git_clone_retry https://github.com/ROCm/mori.git /opt/mori && cd /opt/mori + git checkout "$MORI_TARGET_COMMIT" + pip install --quiet --force-reinstall . + ) + rm -rf /opt/mori + + if ! python3 -c "import mori" 2>/dev/null; then + echo "[SETUP] ERROR: MoRI build failed"; exit 1 + fi + # Drop a marker so re-entry doesn't rebuild + touch $(python3 -c "import sysconfig; print(sysconfig.get_paths()['purelib'])")/.mori_commit_${MORI_TARGET_COMMIT} + _SETUP_INSTALLED+=("MoRI@$MORI_TARGET_COMMIT") +} + +# --------------------------------------------------------------------------- +# 7. Patch vLLM v0.17.1 MoRI-EP + FP8 incompatibility +# v0.17.1 asserts MoRI requires AITER fused_moe, but AITER's FP8 kernel +# uses defer_input_quant=True which MoRI's prepare/finalize rejects. +# Patch: remove both the AITER requirement assertion and the +# defer_input_quant NotImplementedError so non-AITER kernels work. +# --------------------------------------------------------------------------- +patch_mori_fp8_compat() { + python3 -c ' +import re, os, sys +patched = [] + +# 1. Patch layer.py: remove multi-line AITER assertion for MoRI +try: + import vllm.model_executor.layers.fused_moe.layer as lm + f = lm.__file__ + src = open(f).read() + if "Mori needs to be used with aiter" in src: + new = re.sub( + r"assert self\.rocm_aiter_fmoe_enabled,\s*\([^)]*Mori needs[^)]*\)", + "pass # [PATCHED] AITER requirement removed for MoRI-EP + FP8", + src, flags=re.DOTALL) + if new != src: + open(f, "w").write(new) + patched.append("layer.py") +except Exception as e: + print(f"[SETUP] WARN patch layer.py: {e}", file=sys.stderr) + +# 2. Patch mori_prepare_finalize.py: remove defer_input_quant restriction +try: + import vllm.model_executor.layers.fused_moe.mori_prepare_finalize as mm + f = mm.__file__ + src = open(f).read() + if "defer_input_quant" in src: + new = re.sub( + r"raise NotImplementedError\([^)]*defer_input_quant[^)]*\)", + "pass # [PATCHED] defer_input_quant check removed for MoRI-EP + FP8", + src) + if new != src: + open(f, "w").write(new) + patched.append("mori_prepare_finalize.py") +except Exception as e: + print(f"[SETUP] WARN patch mori_pf: {e}", file=sys.stderr) + +if patched: + print(f"[SETUP] Patched: {chr(44).join(patched)}") +else: + print("[SETUP] No MoRI-FP8 patches needed") +' + _SETUP_INSTALLED+=("MoRI-FP8-patch") +} + +# --------------------------------------------------------------------------- +# 8. Patch vLLM MoRI-IO save_kv_layer busy-spin (C128 tail-batch deadlock) +# In WRITE mode, save_kv_layer spins forever waiting for the handshake +# callback to set write_ready_flags. This blocks the model worker thread, +# preventing it from responding to EngineCore shm_broadcast, causing a +# TimeoutError cascade and crash. +# Patch: add time.sleep(0.001) and a 30s timeout to yield CPU and prevent +# the model worker from deadlocking. +# --------------------------------------------------------------------------- +patch_moriio_save_kv_timeout() { + python3 -c ' +import os, sys + +try: + import vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector as mc + f = mc.__file__ + src = open(f).read() + + # Already patched? + if "[PATCHED] save_kv_layer timeout" in src: + print("[SETUP] save_kv_layer timeout patch already applied") + sys.exit(0) + + old = """ while True: + if ( + self._ready_requests.empty() + and remote_engine_id not in self.write_ready_flags + ): + continue""" + + if old not in src: + print("[SETUP] WARN: save_kv_layer busy-spin pattern not found, skipping patch") + sys.exit(0) + + new = """ # [PATCHED] save_kv_layer — null guard + timeout + sleep + if remote_engine_id is None: + return + import time as _time, os as _os + _wait_start = _time.monotonic() + _SAVE_KV_TIMEOUT = float(_os.environ.get("VLLM_MORIIO_HANDSHAKE_TIMEOUT", "30")) + while True: + if ( + self._ready_requests.empty() + and remote_engine_id not in self.write_ready_flags + ): + _elapsed = _time.monotonic() - _wait_start + if _elapsed > _SAVE_KV_TIMEOUT: + import logging as _logging + _logging.getLogger("vllm.moriio").warning( + "[HANGFIX] save_kv_layer: timeout (%.1fs) waiting for " + "write_ready_flags[%s], breaking to unblock model " + "worker", _elapsed, remote_engine_id) + break + _time.sleep(0.001) + continue""" + + new_src = src.replace(old, new) + if new_src == src: + print("[SETUP] WARN: replacement had no effect") + sys.exit(0) + + open(f, "w").write(new_src) + print("[SETUP] Patched save_kv_layer: null guard + timeout + sleep") +except Exception as e: + print(f"[SETUP] WARN patch save_kv_layer: {e}", file=sys.stderr) +' + _SETUP_INSTALLED+=("MoRIIO-save-kv-timeout-patch") +} + +# --------------------------------------------------------------------------- +# 9. Patch MoRIIO waiting_for_transfer_complete with bounded timeout +# The original status.Wait() blocks forever if an RDMA completion never +# arrives (e.g., NIC queue saturation at C256). This replaces the unbounded +# wait with a polling loop using status.Succeeded() + configurable timeout. +# Also adds error handling to the write worker loop so a single failed +# transfer doesn't kill the background thread. +# --------------------------------------------------------------------------- +patch_moriio_transfer_timeout() { + python3 -c ' +import os, sys, textwrap + +try: + import vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine as me + f = me.__file__ + src = open(f).read() + + if "[PATCHED] transfer completion timeout" in src: + print("[SETUP] transfer completion timeout patch already applied") + sys.exit(0) + + # --- Patch 1: Replace waiting_for_transfer_complete with polling + timeout --- + old_wait = """ def waiting_for_transfer_complete(self): + if not self.transfer_status: + return + + transfers_to_wait = [] + with self.lock: + transfers_to_wait = self.transfer_status[:] + self.transfer_status.clear() + + for status in transfers_to_wait: + try: + status.Wait() + if not status.Succeeded(): + logger.error( + "Transfer failed: %s, Code: %s", status.Message(), status.Code() + ) + raise TransferError("MoRIIO transfer failed!") + except Exception as e: + logger.error("Transfer %s failed: %s", status, e) + raise""" + + new_wait = """ def waiting_for_transfer_complete(self): + # [PATCHED] transfer completion timeout — bounded polling loop + import time as _time, os as _os + if not self.transfer_status: + return + + _timeout = float(_os.environ.get("VLLM_MORIIO_TRANSFER_TIMEOUT", "120")) + + transfers_to_wait = [] + with self.lock: + transfers_to_wait = self.transfer_status[:] + self.transfer_status.clear() + + _start = _time.monotonic() + remaining = list(transfers_to_wait) + _polls = 0 + _completed = 0 + + while remaining: + _elapsed = _time.monotonic() - _start + if _elapsed > _timeout: + logger.error( + "[HANGFIX] transfer_timeout elapsed=%.1fs " + "pending=%d/%d completed=%d polls=%d " + "action=raise_transfer_error", + _elapsed, len(remaining), len(transfers_to_wait), + _completed, _polls, + ) + raise TransferError( + f"RDMA transfer timeout after {_elapsed:.1f}s, " + f"{len(remaining)}/{len(transfers_to_wait)} pending" + ) + + still_waiting = [] + for status in remaining: + try: + if status.Succeeded(): + _completed += 1 + continue + still_waiting.append(status) + except Exception as e: + logger.error( + "[HANGFIX] transfer_poll_error error=%s", e) + raise TransferError( + f"Transfer failed during poll: {e}" + ) from e + + remaining = still_waiting + if remaining: + _time.sleep(0.005) + _polls += 1 + if _polls % 2000 == 0: + logger.warning( + "[HANGFIX] transfer_wait pending=%d " + "completed=%d elapsed=%.1fs timeout=%.0fs", + len(remaining), _completed, + _time.monotonic() - _start, _timeout, + )""" + + if old_wait not in src: + print("[SETUP] WARN: waiting_for_transfer_complete pattern not found") + sys.exit(0) + + new_src = src.replace(old_wait, new_wait) + + # --- Patch 2: Add error handling + cleanup to _write_worker_loop --- + old_loop = """ self._execute_write_task(task)""" + + new_loop = """ try: + self._execute_write_task(task) + except Exception as _e: + logger.error( + "[HANGFIX] req=%s write_task_failed error=%s " + "action=cleanup_and_mark_done", + task.request_id, _e, + ) + try: + _wr = self.worker.moriio_wrapper + with _wr.lock: + _wr.done_req_ids.append(task.request_id) + _wr.done_remote_allocate_req_dict.pop( + task.request_id, None + ) + except Exception: + pass""" + + if old_loop in new_src: + new_src = new_src.replace(old_loop, new_loop, 1) + else: + print("[SETUP] WARN: _write_worker_loop pattern not found for error handling") + + # --- Patch 3: Add deferred task timeout to _process_deferred_tasks --- + old_deferred = """ def _process_deferred_tasks(self) -> None: + \"\"\"Process tasks that were previously deferred.\"\"\" + if not self._deferred_tasks: + return + + still_deferred: list[WriteTask] = [] + for task in self._deferred_tasks: + if self._is_remote_ready(task): + self._execute_write_task(task) + else: + still_deferred.append(task) + + self._deferred_tasks = still_deferred""" + + new_deferred = """ def _process_deferred_tasks(self) -> None: + \"\"\"Process tasks that were previously deferred.\"\"\" + # [PATCHED] deferred task timeout — prune stale tasks + import time as _time, os as _os + if not self._deferred_tasks: + return + + _DEFER_TIMEOUT = float( + _os.environ.get("VLLM_MORIIO_DEFER_TIMEOUT", "60")) + + still_deferred: list[WriteTask] = [] + for task in self._deferred_tasks: + _age = _time.monotonic() - getattr(task, "_defer_ts", _time.monotonic()) + if _age > _DEFER_TIMEOUT: + logger.error( + "[HANGFIX] req=%s deferred_task_expired age=%.1fs " + "action=drop_and_mark_done", + task.request_id, _age, + ) + try: + _wr = self.worker.moriio_wrapper + with _wr.lock: + _wr.done_req_ids.append(task.request_id) + _wr.done_remote_allocate_req_dict.pop( + task.request_id, None) + except Exception: + pass + continue + if self._is_remote_ready(task): + try: + self._execute_write_task(task) + except Exception as _e: + logger.error( + "[HANGFIX] req=%s deferred_write_failed error=%s", + task.request_id, _e, + ) + try: + _wr = self.worker.moriio_wrapper + with _wr.lock: + _wr.done_req_ids.append(task.request_id) + _wr.done_remote_allocate_req_dict.pop( + task.request_id, None) + except Exception: + pass + else: + still_deferred.append(task) + + self._deferred_tasks = still_deferred""" + + if old_deferred in new_src: + new_src = new_src.replace(old_deferred, new_deferred, 1) + else: + print("[SETUP] WARN: _process_deferred_tasks pattern not found") + + # --- Patch 4: Stamp defer time when task is deferred --- + old_defer_add = """ self._deferred_tasks.append(task)""" + new_defer_add = """ import time as _time2 + if not hasattr(task, "_defer_ts"): + task._defer_ts = _time2.monotonic() + self._deferred_tasks.append(task)""" + if old_defer_add in new_src: + new_src = new_src.replace(old_defer_add, new_defer_add, 1) + else: + print("[SETUP] WARN: deferred task timestamp patch target not found") + + open(f, "w").write(new_src) + print("[SETUP] Patched: transfer timeout + writer error handling") + +except Exception as e: + print(f"[SETUP] WARN patch transfer_timeout: {e}", file=sys.stderr) +' + _SETUP_INSTALLED+=("MoRIIO-transfer-timeout-patch") +} + +# --------------------------------------------------------------------------- +# 10. Patch MoRIIO start_load_kv busy-spin (same pattern as save_kv_layer) +# The READ-mode spin loop in start_load_kv has the same unbounded-spin +# issue as save_kv_layer. Add timeout + sleep + null guard. +# --------------------------------------------------------------------------- +patch_moriio_load_kv_timeout() { + python3 -c ' +import os, sys + +try: + import vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector as mc + f = mc.__file__ + src = open(f).read() + + if "[PATCHED] start_load_kv timeout" in src: + print("[SETUP] start_load_kv timeout patch already applied") + sys.exit(0) + + old = """ while True: + if ( + self._ready_requests.empty() + and remote_engine_id not in self.load_ready_flag + and wait_handshake_readd_req + ): + continue""" + + if old not in src: + print("[SETUP] WARN: start_load_kv busy-spin pattern not found, skipping") + sys.exit(0) + + new = """ # [PATCHED] start_load_kv timeout — prevent model worker deadlock + if remote_engine_id is None and not wait_handshake_readd_req: + self._reqs_to_send.update(metadata.reqs_to_send) + return + import time as _time, os as _os + _wait_start = _time.monotonic() + _LOAD_KV_TIMEOUT = float(_os.environ.get("VLLM_MORIIO_HANDSHAKE_TIMEOUT", "30")) + while True: + if ( + self._ready_requests.empty() + and remote_engine_id not in self.load_ready_flag + and wait_handshake_readd_req + ): + if _time.monotonic() - _wait_start > _LOAD_KV_TIMEOUT: + import logging as _logging + _logging.getLogger("vllm.moriio").warning( + "[HANGFIX] start_load_kv: timeout (%.1fs) waiting for " + "load_ready_flag[%s]", _time.monotonic() - _wait_start, + remote_engine_id) + break + _time.sleep(0.001) + continue""" + + new_src = src.replace(old, new) + if new_src == src: + print("[SETUP] WARN: start_load_kv replacement had no effect") + sys.exit(0) + + open(f, "w").write(new_src) + print("[SETUP] Patched start_load_kv busy-spin with timeout + sleep") +except Exception as e: + print(f"[SETUP] WARN patch start_load_kv: {e}", file=sys.stderr) +' + _SETUP_INSTALLED+=("MoRIIO-load-kv-timeout-patch") +} + +# --------------------------------------------------------------------------- +# 11. Fix READ-mode scheduler assertion in _update_from_kv_xfer_finished +# vLLM v0.17.1 asserts that a request in finished_recving must be either +# WAITING_FOR_REMOTE_KVS or finished. In READ mode the request can +# transition to RUNNING before the aggregated recv notification arrives, +# crashing the engine with AssertionError. +# --------------------------------------------------------------------------- +patch_scheduler_read_mode_fix() { + python3 -c ' +import os, sys + +try: + import vllm.v1.core.sched.scheduler as smod + f = smod.__file__ + src = open(f).read() + + if "[PATCHED] read-mode recv assertion" in src: + print("[SETUP] scheduler read-mode assertion fix already applied") + sys.exit(0) + + old_recv = """ for req_id in kv_connector_output.finished_recving or (): + logger.debug("Finished recving KV transfer for request %s", req_id) + assert req_id in self.requests + req = self.requests[req_id] + if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + self.finished_recving_kv_req_ids.add(req_id) + else: + assert RequestStatus.is_finished(req.status) + self._free_blocks(self.requests[req_id])""" + + new_recv = """ # [PATCHED] read-mode recv assertion — handle intermediate states + for req_id in kv_connector_output.finished_recving or (): + logger.debug("Finished recving KV transfer for request %s", req_id) + if req_id not in self.requests: + logger.debug("Request %s already removed, skipping recv", req_id) + continue + req = self.requests[req_id] + if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + self.finished_recving_kv_req_ids.add(req_id) + elif RequestStatus.is_finished(req.status): + self._free_blocks(self.requests[req_id]) + else: + logger.debug( + "Request %s recv finished but status=%s (not " + "WAITING_FOR_REMOTE_KVS or finished), skipping " + "block free — will be freed on request completion", + req_id, req.status.name)""" + + if old_recv not in src: + print("[SETUP] WARN: scheduler finished_recving pattern not found, skipping") + sys.exit(0) + + new_src = src.replace(old_recv, new_recv, 1) + + old_send = """ for req_id in kv_connector_output.finished_sending or (): + logger.debug("Finished sending KV transfer for request %s", req_id) + assert req_id in self.requests + self._free_blocks(self.requests[req_id])""" + + new_send = """ for req_id in kv_connector_output.finished_sending or (): + logger.debug("Finished sending KV transfer for request %s", req_id) + if req_id not in self.requests: + logger.debug("Request %s already removed, skipping send", req_id) + continue + self._free_blocks(self.requests[req_id])""" + + if old_send in new_src: + new_src = new_src.replace(old_send, new_send, 1) + else: + print("[SETUP] WARN: scheduler finished_sending pattern not found") + + open(f, "w").write(new_src) + print("[SETUP] Patched: scheduler _update_from_kv_xfer_finished read-mode fix") + +except Exception as e: + print(f"[SETUP] WARN patch scheduler read-mode: {e}", file=sys.stderr) +' + _SETUP_INSTALLED+=("scheduler-read-mode-fix") +} + +# --------------------------------------------------------------------------- +# 12. Idle KV block reaper for disaggregated prefill (READ mode) +# The RIXL notification path can lose `finished_sending` signals under +# high concurrency with ibv_post_send failures. This leaves KV blocks +# permanently allocated on the prefill engine even after the decode has +# finished reading. Over multiple benchmark rounds, leaked blocks +# accumulate and eventually saturate the prefill KV cache. +# +# Fix: instrument the scheduler's `schedule()` method to detect idle +# periods (0 running, 0 waiting for >5s) and force-free blocks for +# any remaining requests whose status is finished. +# --------------------------------------------------------------------------- +patch_prefill_idle_kv_reaper() { + python3 -c ' +import os, sys + +try: + import vllm.v1.core.sched.scheduler as smod + f = smod.__file__ + src = open(f).read() + + if "[PATCHED] idle-kv-reaper" in src: + print("[SETUP] idle KV block reaper already applied") + sys.exit(0) + + # Find the _update_from_kv_xfer_finished method end and add reaper logic + # We inject into the method that processes KV transfer completions. + marker = "[PATCHED] read-mode recv assertion" + if marker not in src: + print("[SETUP] WARN: scheduler read-mode patch not found, skipping reaper") + sys.exit(0) + + # Add reaper state initialization to __init__ + old_init_marker = "self.finished_recving_kv_req_ids" + if old_init_marker not in src: + print("[SETUP] WARN: finished_recving_kv_req_ids not found in scheduler") + sys.exit(0) + + # Find the first occurrence to insert reaper state + init_pos = src.find(old_init_marker) + # Find the line containing it + line_end = src.find("\n", init_pos) + init_line = src[init_pos:line_end] + + # Add reaper state after this line + reaper_init = init_line + """ + # [PATCHED] idle-kv-reaper state + self._idle_kv_reaper_ts = 0.0 + self._idle_kv_reaper_active = False""" + + src = src.replace(init_line, reaper_init, 1) + + # Now add the reaper logic at the end of _update_from_kv_xfer_finished + # Find the finished_sending handler we patched + send_handler = """ for req_id in kv_connector_output.finished_sending or (): + logger.debug("Finished sending KV transfer for request %s", req_id) + if req_id not in self.requests: + logger.debug("Request %s already removed, skipping send", req_id) + continue + self._free_blocks(self.requests[req_id])""" + + reaper_logic = send_handler + """ + + # [PATCHED] idle-kv-reaper — force-free leaked prefill KV blocks + import time as _time + _REAPER_IDLE_SECS = 5.0 + _num_running = sum(1 for r in self.requests.values() + if r.status == RequestStatus.RUNNING) + _should_reap = (_num_running == 0) + + if _should_reap: + if not self._idle_kv_reaper_active: + self._idle_kv_reaper_active = True + self._idle_kv_reaper_ts = _time.monotonic() + elif _time.monotonic() - self._idle_kv_reaper_ts > _REAPER_IDLE_SECS: + _reaped = 0 + _reap_ids = [] + for _rid, _req in list(self.requests.items()): + if RequestStatus.is_finished(_req.status): + _reap_ids.append(_rid) + for _rid in _reap_ids: + try: + _req = self.requests[_rid] + self._free_blocks(_req) + _reaped += 1 + except Exception as _e: + logger.debug("[KV-REAPER] free_blocks failed for %s: %s", _rid, _e) + if _reaped > 0: + logger.warning( + "[KV-REAPER] Force-freed blocks for %d finished " + "requests after %.1fs idle", + _reaped, _time.monotonic() - self._idle_kv_reaper_ts) + self._idle_kv_reaper_ts = _time.monotonic() + else: + self._idle_kv_reaper_active = False""" + + if send_handler in src: + src = src.replace(send_handler, reaper_logic, 1) + else: + print("[SETUP] WARN: send handler not found for reaper injection") + sys.exit(0) + + open(f, "w").write(src) + print("[SETUP] Patched: idle KV block reaper for prefill") + +except Exception as e: + print(f"[SETUP] WARN patch idle-kv-reaper: {e}", file=sys.stderr) +' + _SETUP_INSTALLED+=("idle-kv-reaper") +} + +# ============================================================================= +# Run installers +# ============================================================================= + +install_ucx +install_rixl +install_etcd +install_libionic +install_mori +patch_mori_fp8_compat +patch_moriio_save_kv_timeout +patch_moriio_transfer_timeout +patch_moriio_load_kv_timeout +patch_scheduler_read_mode_fix +patch_prefill_idle_kv_reaper + +if [[ "${NODE_RANK:-0}" -eq 0 ]]; then + install_mori_proxy_deps +fi + +# ============================================================================= +# Export paths (persists for server.sh since this file is sourced) +# ============================================================================= + +export ROCM_PATH="${ROCM_PATH}" +export UCX_HOME="${UCX_HOME}" +export RIXL_HOME="${RIXL_HOME}" +export PATH="${UCX_HOME}/bin:/usr/local/bin/etcd:/root/.cargo/bin:${PATH}" +export LD_LIBRARY_PATH="${UCX_HOME}/lib:${RIXL_HOME}/lib:${RIXL_HOME}/lib/x86_64-linux-gnu:${LD_LIBRARY_PATH:-}" + +_SETUP_END=$(date +%s) +if [[ ${#_SETUP_INSTALLED[@]} -eq 0 ]]; then + echo "[SETUP] All dependencies already present (${_SETUP_END}s wallclock)" +else + echo "[SETUP] Installed: ${_SETUP_INSTALLED[*]} in $(( _SETUP_END - _SETUP_START ))s" +fi diff --git a/benchmarks/multi_node/vllm_disagg_utils/start_etcd.sh b/benchmarks/multi_node/vllm_disagg_utils/start_etcd.sh new file mode 100755 index 000000000..46bbd2964 --- /dev/null +++ b/benchmarks/multi_node/vllm_disagg_utils/start_etcd.sh @@ -0,0 +1,47 @@ +#!/bin/bash +set -x + +IPADDRS="${IPADDRS:-localhost}" + +# Use management network IP (matching what the Slurm script resolved) +host_ip=$(ip route get 1.1.1.1 2>/dev/null | sed -n 's/.*src \([^ ]*\).*/\1/p') +if [[ -z "$host_ip" ]]; then + host_ip=$(hostname -I | awk '{print $1}') +fi + +IFS=',' read -ra ADDR <<< "$IPADDRS" + +# Determine node name based on position in the IPADDRS list +index=0 +for ip in "${ADDR[@]}"; do + if [[ "$ip" == "$host_ip" ]]; then + break + fi + index=$((index + 1)) +done +node_name="etcd-$((index+1))" + +# Build initial cluster string +initial_cluster="" +for i in "${!ADDR[@]}"; do + peer_name="etcd-$((i+1))" + initial_cluster+="$peer_name=http://${ADDR[i]}:2380" + if [[ $i -lt $((${#ADDR[@]} - 1)) ]]; then + initial_cluster+="," + fi +done + +mkdir -p /var/lib/etcd +rm -rf /var/lib/etcd/* + +/usr/local/bin/etcd/etcd \ + --name "$node_name" \ + --data-dir /var/lib/etcd \ + --initial-advertise-peer-urls http://$host_ip:2380 \ + --listen-peer-urls http://0.0.0.0:2380 \ + --listen-client-urls http://0.0.0.0:2379 \ + --advertise-client-urls http://$host_ip:2379 \ + --initial-cluster-token etcd-cluster-1 \ + --initial-cluster "$initial_cluster" \ + --initial-cluster-state new \ + 2>&1 | tee /run_logs/slurm_job-${SLURM_JOB_ID}/etcd_NODE${NODE_RANK}.log diff --git a/benchmarks/multi_node/vllm_disagg_utils/submit.sh b/benchmarks/multi_node/vllm_disagg_utils/submit.sh new file mode 100755 index 000000000..ecb5a9876 --- /dev/null +++ b/benchmarks/multi_node/vllm_disagg_utils/submit.sh @@ -0,0 +1,166 @@ +#!/bin/bash +# +# Cluster Configuration Template for Multi-Node vLLM Disaggregated Serving +# +# This script submits a multi-node vLLM disaggregated benchmark job to SLURM. +# It must be configured for your specific cluster before use. +# +# Router is co-located with the first prefill node (same as SGLang), so +# NUM_NODES = PREFILL_NODES + DECODE_NODES. + +usage() { + cat << 'USAGE' +Usage: + bash submit.sh \ + \ + \ + \ + \ + [NODE_LIST] + +Arguments: + PREFILL_NODES Number of prefill nodes + PREFILL_WORKERS Number of prefill workers (usually 1) + DECODE_NODES Number of decode nodes + DECODE_WORKERS Number of decode workers (usually 1) + ISL Input sequence length + OSL Output sequence length + CONCURRENCIES Concurrency levels, delimited by 'x' (e.g., "8x16x32") + REQUEST_RATE Request rate ("inf" for max throughput) + PREFILL_ENABLE_EP true/false (from PREFILL_EP in YAML; false when EP==1) + PREFILL_ENABLE_DP true/false (data-parallel attention on prefill) + DECODE_ENABLE_EP true/false (from DECODE_EP in YAML) + DECODE_ENABLE_DP true/false (data-parallel attention on decode) + PREFILL_TP Tensor parallel size per prefill node + DECODE_TP Tensor parallel size per decode node + RANDOM_RANGE_RATIO Random range ratio for benchmark client + NODE_LIST Optional: comma-separated hostnames (must match NUM_NODES) + +Required environment variables: + SLURM_ACCOUNT SLURM account name + SLURM_PARTITION SLURM partition + TIME_LIMIT Job time limit (e.g., "08:00:00") + MODEL_PATH Path to model directory (e.g., /nfsdata) + MODEL_NAME Model name directory + CONTAINER_IMAGE Docker image name (e.g., vllm_disagg_pd:latest) + RUNNER_NAME Runner identifier (for job name) +USAGE +} + +check_env() { + local name="$1" + if [[ -z "${!name:-}" ]]; then + echo "Error: ${name} not specified" >&2 + usage >&2 + exit 1 + fi +} + +check_env SLURM_ACCOUNT +check_env SLURM_PARTITION +check_env TIME_LIMIT + +check_env MODEL_PATH +check_env MODEL_NAME +check_env CONTAINER_IMAGE +check_env RUNNER_NAME + +GPUS_PER_NODE="${GPUS_PER_NODE:-8}" + +# COMMAND_LINE ARGS (aligned with benchmarks/multi_node/amd_utils/submit.sh) +PREFILL_NODES=$1 +PREFILL_WORKERS=${2:-1} +DECODE_NODES=$3 +DECODE_WORKERS=${4:-1} +ISL=$5 +OSL=$6 +CONCURRENCIES=$7 +REQUEST_RATE=$8 +PREFILL_ENABLE_EP=${9:-false} +PREFILL_ENABLE_DP=${10:-false} +DECODE_ENABLE_EP=${11:-false} +DECODE_ENABLE_DP=${12:-false} +PREFILL_TP=${13:-8} +DECODE_TP=${14:-8} +RANDOM_RANGE_RATIO=${15:-0.8} +NODE_LIST=${16} + +# Router co-located with first prefill: xP + yD nodes total +NUM_NODES=$((PREFILL_NODES + DECODE_NODES)) +profiler_args="${ISL} ${OSL} ${CONCURRENCIES} ${REQUEST_RATE}" + +# Export variables for the SLURM job +export MODEL_DIR=$MODEL_PATH +export DOCKER_IMAGE_NAME=$CONTAINER_IMAGE +export PROFILER_ARGS=$profiler_args + +# For vLLM, each worker = 1 node (TP=8 per node). +# xP/yD must match the node counts so NUM_NODES = xP+yD is correct. +export xP=$PREFILL_NODES +export yD=$DECODE_NODES +export NUM_NODES=$NUM_NODES +export GPUS_PER_NODE=$GPUS_PER_NODE +export MODEL_NAME=$MODEL_NAME +export PREFILL_ENABLE_EP=${PREFILL_ENABLE_EP} +export PREFILL_ENABLE_DP=${PREFILL_ENABLE_DP} +export DECODE_ENABLE_EP=${DECODE_ENABLE_EP} +export DECODE_ENABLE_DP=${DECODE_ENABLE_DP} +export PREFILL_TP=${PREFILL_TP} +export DECODE_TP=${DECODE_TP} +export BENCH_INPUT_LEN=${ISL} +export BENCH_OUTPUT_LEN=${OSL} +export BENCH_NUM_PROMPTS_MULTIPLIER=${BENCH_NUM_PROMPTS_MULTIPLIER:-10} +export BENCH_MAX_CONCURRENCY=${CONCURRENCIES} +export BENCH_REQUEST_RATE=${REQUEST_RATE} +export BENCH_RANDOM_RANGE_RATIO=${RANDOM_RANGE_RATIO:-0.8} + +export PROXY_STREAM_IDLE_TIMEOUT=${PROXY_STREAM_IDLE_TIMEOUT:-300} +export VLLM_MORIIO_CONNECTOR_READ_MODE=${VLLM_MORIIO_CONNECTOR_READ_MODE:-1} + +# Log directory: must be on NFS (shared filesystem) so the submit host can read SLURM output. +export BENCHMARK_LOGS_DIR="${BENCHMARK_LOGS_DIR:-$(pwd)/benchmark_logs}" +mkdir -p "$BENCHMARK_LOGS_DIR" + +# Optional: pass an explicit node list to sbatch. +NODELIST_OPT=() +if [[ -n "${NODE_LIST//[[:space:]]/}" ]]; then + IFS=',' read -r -a NODE_ARR <<< "$NODE_LIST" + if [[ "${#NODE_ARR[@]}" -ne "$NUM_NODES" ]]; then + echo "Error: NODE_LIST has ${#NODE_ARR[@]} nodes but NUM_NODES=${NUM_NODES}" >&2 + echo "Error: NODE_LIST='${NODE_LIST}'" >&2 + exit 1 + fi + NODELIST_CSV="$(IFS=,; echo "${NODE_ARR[*]}")" + NODELIST_OPT=(--nodelist "$NODELIST_CSV") +fi + +# Optional: exclude specific nodes (e.g. nodes with broken Docker sockets). +# Set SLURM_EXCLUDE_NODES env var to a comma-separated list of hostnames. +EXCLUDE_OPT=() +if [[ -n "${SLURM_EXCLUDE_NODES:-}" ]]; then + EXCLUDE_OPT=(--exclude "$SLURM_EXCLUDE_NODES") +fi + +# Construct the sbatch command +sbatch_cmd=( + sbatch + --parsable + -N "$NUM_NODES" + -n "$NUM_NODES" + "${NODELIST_OPT[@]}" + "${EXCLUDE_OPT[@]}" + --time "$TIME_LIMIT" + --partition "$SLURM_PARTITION" + --account "$SLURM_ACCOUNT" + --job-name "$RUNNER_NAME" + --output "${BENCHMARK_LOGS_DIR}/slurm_job-%j.out" + --error "${BENCHMARK_LOGS_DIR}/slurm_job-%j.err" + "$(dirname "$0")/job.slurm" +) + +JOB_ID=$("${sbatch_cmd[@]}") +if [[ $? -ne 0 ]]; then + echo "Error: Failed to submit job with sbatch" >&2 + exit 1 +fi +echo "$JOB_ID" diff --git a/benchmarks/multi_node/vllm_disagg_utils/sync.py b/benchmarks/multi_node/vllm_disagg_utils/sync.py new file mode 100755 index 000000000..3678e7614 --- /dev/null +++ b/benchmarks/multi_node/vllm_disagg_utils/sync.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +""" +Multi-node synchronization utilities for disaggregated inference. + +Subcommands: + barrier - Wait until all specified nodes have opened their ports (TCP barrier) + Optionally wait for HTTP health endpoints to return 200 + wait - Block until a remote port closes (shutdown coordination) +""" + +import socket +import time +import threading +import argparse +import sys +import urllib.request +import urllib.error + + +def is_port_open(ip, port, timeout=2): + """Check if a given IP and port are accessible.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.settimeout(timeout) + return s.connect_ex((ip, port)) == 0 + + +def check_health(ip, port, path="/health", timeout=2): + """Return True if http://ip:port/path returns HTTP 200.""" + try: + url = f"http://{ip}:{port}{path}" + req = urllib.request.Request(url) + with urllib.request.urlopen(req, timeout=timeout) as resp: + return getattr(resp, "status", 200) == 200 + except (urllib.error.URLError, urllib.error.HTTPError, OSError): + return False + + +# ============================================================================= +# barrier subcommand +# ============================================================================= + +def cmd_barrier(args): + """Wait until all nodes have opened the specified ports.""" + NODE_IPS = [ip.strip() for ip in args.node_ips.split(",") if ip.strip()] + NODE_PORTS = [int(p.strip()) for p in args.node_ports.split(",") if p.strip()] + + if not NODE_IPS: + print("Error: NODE_IPS argument is empty or not set.") + sys.exit(1) + + if len(NODE_PORTS) == 1: + NODE_PORTS *= len(NODE_IPS) + elif len(NODE_PORTS) != len(NODE_IPS): + print("Error: Number of ports must match number of node IPs or only one port should be given for all.") + sys.exit(1) + + server_socket = None + + def open_port(): + nonlocal server_socket + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server_socket.bind((args.local_ip, args.local_port)) + server_socket.listen(5) + print(f"Port {args.local_port} is now open on {args.local_ip}.") + while True: + conn, addr = server_socket.accept() + conn.close() + + def close_port(): + nonlocal server_socket + if server_socket: + server_socket.close() + print(f"Port {args.local_port} has been closed on {args.local_ip}.") + + if args.enable_port: + threading.Thread(target=open_port, daemon=True).start() + + # Wait for all ports (TCP check) + if args.wait_for_all_ports: + start_time = time.time() + timeout = args.timeout + + while True: + if timeout > 0: + elapsed = time.time() - start_time + if elapsed >= timeout: + not_open = [(ip, port) for ip, port in zip(NODE_IPS, NODE_PORTS) + if not is_port_open(ip, port)] + print(f"ERROR: Timeout after {timeout} seconds waiting for ports to open.", flush=True) + print("The following nodes/ports are still not responding:", flush=True) + for ip, port in not_open: + print(f" - {ip}:{port}", flush=True) + sys.exit(1) + + all_open = all(is_port_open(ip, port) for ip, port in zip(NODE_IPS, NODE_PORTS)) + if all_open: + break + + if timeout > 0: + remaining = timeout - (time.time() - start_time) + print(f"Waiting for nodes.{NODE_PORTS},{NODE_IPS} . . ({remaining:.0f}s remaining)", flush=True) + else: + print(f"Waiting for nodes.{NODE_PORTS},{NODE_IPS} . .", flush=True) + time.sleep(5) + + # Wait for all health endpoints (HTTP check) + if args.wait_for_all_health: + health_path = args.health_endpoint + start_time = time.time() + timeout = args.timeout + + while True: + if timeout > 0: + elapsed = time.time() - start_time + if elapsed >= timeout: + not_ready = [ + (ip, port) + for ip, port in zip(NODE_IPS, NODE_PORTS) + if not check_health(ip, port, health_path) + ] + print(f"ERROR: Timeout after {timeout} seconds waiting for health endpoints.", flush=True) + print(f"The following (http://ip:port{health_path}) are still not responding:", flush=True) + for ip, port in not_ready: + print(f" - http://{ip}:{port}{health_path}", flush=True) + sys.exit(1) + + all_ready = all( + check_health(ip, port, health_path) + for ip, port in zip(NODE_IPS, NODE_PORTS) + ) + if all_ready: + break + + if timeout > 0: + remaining = timeout - (time.time() - start_time) + print( + f"Waiting for health on {list(zip(NODE_IPS, NODE_PORTS))} ({health_path}) .. ({remaining:.0f}s remaining)", + flush=True, + ) + else: + print(f"Waiting for health on {list(zip(NODE_IPS, NODE_PORTS))} ({health_path}) ..", flush=True) + time.sleep(30) + + if args.enable_port: + # Keep the port open long enough for slow nodes to pass their barrier. + # The previous 30s was too short when setup times vary by minutes. + grace = max(60, args.timeout // 2) if args.timeout > 0 else 300 + time.sleep(grace) + close_port() + + +# ============================================================================= +# wait subcommand +# ============================================================================= + +def cmd_wait(args): + """Wait while a remote port remains open, exit when it closes.""" + print(f"Waiting while port {args.remote_port} on {args.remote_ip} is open...") + while is_port_open(args.remote_ip, args.remote_port): + time.sleep(5) + print(f"Port {args.remote_port} on {args.remote_ip} is now closed.") + + +# ============================================================================= +# CLI +# ============================================================================= + +def main(): + parser = argparse.ArgumentParser(description="Multi-node synchronization utilities.") + subparsers = parser.add_subparsers(dest="command", required=True) + + # barrier subcommand + bp = subparsers.add_parser("barrier", help="Wait for all nodes to open specified ports.") + bp.add_argument("--local-ip", required=False, help="Local IP address to bind the server.") + bp.add_argument("--local-port", type=int, required=False, help="Port number to bind the server.") + bp.add_argument("--enable-port", action="store_true", help="Enable opening and closing of local port.") + bp.add_argument("--node-ips", required=True, help="Comma-separated list of node IPs.") + bp.add_argument("--node-ports", required=True, help="Comma-separated list of ports to check.") + bp.add_argument("--timeout", type=int, default=600, + help="Timeout in seconds (default: 600). Set to 0 for no timeout.") + bp.add_argument("--wait-for-all-ports", action="store_true", + help="Wait until all node ports are open (TCP).") + bp.add_argument("--wait-for-all-health", action="store_true", + help="Wait until http://ip:port/health returns 200 for all nodes.") + bp.add_argument("--health-endpoint", default="/health", + help="Path for health check (default: /health).") + bp.set_defaults(func=cmd_barrier) + + # wait subcommand + wp = subparsers.add_parser("wait", help="Wait while a remote port remains open.") + wp.add_argument("--remote-ip", required=True, help="Remote server IP address.") + wp.add_argument("--remote-port", type=int, required=True, help="Remote port number.") + wp.set_defaults(func=cmd_wait) + + args = parser.parse_args() + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/runners/launch_mi355x-amds.sh b/runners/launch_mi355x-amds.sh index 52e28e9b8..e03606a0f 100644 --- a/runners/launch_mi355x-amds.sh +++ b/runners/launch_mi355x-amds.sh @@ -52,7 +52,7 @@ if [[ "$IS_MULTINODE" == "true" ]]; then sudo rm -rf "$BENCHMARK_LOGS_DIR/logs" 2>/dev/null || true SCRIPT_NAME="${EXP_NAME%%_*}_${PRECISION}_mi355x_${FRAMEWORK}.sh" - if [[ "$FRAMEWORK" == "sglang-disagg" ]]; then + if [[ "$FRAMEWORK" == "sglang-disagg" || "$FRAMEWORK" == "vllm-disagg" ]]; then BENCHMARK_SUBDIR="multi_node" else BENCHMARK_SUBDIR="single_node" @@ -103,8 +103,17 @@ if [[ "$IS_MULTINODE" == "true" ]]; then cat > collect_latest_results.py <<'PY' import os, sys -sgl_job_dir, isl, osl, nexp = sys.argv[1], int(sys.argv[2]), int(sys.argv[3]), int(sys.argv[4]) -for path in sorted([f"{sgl_job_dir}/logs/{name}/sglang_isl_{isl}_osl_{osl}" for name in os.listdir(f"{sgl_job_dir}/logs/") if os.path.isdir(f"{sgl_job_dir}/logs/{name}/sglang_isl_{isl}_osl_{osl}")], key=os.path.getmtime, reverse=True)[:nexp]: +job_dir, isl, osl, nexp = sys.argv[1], int(sys.argv[2]), int(sys.argv[3]), int(sys.argv[4]) +prefixes = ["sglang", "vllm"] +logs_root = f"{job_dir}/logs/" +candidates = [] +if os.path.isdir(logs_root): + for name in os.listdir(logs_root): + for pfx in prefixes: + subdir = f"{logs_root}{name}/{pfx}_isl_{isl}_osl_{osl}" + if os.path.isdir(subdir): + candidates.append(subdir) +for path in sorted(candidates, key=os.path.getmtime, reverse=True)[:nexp]: print(path) PY diff --git a/utils/bench_serving/backend_request_func.py b/utils/bench_serving/backend_request_func.py index 32331a398..5ba629c06 100644 --- a/utils/bench_serving/backend_request_func.py +++ b/utils/bench_serving/backend_request_func.py @@ -14,7 +14,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) -AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=30 * 60) @dataclass @@ -49,12 +49,16 @@ class RequestFuncOutput: async def async_request_tgi( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, + session: Optional[aiohttp.ClientSession] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + _own_session = session is None + if _own_session: + session = aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) + try: params = { "best_of": request_func_input.best_of, "max_new_tokens": request_func_input.output_len, @@ -62,7 +66,6 @@ async def async_request_tgi( "temperature": 0.01, # TGI does not accept 0.0 temperature. "top_p": 0.99, # TGI does not accept 1.0 top_p. "truncate": request_func_input.prompt_len, - # TGI does not accept ignore_eos flag. } payload = { "inputs": request_func_input.prompt, @@ -113,21 +116,28 @@ async def async_request_tgi( output.success = False exc_info = sys.exc_info() output.error = "".join(traceback.format_exception(*exc_info)) + finally: + if _own_session: + await session.close() - if pbar: - pbar.update(1) - return output + if pbar: + pbar.update(1) + return output async def async_request_trt_llm( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, + session: Optional[aiohttp.ClientSession] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + _own_session = session is None + if _own_session: + session = aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) + try: assert request_func_input.best_of == 1 payload = { "accumulate_tokens": True, @@ -181,18 +191,25 @@ async def async_request_trt_llm( output.success = False exc_info = sys.exc_info() output.error = "".join(traceback.format_exception(*exc_info)) + finally: + if _own_session: + await session.close() - if pbar: - pbar.update(1) - return output + if pbar: + pbar.update(1) + return output async def async_request_deepspeed_mii( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, + session: Optional[aiohttp.ClientSession] = None, ) -> RequestFuncOutput: - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + _own_session = session is None + if _own_session: + session = aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) + try: assert request_func_input.best_of == 1 payload = { @@ -225,23 +242,30 @@ async def async_request_deepspeed_mii( output.success = False exc_info = sys.exc_info() output.error = "".join(traceback.format_exception(*exc_info)) + finally: + if _own_session: + await session.close() - if pbar: - pbar.update(1) - return output + if pbar: + pbar.update(1) + return output async def async_request_openai_completions( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, + session: Optional[aiohttp.ClientSession] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( ("completions", "profile") ), "OpenAI Completions API URL must end with 'completions' or 'profile'." - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + _own_session = session is None + if _own_session: + session = aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) + try: payload = { "model": request_func_input.model_name \ if request_func_input.model_name else request_func_input.model, @@ -281,33 +305,35 @@ async def async_request_openai_completions( chunk = chunk_bytes.decode("utf-8").removeprefix( "data: ") - if chunk != "[DONE]": - data = json.loads(chunk) - - # NOTE: Some completion API might have a last - # usage summary response without a token so we - # want to check a token was generated - if choices := data.get("choices"): - # Note that text could be empty here - # e.g. for special tokens - text = choices[0].get("text") - timestamp = time.perf_counter() - # First token - if not first_chunk_received: - first_chunk_received = True - ttft = time.perf_counter() - st - output.ttft = ttft - - # Decoding phase - else: - output.itl.append(timestamp - - most_recent_timestamp) - - most_recent_timestamp = timestamp - generated_text += text or "" - elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + if chunk == "[DONE]": + break + + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") + timestamp = time.perf_counter() + # First token + if not first_chunk_received: + first_chunk_received = True + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += text or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") if first_chunk_received: output.success = True else: @@ -324,6 +350,9 @@ async def async_request_openai_completions( output.success = False exc_info = sys.exc_info() output.error = "".join(traceback.format_exception(*exc_info)) + finally: + if _own_session: + await session.close() if pbar: pbar.update(1) @@ -333,14 +362,18 @@ async def async_request_openai_completions( async def async_request_openai_chat_completions( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, + session: Optional[aiohttp.ClientSession] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( "chat/completions" ), "OpenAI Chat Completions API URL must end with 'chat/completions'." - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + _own_session = session is None + if _own_session: + session = aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) + try: content = [{"type": "text", "text": request_func_input.prompt}] if request_func_input.multi_modal_content: content.append(request_func_input.multi_modal_content) @@ -387,28 +420,30 @@ async def async_request_openai_chat_completions( chunk = chunk_bytes.decode("utf-8").removeprefix( "data: ") - if chunk != "[DONE]": - timestamp = time.perf_counter() - data = json.loads(chunk) + if chunk == "[DONE]": + break - if choices := data.get("choices"): - content = choices[0]["delta"].get("content") - # First token - if ttft == 0.0: - ttft = timestamp - st - output.ttft = ttft + timestamp = time.perf_counter() + data = json.loads(chunk) - # Decoding phase - else: - output.itl.append(timestamp - - most_recent_timestamp) + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft - generated_text += content or "" - elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) - most_recent_timestamp = timestamp + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + + most_recent_timestamp = timestamp output.generated_text = generated_text output.success = True @@ -420,6 +455,9 @@ async def async_request_openai_chat_completions( output.success = False exc_info = sys.exc_info() output.error = "".join(traceback.format_exception(*exc_info)) + finally: + if _own_session: + await session.close() if pbar: pbar.update(1) diff --git a/utils/bench_serving/benchmark_serving.py b/utils/bench_serving/benchmark_serving.py index 38365dbfc..b841fe69f 100644 --- a/utils/bench_serving/benchmark_serving.py +++ b/utils/bench_serving/benchmark_serving.py @@ -26,6 +26,7 @@ import argparse import asyncio import base64 +import contextlib import gc import io import json @@ -37,9 +38,10 @@ from datetime import datetime from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple +import aiohttp import numpy as np -from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, - RequestFuncOutput) +from backend_request_func import (AIOHTTP_TIMEOUT, ASYNC_REQUEST_FUNCS, + RequestFuncInput, RequestFuncOutput) from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase @@ -348,11 +350,14 @@ async def benchmark( else: raise ValueError(f"Unknown backend: {backend}") + connector = aiohttp.TCPConnector(limit=0, enable_cleanup_closed=True) + shared_session = aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT, connector=connector) + print("Starting initial single prompt test run...") test_prompt, test_prompt_len, test_output_len, test_mm_content = ( input_requests[0]) if backend != "openai-chat" and test_mm_content is not None: - # multi-modal benchmark is only available on OpenAI Chat backend. raise ValueError( "Multi-modal content is only supported on 'openai-chat' backend.") test_input = RequestFuncInput( @@ -371,11 +376,13 @@ async def benchmark( if num_warmups > 0: print(f"Warming up with {num_warmups} requests...") warmup_pbar = None if disable_tqdm else tqdm(total=num_warmups) - warmup_semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else contextlib.nullcontext() + warmup_semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else asyncio.Semaphore(num_warmups) async def warmup_limited_req_fn(): async with warmup_semaphore: - return await request_func(request_func_input=test_input, pbar=warmup_pbar) + return await request_func( + request_func_input=test_input, pbar=warmup_pbar, + session=shared_session) warmup_tasks = [] for _ in range(num_warmups): @@ -388,7 +395,6 @@ async def warmup_limited_req_fn(): print("Warmup completed.") if lora_modules: - # For each input request, choose a LoRA module at random. lora_modules = iter( [random.choice(lora_modules) for _ in range(len(input_requests))]) @@ -405,7 +411,8 @@ async def warmup_limited_req_fn(): best_of=best_of, multi_modal_content=test_mm_content, ignore_eos=ignore_eos) - profile_output = await request_func(request_func_input=profile_input) + profile_output = await request_func( + request_func_input=profile_input, session=shared_session) if profile_output.success: print("Profiler started") @@ -420,20 +427,16 @@ async def warmup_limited_req_fn(): pbar = None if disable_tqdm else tqdm(total=len(input_requests)) - # This can be used once the minimum Python version is 3.10 or higher, - # and it will simplify the code in limited_request_func. - # semaphore = (asyncio.Semaphore(max_concurrency) - # if max_concurrency else contextlib.nullcontext()) semaphore = (asyncio.Semaphore(max_concurrency) if max_concurrency else None) async def limited_request_func(request_func_input, pbar): if semaphore is None: return await request_func(request_func_input=request_func_input, - pbar=pbar) + pbar=pbar, session=shared_session) async with semaphore: return await request_func(request_func_input=request_func_input, - pbar=pbar) + pbar=pbar, session=shared_session) print("Starting main benchmark run...") @@ -460,7 +463,28 @@ async def limited_request_func(request_func_input, pbar): asyncio.create_task( limited_request_func(request_func_input=request_func_input, pbar=pbar))) - outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + gather_timeout = max(7200, len(input_requests) * 30) + try: + outputs: List[RequestFuncOutput] = await asyncio.wait_for( + asyncio.gather(*tasks), timeout=gather_timeout) + except asyncio.TimeoutError: + completed = pbar.n if pbar else "?" + print(f"\n[WARNING] Benchmark timed out after {gather_timeout}s " + f"({completed}/{len(tasks)} requests completed). " + "Collecting partial results...") + for task in tasks: + if not task.done(): + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + outputs = [] + for task in tasks: + if task.done() and not task.cancelled(): + try: + outputs.append(task.result()) + except Exception: + outputs.append(RequestFuncOutput()) + else: + outputs.append(RequestFuncOutput()) if profile: print("Stopping profiler...") @@ -473,10 +497,14 @@ async def limited_request_func(request_func_input, pbar): logprobs=logprobs, best_of=best_of, ) - profile_output = await request_func(request_func_input=profile_input) + profile_output = await request_func( + request_func_input=profile_input, session=shared_session) if profile_output.success: print("Profiler stopped") + await shared_session.close() + await connector.close() + if pbar is not None: pbar.close()