From 8a0ea4759c7ac5b7ede5061f98d2fcac3fb28be2 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 9 Mar 2026 17:05:04 -0500 Subject: [PATCH 01/25] initial impl --- .github/workflows/rocm-ci.yml | 485 ++++++++++++++++----------- benchmarks/benchmark_gemm.py | 164 +++++++++ benchmarks/benchmark_grouped_gemm.py | 274 +++++++++++++++ benchmarks/compare_results.py | 132 ++++++++ 4 files changed, 864 insertions(+), 191 deletions(-) create mode 100755 benchmarks/benchmark_gemm.py create mode 100755 benchmarks/benchmark_grouped_gemm.py create mode 100644 benchmarks/compare_results.py diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 32c3cb2a2..b573c5414 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -34,6 +34,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true +permissions: + pull-requests: write + jobs: build_and_test: name: Build and Test on GPU (${{ matrix.runner }}) @@ -221,216 +224,316 @@ jobs: EOF )" - - name: Run sGPU tests - id: sgpu-tests - continue-on-error: true - run: | - # Cleanup previous failure markers if any. Don't actually do anything on k8s pods - rm -f FAIL_* - - docker exec \ - -e TEST_SGPU=1 \ - -e TEST_LEVEL=${{ env.TEST_LEVEL }} \ - te-runner bash -c "$(cat <<'EOF' - #!/usr/bin/bash - set -x -o pipefail - ulimit -c 0 # Disable core dumps - - # debug output - ls -d /opt/rocm* - python --version - pip list | egrep "transformer_e|torch|jax|numpy|ml_dtypes|typing_ext" - - HIP_VISIBLE_DEVICES=1 ci/pytorch.sh > /workspace/torch_sgpu.log 2>&1 & - torch_pid=$!; echo Pytorch test pid $! + # - name: Run sGPU tests + # id: sgpu-tests + # continue-on-error: true + # run: | + # # Cleanup previous failure markers if any. Don't actually do anything on k8s pods + # rm -f FAIL_* + + # docker exec \ + # -e TEST_SGPU=1 \ + # -e TEST_LEVEL=${{ env.TEST_LEVEL }} \ + # te-runner bash -c "$(cat <<'EOF' + # #!/usr/bin/bash + # set -x -o pipefail + # ulimit -c 0 # Disable core dumps + + # # debug output + # ls -d /opt/rocm* + # python --version + # pip list | egrep "transformer_e|torch|jax|numpy|ml_dtypes|typing_ext" + + # HIP_VISIBLE_DEVICES=1 ci/pytorch.sh > /workspace/torch_sgpu.log 2>&1 & + # torch_pid=$!; echo Pytorch test pid $! - HIP_VISIBLE_DEVICES=2 ci/jax.sh > /workspace/jax_sgpu.log 2>&1 & - jax_pid=$!; echo JAX test pid $! + # HIP_VISIBLE_DEVICES=2 ci/jax.sh > /workspace/jax_sgpu.log 2>&1 & + # jax_pid=$!; echo JAX test pid $! - HIP_VISIBLE_DEVICES=3 ci/core.sh > /workspace/core_sgpu.log 2>&1 & - core_pid=$!; echo Core test pid $! + # HIP_VISIBLE_DEVICES=3 ci/core.sh > /workspace/core_sgpu.log 2>&1 & + # core_pid=$!; echo Core test pid $! - wait $core_pid; core_rc=$? - wait $jax_pid; jax_rc=$? - wait $torch_pid; torch_rc=$? + # wait $core_pid; core_rc=$? + # wait $jax_pid; jax_rc=$? + # wait $torch_pid; torch_rc=$? - # /workspace/FAIL_* files are for failure markers we can extract to the host runner and process later - # Check PyTorch - if [ $torch_rc -ne 0 ]; then - echo "::group::[FAILED] PyTorch sGPU Log" - cat /workspace/torch_sgpu.log - echo "::endgroup::" - echo "::error::Pytorch sGPU test FAILED." - touch /workspace/FAIL_TORCH_SGPU - fi - - # Check JAX - if [ $jax_rc -ne 0 ]; then - echo "::group::[FAILED] JAX sGPU Log" - cat /workspace/jax_sgpu.log - echo "::endgroup::" - echo "::error::JAX sGPU test FAILED." - touch /workspace/FAIL_JAX_SGPU - fi - - # Check Core - if [ $core_rc -ne 0 ]; then - echo "::group::[FAILED] Core sGPU Log" - cat /workspace/core_sgpu.log - echo "::endgroup::" - echo "::error::Core sGPU test FAILED." - touch /workspace/FAIL_CORE_SGPU - fi + # # /workspace/FAIL_* files are for failure markers we can extract to the host runner and process later + # # Check PyTorch + # if [ $torch_rc -ne 0 ]; then + # echo "::group::[FAILED] PyTorch sGPU Log" + # cat /workspace/torch_sgpu.log + # echo "::endgroup::" + # echo "::error::Pytorch sGPU test FAILED." + # touch /workspace/FAIL_TORCH_SGPU + # fi + + # # Check JAX + # if [ $jax_rc -ne 0 ]; then + # echo "::group::[FAILED] JAX sGPU Log" + # cat /workspace/jax_sgpu.log + # echo "::endgroup::" + # echo "::error::JAX sGPU test FAILED." + # touch /workspace/FAIL_JAX_SGPU + # fi + + # # Check Core + # if [ $core_rc -ne 0 ]; then + # echo "::group::[FAILED] Core sGPU Log" + # cat /workspace/core_sgpu.log + # echo "::endgroup::" + # echo "::error::Core sGPU test FAILED." + # touch /workspace/FAIL_CORE_SGPU + # fi - test $torch_rc -eq 0 -a $jax_rc -eq 0 -a $core_rc -eq 0 - EOF - )" + # test $torch_rc -eq 0 -a $jax_rc -eq 0 -a $core_rc -eq 0 + # EOF + # )" - # Export failed tests statuses to host runner - if [ -f FAIL_TORCH_SGPU ]; then echo "torch=fail" >> $GITHUB_OUTPUT; fi - if [ -f FAIL_JAX_SGPU ]; then echo "jax=fail" >> $GITHUB_OUTPUT; fi - if [ -f FAIL_CORE_SGPU ]; then echo "core=fail" >> $GITHUB_OUTPUT; fi - - - name: Run mGPU tests - id: mgpu-tests - continue-on-error: true - run: | - docker exec \ - -e TEST_MGPU=1 \ - -e TEST_LEVEL=${{ env.TEST_LEVEL }} \ - te-runner bash -c "$(cat <<'EOF' - #!/usr/bin/bash - set -x -o pipefail - ulimit -c 0 # Disable core dumps + # # Export failed tests statuses to host runner + # if [ -f FAIL_TORCH_SGPU ]; then echo "torch=fail" >> $GITHUB_OUTPUT; fi + # if [ -f FAIL_JAX_SGPU ]; then echo "jax=fail" >> $GITHUB_OUTPUT; fi + # if [ -f FAIL_CORE_SGPU ]; then echo "core=fail" >> $GITHUB_OUTPUT; fi + + # - name: Run mGPU tests + # id: mgpu-tests + # continue-on-error: true + # run: | + # docker exec \ + # -e TEST_MGPU=1 \ + # -e TEST_LEVEL=${{ env.TEST_LEVEL }} \ + # te-runner bash -c "$(cat <<'EOF' + # #!/usr/bin/bash + # set -x -o pipefail + # ulimit -c 0 # Disable core dumps - # Run PyTorch - ci/pytorch.sh > /workspace/torch_mgpu.log 2>&1 - torch_rc=$? + # # Run PyTorch + # ci/pytorch.sh > /workspace/torch_mgpu.log 2>&1 + # torch_rc=$? - # Run JAX - ci/jax.sh > /workspace/jax_mgpu.log 2>&1 - jax_rc=$? + # # Run JAX + # ci/jax.sh > /workspace/jax_mgpu.log 2>&1 + # jax_rc=$? - # /workspace/FAIL_* files are for failure markers we can extract to the host runner and process later - if [ $torch_rc -ne 0 ]; then - echo "::group::[FAILED] PyTorch mGPU Log" - cat /workspace/torch_mgpu.log - echo "::endgroup::" - echo "::error::Pytorch mGPU test FAILED." - touch /workspace/FAIL_TORCH_MGPU - fi - - if [ $jax_rc -ne 0 ]; then - echo "::group::[FAILED] JAX mGPU Log" - cat /workspace/jax_mgpu.log - echo "::endgroup::" - echo "::error::JAX mGPU test FAILED." - touch /workspace/FAIL_JAX_MGPU - fi + # # /workspace/FAIL_* files are for failure markers we can extract to the host runner and process later + # if [ $torch_rc -ne 0 ]; then + # echo "::group::[FAILED] PyTorch mGPU Log" + # cat /workspace/torch_mgpu.log + # echo "::endgroup::" + # echo "::error::Pytorch mGPU test FAILED." + # touch /workspace/FAIL_TORCH_MGPU + # fi + + # if [ $jax_rc -ne 0 ]; then + # echo "::group::[FAILED] JAX mGPU Log" + # cat /workspace/jax_mgpu.log + # echo "::endgroup::" + # echo "::error::JAX mGPU test FAILED." + # touch /workspace/FAIL_JAX_MGPU + # fi - test $torch_rc -eq 0 -a $jax_rc -eq 0 - EOF + # test $torch_rc -eq 0 -a $jax_rc -eq 0 + # EOF + # )" + + # # Export failed tests statuses to host runner + # if [ -f FAIL_TORCH_MGPU ]; then echo "torch=fail" >> $GITHUB_OUTPUT; fi + # if [ -f FAIL_JAX_MGPU ]; then echo "jax=fail" >> $GITHUB_OUTPUT; fi + + # - name: Run Examples + # id: examples-tests + # continue-on-error: true + # env: + # HF_TOKEN: ${{ secrets.HF_TOKEN }} + # run: | + # docker exec -e HF_TOKEN="$HF_TOKEN" te-runner bash -c "$(cat <<'EOF' + # #!/usr/bin/bash + # set -ex -o pipefail + # ulimit -c 0 # Disable core dumps + + # # Check whether the HF_TOKEN is present + # python -c "import os; print('HF_TOKEN set:', bool(os.environ.get('HF_TOKEN')))" + + # cd /workspace/examples/pytorch/mnist + # python main.py 2>&1 | tee /workspace/examples.log + # python main.py --use-te 2>&1 | tee -a /workspace/examples.log + # python main.py --use-fp8 2>&1 | tee -a /workspace/examples.log + + # cd /workspace/examples/jax/mnist + # pip3 install -r requirements.txt + # python test_single_gpu_mnist.py 2>&1 | tee -a /workspace/examples.log + # python test_single_gpu_mnist.py --use-te 2>&1 | tee -a /workspace/examples.log + # python test_single_gpu_mnist.py --use-fp8 2>&1 | tee -a /workspace/examples.log + + # cd /workspace/examples/jax/encoder + # pip3 install -r requirements.txt + # python test_single_gpu_encoder.py 2>&1 | tee -a /workspace/examples.log + # python test_single_gpu_encoder.py --use-fp8 2>&1 | tee -a /workspace/examples.log + # EOF + # )" + + - name: "Performance regression check" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.pull_request.number }} + GH_REPO: ${{ github.repository }} + RUNNER_NAME: ${{ matrix.runner }} + run: | + set -ex + + # Restore PR checkout no matter how this step exits + trap 'git checkout ${{ github.sha }} && git submodule update --init --recursive' EXIT + + # Benchmark PR branch (already built) + docker exec te-runner bash -c "$(cat <<'OUTER' + set -ex + pip install pandas tabulate + cd /workspace + + mkdir -p perf_results/pr + for bench in benchmarks/benchmark_*.py; do + name=$(basename "$bench" .py) + echo "=== Running $name (PR) ===" + python "$bench" + mv "${name}.csv" perf_results/pr/ + done + + # Stash benchmark scripts so they survive the base branch checkout + mkdir -p .perf_stash + cp benchmarks/benchmark_*.py benchmarks/compare_results.py .perf_stash/ + OUTER )" - # Export failed tests statuses to host runner - if [ -f FAIL_TORCH_MGPU ]; then echo "torch=fail" >> $GITHUB_OUTPUT; fi - if [ -f FAIL_JAX_MGPU ]; then echo "jax=fail" >> $GITHUB_OUTPUT; fi + # Checkout base branch (on host, where git credentials exist) + git fetch origin ${{ github.base_ref }} --depth=1 + git checkout FETCH_HEAD + git submodule update --init --recursive - - name: Run Examples - id: examples-tests - continue-on-error: true - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: | - docker exec -e HF_TOKEN="$HF_TOKEN" te-runner bash -c "$(cat <<'EOF' - #!/usr/bin/bash - set -ex -o pipefail - ulimit -c 0 # Disable core dumps - - # Check whether the HF_TOKEN is present - python -c "import os; print('HF_TOKEN set:', bool(os.environ.get('HF_TOKEN')))" - - cd /workspace/examples/pytorch/mnist - python main.py 2>&1 | tee /workspace/examples.log - python main.py --use-te 2>&1 | tee -a /workspace/examples.log - python main.py --use-fp8 2>&1 | tee -a /workspace/examples.log - - cd /workspace/examples/jax/mnist - pip3 install -r requirements.txt - python test_single_gpu_mnist.py 2>&1 | tee -a /workspace/examples.log - python test_single_gpu_mnist.py --use-te 2>&1 | tee -a /workspace/examples.log - python test_single_gpu_mnist.py --use-fp8 2>&1 | tee -a /workspace/examples.log - - cd /workspace/examples/jax/encoder - pip3 install -r requirements.txt - python test_single_gpu_encoder.py 2>&1 | tee -a /workspace/examples.log - python test_single_gpu_encoder.py --use-fp8 2>&1 | tee -a /workspace/examples.log - EOF + # Rebuild base, benchmark, compare, build report + docker exec \ + -e GPU_ARCH=${{ steps.container-diag.outputs.arch }} \ + te-runner bash -c "$(cat <<'OUTER' + set -ex + cd /workspace + + # Rebuild base branch + export HIP_PATH="" + export PYTORCH_ROCM_ARCH=$GPU_ARCH + export NVTE_ROCM_ARCH=$GPU_ARCH + export NVTE_AITER_PREBUILT_BASE_URL=https://compute-artifactory.amd.com:5000/artifactory/rocm-generic-local/te-ci/aiter-prebuilts + pip install ninja + git config --global --add safe.directory '*' + pip install --no-build-isolation . 2>&1 + + # Benchmark base branch + mkdir -p perf_results/base + for bench in .perf_stash/benchmark_*.py; do + name=$(basename "$bench" .py) + echo "=== Running $name (base) ===" + python "$bench" + mv "${name}.csv" perf_results/base/ + done + + # Compare and build report + FAIL=0 + { + echo "## Performance Report (\`${{ matrix.runner }}\`)" + echo "" + for pr_csv in perf_results/pr/benchmark_*.csv; do + name=$(basename "$pr_csv" .csv) + base_csv="perf_results/base/${name}.csv" + [ -f "$base_csv" ] || continue + echo "### ${name}" + echo '```' + python .perf_stash/compare_results.py "$base_csv" "$pr_csv" || FAIL=1 + echo '```' + echo "" + done + } > perf_results/report.md + cat perf_results/report.md + # exit $FAIL + OUTER )" - - name: Check Test Failure Status - if: always() - run: | - EXIT_STATUS=0 - # Check outcomes of the specific test steps - # "outcome" will be 'failure' even if continue-on-error was true - - # sGPU CHECKS - # We check for the file existence directly because the 'Run sGPU tests' step - # halts immediately on docker failure, skipping the lines that set step outputs. - if [[ -f FAIL_CORE_SGPU ]]; then - echo "::error::Core sGPU Tests Failed." - EXIT_STATUS=1 - fi - if [[ -f FAIL_TORCH_SGPU ]]; then - echo "::error::PyTorch sGPU Tests Failed." - EXIT_STATUS=1 - fi - if [[ -f FAIL_JAX_SGPU ]]; then - echo "::error::JAX sGPU Tests Failed." - EXIT_STATUS=1 - fi - - # mGPU CHECKS - if [[ -f FAIL_TORCH_MGPU ]]; then - echo "::error::PyTorch mGPU Tests Failed." - EXIT_STATUS=1 - fi - if [[ -f FAIL_JAX_MGPU ]]; then - echo "::error::JAX mGPU Tests Failed." - EXIT_STATUS=1 - fi + # Post report as PR comment + REPORT="perf_results/report.md" - # EXAMPLES CHECK - # Examples script does not use marker files, so we rely on step outcome - if [[ "${{ steps.examples-tests.outcome }}" == "failure" ]]; then - echo "::error::Example Tests Failed." - EXIT_STATUS=1 - fi + MARKER="" + BODY="${MARKER}"$'\n'"$(cat "$REPORT")" - # Fail the job if any errors were detected - if [[ "$EXIT_STATUS" == "1" ]]; then - exit 1 - fi + # Find and update existing comment, or create new one + EXISTING=$(gh api "repos/${GH_REPO}/issues/${PR_NUMBER}/comments" \ + --paginate --jq ".[] | select(.body | contains(\"${MARKER}\")) | .id" | head -1) - - name: Copy logs and reports from container - if: always() - run: | - docker cp te-runner:/workspace/torch_sgpu.log ./torch_sgpu.log || true - docker cp te-runner:/workspace/jax_sgpu.log ./jax_sgpu.log || true - docker cp te-runner:/workspace/core_sgpu.log ./core_sgpu.log || true - docker cp te-runner:/workspace/torch_mgpu.log ./torch_mgpu.log || true - docker cp te-runner:/workspace/jax_mgpu.log ./jax_mgpu.log || true + if [ -n "$EXISTING" ]; then + gh api "repos/${GH_REPO}/issues/comments/${EXISTING}" \ + --method PATCH --field body="$BODY" + else + gh pr comment "$PR_NUMBER" --repo "$GH_REPO" --body "$BODY" + fi - - name: Upload logs and test reports - if: always() - uses: actions/upload-artifact@v4 - with: - name: logs-and-reports-${{ matrix.runner }} - path: | - *.log - if-no-files-found: ignore - retention-days: 5 + # - name: Check Test Failure Status + # if: always() + # run: | + # EXIT_STATUS=0 + # # Check outcomes of the specific test steps + # # "outcome" will be 'failure' even if continue-on-error was true + + # # sGPU CHECKS + # # We check for the file existence directly because the 'Run sGPU tests' step + # # halts immediately on docker failure, skipping the lines that set step outputs. + # if [[ -f FAIL_CORE_SGPU ]]; then + # echo "::error::Core sGPU Tests Failed." + # EXIT_STATUS=1 + # fi + # if [[ -f FAIL_TORCH_SGPU ]]; then + # echo "::error::PyTorch sGPU Tests Failed." + # EXIT_STATUS=1 + # fi + # if [[ -f FAIL_JAX_SGPU ]]; then + # echo "::error::JAX sGPU Tests Failed." + # EXIT_STATUS=1 + # fi + + # # mGPU CHECKS + # if [[ -f FAIL_TORCH_MGPU ]]; then + # echo "::error::PyTorch mGPU Tests Failed." + # EXIT_STATUS=1 + # fi + # if [[ -f FAIL_JAX_MGPU ]]; then + # echo "::error::JAX mGPU Tests Failed." + # EXIT_STATUS=1 + # fi + + # # EXAMPLES CHECK + # # Examples script does not use marker files, so we rely on step outcome + # if [[ "${{ steps.examples-tests.outcome }}" == "failure" ]]; then + # echo "::error::Example Tests Failed." + # EXIT_STATUS=1 + # fi + + # # Fail the job if any errors were detected + # if [[ "$EXIT_STATUS" == "1" ]]; then + # exit 1 + # fi + + # - name: Copy logs and reports from container + # if: always() + # run: | + # docker cp te-runner:/workspace/torch_sgpu.log ./torch_sgpu.log || true + # docker cp te-runner:/workspace/jax_sgpu.log ./jax_sgpu.log || true + # docker cp te-runner:/workspace/core_sgpu.log ./core_sgpu.log || true + # docker cp te-runner:/workspace/torch_mgpu.log ./torch_mgpu.log || true + # docker cp te-runner:/workspace/jax_mgpu.log ./jax_mgpu.log || true + + # - name: Upload logs and test reports + # if: always() + # uses: actions/upload-artifact@v4 + # with: + # name: logs-and-reports-${{ matrix.runner }} + # path: | + # *.log + # if-no-files-found: ignore + # retention-days: 5 - name: Cleanup container if: always() diff --git a/benchmarks/benchmark_gemm.py b/benchmarks/benchmark_gemm.py new file mode 100755 index 000000000..cd651c172 --- /dev/null +++ b/benchmarks/benchmark_gemm.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + + +import torch +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te + +# Sequence / batch-token sizes to sweep +M_SIZE_LIST = [1024, 2048, 4096, 8192] + +# Model configurations +# Sources: +# - Llama 3 8B (hidden=4096, intermediate=14336, heads=32, kv_heads=8, head_dim=128) +# https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + +# - Llama 3 70B (hidden=8192, intermediate=28672, heads=64, kv_heads=8, head_dim=128) +# https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + +# - Llama 3 405B (hidden=16384, intermediate=53248, heads=128, kv_heads=8, head_dim=128) +# https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + +# - Qwen 2.5 7B (hidden=3584, intermediate=18944, heads=28, kv_heads=4, head_dim=128) +# https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + +# - Qwen 2.5 72B (hidden=8192, intermediate=29568, heads=64, kv_heads=8, head_dim=128) +# https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json + +MODEL_CONFIGS = [ + # (name, hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) + ("Llama3-8B/TP1", 4096, 14336, 32, 8, 128, 1), + ("Llama3-8B/TP8", 4096, 14336, 32, 8, 128, 8), + ("Llama3-70B/TP8", 8192, 28672, 64, 8, 128, 8), + ("Llama3-405B/TP8", 16384, 53248, 128, 8, 128, 8), + ("Qwen2.5-7B/TP1", 3584, 18944, 28, 4, 128, 1), + ("Qwen2.5-72B/TP8", 8192, 29568, 64, 8, 128, 8), +] + + +def _generate_gemm_test_cases(): + test_cases = [] + + for (name, hidden, intermediate, n_q, n_kv, hd, tp) in MODEL_CONFIGS: + shapes = { + f"{name}-QKV": ((n_q * hd + 2 * n_kv * hd) // tp, hidden), + f"{name}-AttnOut": (hidden, (n_q * hd) // tp), + f"{name}-GateUp": ((2 * intermediate) // tp, hidden), + f"{name}-Down": (hidden, intermediate // tp), + } + + for M in M_SIZE_LIST: + for case_name, (N, K) in shapes.items(): + test_cases.append({ + "Case": case_name, + "M": M, + "N": N, + "K": K, + "dtype": torch.bfloat16, + }) + return test_cases + + +def bench_gemm(M, N, K, dtype): + device = "cuda" + + linear = te.Linear(K, N, bias=False).to(device=device, dtype=dtype) + x = torch.randn(M, K, dtype=dtype, device=device, requires_grad=True) + + fwd_func = lambda: linear(x) + out = fwd_func() + grad_out = torch.randn_like(out) + + def bwd_func(): + out = linear(x) + out.backward(grad_out) + # Clear grads so they don't accumulate across iterations + x.grad = None + linear.weight.grad = None + + bwd_func() + + fwd_flops = 2 * M * N * K + bwd_flops = 2 * fwd_flops # dX + dW + + # Warmup + for _ in range(20): + fwd_func() + bwd_func() + torch.cuda.synchronize() + + # Benchmark + n_iters = 100 + + fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).timeit(n_iters).mean * 1e3 + fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": bwd_func}).timeit(n_iters).mean * 1e3 + + bwd_ms = max(fwd_bwd_ms - fwd_ms, 0.0) + + fwd_tflops = fwd_flops / (fwd_ms * 1e-3) / 1e12 + bwd_tflops = bwd_flops / (bwd_ms * 1e-3) / 1e12 if bwd_ms > 0 else 0.0 + + print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS") + print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)") + + return { + "TE Forward Time (ms)": f"{fwd_ms:.2f}", + "TE Forward TFLOPS": f"{fwd_tflops:.2f}", + "TE Backward Time (ms)": f"{bwd_ms:.2f}", + "TE Backward TFLOPS": f"{bwd_tflops:.2f}", + } + + +if __name__ == "__main__": + import pandas as pd + + test_cases = _generate_gemm_test_cases() + + columns = [ + "Case", "M", "N", "K", "dtype", + "TE Forward Time (ms)", + "TE Forward TFLOPS", + "TE Backward Time (ms)", + "TE Backward TFLOPS", + ] + rows = [] + + # Warmup run + c = test_cases[0] + print(f"\n{'='*60}") + print(f"WARMUP: {c}") + print(f"{'='*60}") + bench_gemm(M=c["M"], N=c["N"], K=c["K"], dtype=c["dtype"]) + + for case in test_cases: + print(f"\n{'='*60}") + print(f"Testing: {case}") + print(f"{'='*60}") + try: + metrics = bench_gemm( + M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] + ) + row = { + "Case": case["Case"], + "M": case["M"], + "N": case["N"], + "K": case["K"], + "dtype": str(case["dtype"]), + **metrics, + } + rows.append(row) + except Exception as e: + print(f"FAILED: {case}: {e}") + raise + + results = pd.DataFrame(rows, columns=columns) + + out_csv = "benchmark_gemm.csv" + results.to_csv(out_csv, index=False) + print(f"\nResults saved to {out_csv}") diff --git a/benchmarks/benchmark_grouped_gemm.py b/benchmarks/benchmark_grouped_gemm.py new file mode 100755 index 000000000..7cee6edd4 --- /dev/null +++ b/benchmarks/benchmark_grouped_gemm.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +import os +import torch +import torch.utils.benchmark as benchmark + +def generate_grouped_gemm_group_lens(b, m, balance: bool): + if balance: + return torch.full((b,), m, dtype=torch.int64) + else: + dist = 0.2 + 0.8 * torch.rand(b) + dist /= dist.sum() + group_lens = (dist * b * m).to(torch.int64) + error = b * m - group_lens.sum() + group_lens[-1] += error + return group_lens + +M_SIZE_LIST = [512, 1024, 2048, 4096]#, 8192, 16384] +EP_SIZE_LIST = [32, 16, 8] + + +def _generate_moe_test_cases( + name_prefix: str, + n_routed_experts: int, + moe_intermediate_size: int, + hidden_size: int, +): + test_cases = [] + shapes_dict = { + f"{name_prefix}-GateUP": (2 * moe_intermediate_size, hidden_size), + f"{name_prefix}-Down": (hidden_size, moe_intermediate_size), + } + + for ep in EP_SIZE_LIST: + if n_routed_experts % ep != 0: + continue + B = n_routed_experts // ep + if B < 1: + continue + for M in M_SIZE_LIST: + for name, (N, K) in shapes_dict.items(): + for dtype in [torch.bfloat16]: + test_cases.append( + { + "Case": name, + "B": B, + "M": M, + "N": N, + "K": K, + "dtype": dtype, + } + ) + return test_cases + + +def generate_deepseekv3_test_cases(): + return _generate_moe_test_cases( + "DSV3", n_routed_experts=256, moe_intermediate_size=2048, hidden_size=7168 + ) + + +def generate_deepseekv2_test_cases(): + return _generate_moe_test_cases( + "DSV2", n_routed_experts=160, moe_intermediate_size=1536, hidden_size=5120 + ) + + +def generate_deepseekv2_lite_test_cases(): + return _generate_moe_test_cases( + "DSV2-Lite", n_routed_experts=64, moe_intermediate_size=1408, hidden_size=2048 + ) + + +def generate_grok_v2_test_cases(): + return _generate_moe_test_cases( + "Grok-V2", n_routed_experts=8, moe_intermediate_size=16384, hidden_size=8192 + ) + + +def make_fwd_bwd_funcs_te(x, w, group_lens, activation_dtype): + from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace + from transformer_engine.pytorch.cpp_extensions import general_grouped_gemm + + B = int(group_lens.numel()) + N = int(w.shape[1]) + K = int(w.shape[2]) + + m_splits = [int(v) for v in group_lens.tolist()] + assert len(m_splits) == B + sum_M = sum(m_splits) + assert x.numel() > 0 and x.shape[0] == sum_M + + x_view = x.reshape(-1, x.shape[-1]) + xs = list(torch.split(x_view, m_splits)) + weights = [w[i] for i in range(B)] + + workspaces = get_multi_stream_cublas_workspace() + + # Forward output buffer + out = torch.empty((sum_M, N), device=x.device, dtype=activation_dtype) + + def fwd_func_te(): + general_grouped_gemm( + A=weights, + B=xs, + out=[out], + out_dtype=activation_dtype, + workspaces=workspaces, + single_output=True, + m_splits=m_splits, + use_bias=False, + bias=None, + layout="TN", + ) + return out + + # dx buffers + dx = torch.empty((sum_M, K), device=x.device, dtype=activation_dtype) + dxs = list(torch.split(dx, m_splits)) + + # dw buffers + dw_stacked = torch.empty((B, N, K), device=x.device, dtype=activation_dtype) + dws = [dw_stacked[i] for i in range(B)] + + def bwd_func_te(grad_out): + go = grad_out.view(-1, grad_out.shape[-1]) + splits = torch.split(go, m_splits) + + general_grouped_gemm( + A=weights, + B=splits, + out=dxs, + out_dtype=activation_dtype, + workspaces=workspaces, + single_output=False, + layout="NN", + m_splits=m_splits, + grad=False, + use_bias=False, + bias=None, + ) + + general_grouped_gemm( + A=xs, + B=splits, + out=dws, + out_dtype=activation_dtype, + workspaces=workspaces, + single_output=False, + layout="NT", + m_splits=m_splits, + grad=False, + use_bias=False, + bias=None, + accumulate=False, + ) + + return dx, dw_stacked + + return fwd_func_te, bwd_func_te + + +def bench_grouped_gemm(B, M, N, K, dtype): + device = "cuda" + + x = torch.randn((B * M, K), dtype=dtype, device=device, requires_grad=True) + w = torch.randn((B, N, K), dtype=dtype, device=device, requires_grad=True) + group_lens = generate_grouped_gemm_group_lens(B, M, balance=True).to(device) + print("group_lens: ", group_lens) + + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + os.environ["NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK"] = "1" + + # TE grouped (CK_Tile) + x_te = x.clone().detach() + w_te = w.clone().detach() + fwd_func_te, bwd_func_te_inner = make_fwd_bwd_funcs_te( + x_te, w_te, group_lens, activation_dtype=dtype + ) + + out_te = fwd_func_te() + grad_out = torch.randn_like(out_te) + bwd_func_te = lambda: bwd_func_te_inner(grad_out) + dx_te, dw_te = bwd_func_te() + + # FLOPs + fwd_total_flops = 2 * B * M * N * K + bwd_total_flops = 2 * fwd_total_flops + + # Warmup + for _ in range(20): + fwd_func_te() + bwd_func_te() + + torch.cuda.synchronize() + + # Benchmark + n_iters = 100 + + fwd_te_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func_te}).timeit(n_iters).mean * 1e3 + bwd_te_ms = benchmark.Timer(stmt="fn()", globals={"fn": bwd_func_te}).timeit(n_iters).mean * 1e3 + + fwd_te_tflops = fwd_total_flops / (fwd_te_ms * 1e-3) / 1e12 + bwd_te_tflops = bwd_total_flops / (bwd_te_ms * 1e-3) / 1e12 + + print(f"TE (CK_Tile) Forward {fwd_te_ms:.3f} ms | {fwd_te_tflops:.2f} TFLOPS") + print(f"TE (CK_Tile) Backward {bwd_te_ms:.3f} ms | {bwd_te_tflops:.2f} TFLOPS") + + return { + "TE (CK_Tile) Forward Time (ms)": f"{fwd_te_ms:.2f}", + "TE (CK_Tile) Forward TFLOPS": f"{fwd_te_tflops:.2f}", + "TE (CK_Tile) Backward Time (ms)": f"{bwd_te_ms:.2f}", + "TE (CK_Tile) Backward TFLOPS": f"{bwd_te_tflops:.2f}", + } + + +if __name__ == "__main__": + import pandas as pd + + test_cases = ( + generate_deepseekv2_lite_test_cases() + + generate_deepseekv2_test_cases() + + generate_deepseekv3_test_cases() + + generate_grok_v2_test_cases() + ) + + columns = [ + "Case", "B", "M", "N", "K", "dtype", + "TE (CK_Tile) Forward Time (ms)", + "TE (CK_Tile) Forward TFLOPS", + "TE (CK_Tile) Backward Time (ms)", + "TE (CK_Tile) Backward TFLOPS", + ] + rows = [] + + # Warmup run + c = test_cases[0] + print(f"\n{'='*50}") + print(f"WARMUP: {c}") + print(f"{'='*50}") + bench_grouped_gemm(B=c["B"], M=c["M"], N=c["N"], K=c["K"], dtype=c["dtype"]) + + for case in test_cases: + print(f"\n{'='*50}") + print(f"Testing: {case}") + print(f"{'='*50}") + try: + metrics = bench_grouped_gemm( + B=case["B"], M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] + ) + row = { + "Case": case["Case"], + "B": case["B"], + "M": case["M"], + "N": case["N"], + "K": case["K"], + "dtype": str(case["dtype"]), + **metrics, + } + rows.append(row) + except Exception as e: + print(f"FAILED: {case}: {e}") + raise + + results = pd.DataFrame(rows, columns=columns) + + out_csv = "benchmark_grouped_gemm.csv" + results.to_csv(out_csv, index=False) + print(f"\nResults saved to {out_csv}") diff --git a/benchmarks/compare_results.py b/benchmarks/compare_results.py new file mode 100644 index 000000000..ca6f7f70d --- /dev/null +++ b/benchmarks/compare_results.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +import argparse +import sys + +import pandas as pd + +SKIP_COLS = {"TestID", "Label"} + + +def auto_detect_columns(df: pd.DataFrame): + """Split columns into (key_cols, metric_cols) by naming convention.""" + metric_cols = [c for c in df.columns if "TFLOPS" in c] + key_cols = [ + c for c in df.columns + if c not in metric_cols and c not in SKIP_COLS + and "Time" not in c # skip timing columns, only compare TFLOPS + ] + return key_cols, metric_cols + + +def main(): + parser = argparse.ArgumentParser(description="Compare benchmark CSVs") + parser.add_argument("base_csv", help="Base branch CSV") + parser.add_argument("pr_csv", help="PR branch CSV") + parser.add_argument("--threshold", type=float, default=5.0, + help="Regression threshold %% (default: 5.0)") + parser.add_argument("--key-cols", default=None, + help="Comma-separated key columns (auto-detected if omitted)") + parser.add_argument("--metric-cols", default=None, + help="Comma-separated metric columns (auto-detected if omitted)") + args = parser.parse_args() + + base_df = pd.read_csv(args.base_csv) + pr_df = pd.read_csv(args.pr_csv) + + # Determine columns + if args.key_cols: + key_cols = [c.strip() for c in args.key_cols.split(",")] + else: + key_cols, _ = auto_detect_columns(base_df) + + if args.metric_cols: + metric_cols = [c.strip() for c in args.metric_cols.split(",")] + else: + _, metric_cols = auto_detect_columns(base_df) + + if not metric_cols: + print("No metric columns found — nothing to compare.") + return 0 + + print(f"Key columns: {key_cols}") + print(f"Metric columns: {metric_cols}") + print(f"Threshold: {args.threshold}%") + print(f"Base rows: {len(base_df)}, PR rows: {len(pr_df)}") + print() + + # Ensure metric columns are numeric + for col in metric_cols: + base_df[col] = pd.to_numeric(base_df[col], errors="coerce") + pr_df[col] = pd.to_numeric(pr_df[col], errors="coerce") + + # Match rows + merged = base_df.merge(pr_df, on=key_cols, suffixes=("_base", "_pr"), how="inner") + if merged.empty: + print("WARNING: No matching rows between base and PR.") + return 0 + + print(f"Matched rows: {len(merged)}") + print() + + # Compare + regressions = [] + for metric in metric_cols: + bc = f"{metric}_base" + pc = f"{metric}_pr" + if bc not in merged.columns or pc not in merged.columns: + continue + + bv = merged[bc] + pv = merged[pc] + delta_pct = ((pv - bv) / bv) * 100.0 + + for idx in merged.index: + if pd.isna(bv[idx]) or pd.isna(pv[idx]) or bv[idx] < 0.5: + continue + if delta_pct[idx] < -args.threshold: + key_info = " | ".join(f"{k}={merged.loc[idx, k]}" for k in key_cols) + regressions.append({ + "keys": key_info, + "metric": metric, + "base": bv[idx], + "pr": pv[idx], + "delta": delta_pct[idx], + }) + + # Print summary per metric + for metric in metric_cols: + bc = f"{metric}_base" + pc = f"{metric}_pr" + if bc not in merged.columns: + continue + bv = merged[bc].dropna() + pv = merged[pc].dropna() + if bv.empty: + continue + deltas = ((pv - bv) / bv) * 100.0 + print(f" {metric}:") + print(f" mean base={bv.mean():.2f} pr={pv.mean():.2f} delta={deltas.mean():+.2f}%") + print(f" min delta={deltas.min():+.2f}% max delta={deltas.max():+.2f}%") + print() + + if regressions: + print(f"REGRESSIONS DETECTED: {len(regressions)}") + print("-" * 80) + for r in regressions: + print(f" [{r['metric']}] {r['keys']}") + print(f" base={r['base']:.2f} pr={r['pr']:.2f} delta={r['delta']:+.2f}%") + print("-" * 80) + return 1 + else: + print("No regressions detected.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) From 6ddb77d99f128c31604140f5811675fbc98a0402 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 11 Mar 2026 14:18:20 -0500 Subject: [PATCH 02/25] put into benchmarks subfolder --- .github/workflows/rocm-ci.yml | 8 +- benchmarks/benchmark_gemm.py | 164 ---------------- benchmarks/benchmark_grouped_gemm.py | 274 --------------------------- benchmarks/compare_results.py | 132 ------------- 4 files changed, 5 insertions(+), 573 deletions(-) delete mode 100755 benchmarks/benchmark_gemm.py delete mode 100755 benchmarks/benchmark_grouped_gemm.py delete mode 100644 benchmarks/compare_results.py diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index b573c5414..44af6fdb7 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -380,7 +380,9 @@ jobs: run: | set -ex - # Restore PR checkout no matter how this step exits + # Restore PR checkout no matter how this step exits, + # in case a later step needs to access the PR code. + # Note that the PR code is *not* recompiled. trap 'git checkout ${{ github.sha }} && git submodule update --init --recursive' EXIT # Benchmark PR branch (already built) @@ -390,7 +392,7 @@ jobs: cd /workspace mkdir -p perf_results/pr - for bench in benchmarks/benchmark_*.py; do + for bench in benchmarks/microbenchmarks/benchmark_*.py; do name=$(basename "$bench" .py) echo "=== Running $name (PR) ===" python "$bench" @@ -399,7 +401,7 @@ jobs: # Stash benchmark scripts so they survive the base branch checkout mkdir -p .perf_stash - cp benchmarks/benchmark_*.py benchmarks/compare_results.py .perf_stash/ + cp benchmarks/microbenchmarks/benchmark_*.py benchmarks/microbenchmarks/compare_results.py .perf_stash/ OUTER )" diff --git a/benchmarks/benchmark_gemm.py b/benchmarks/benchmark_gemm.py deleted file mode 100755 index cd651c172..000000000 --- a/benchmarks/benchmark_gemm.py +++ /dev/null @@ -1,164 +0,0 @@ -#!/usr/bin/env python -############################################################################### -# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. -# -# See LICENSE for license information. -############################################################################### - - -import torch -import torch.utils.benchmark as benchmark - -import transformer_engine.pytorch as te - -# Sequence / batch-token sizes to sweep -M_SIZE_LIST = [1024, 2048, 4096, 8192] - -# Model configurations -# Sources: -# - Llama 3 8B (hidden=4096, intermediate=14336, heads=32, kv_heads=8, head_dim=128) -# https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json - -# - Llama 3 70B (hidden=8192, intermediate=28672, heads=64, kv_heads=8, head_dim=128) -# https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json - -# - Llama 3 405B (hidden=16384, intermediate=53248, heads=128, kv_heads=8, head_dim=128) -# https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json - -# - Qwen 2.5 7B (hidden=3584, intermediate=18944, heads=28, kv_heads=4, head_dim=128) -# https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json - -# - Qwen 2.5 72B (hidden=8192, intermediate=29568, heads=64, kv_heads=8, head_dim=128) -# https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json - -MODEL_CONFIGS = [ - # (name, hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) - ("Llama3-8B/TP1", 4096, 14336, 32, 8, 128, 1), - ("Llama3-8B/TP8", 4096, 14336, 32, 8, 128, 8), - ("Llama3-70B/TP8", 8192, 28672, 64, 8, 128, 8), - ("Llama3-405B/TP8", 16384, 53248, 128, 8, 128, 8), - ("Qwen2.5-7B/TP1", 3584, 18944, 28, 4, 128, 1), - ("Qwen2.5-72B/TP8", 8192, 29568, 64, 8, 128, 8), -] - - -def _generate_gemm_test_cases(): - test_cases = [] - - for (name, hidden, intermediate, n_q, n_kv, hd, tp) in MODEL_CONFIGS: - shapes = { - f"{name}-QKV": ((n_q * hd + 2 * n_kv * hd) // tp, hidden), - f"{name}-AttnOut": (hidden, (n_q * hd) // tp), - f"{name}-GateUp": ((2 * intermediate) // tp, hidden), - f"{name}-Down": (hidden, intermediate // tp), - } - - for M in M_SIZE_LIST: - for case_name, (N, K) in shapes.items(): - test_cases.append({ - "Case": case_name, - "M": M, - "N": N, - "K": K, - "dtype": torch.bfloat16, - }) - return test_cases - - -def bench_gemm(M, N, K, dtype): - device = "cuda" - - linear = te.Linear(K, N, bias=False).to(device=device, dtype=dtype) - x = torch.randn(M, K, dtype=dtype, device=device, requires_grad=True) - - fwd_func = lambda: linear(x) - out = fwd_func() - grad_out = torch.randn_like(out) - - def bwd_func(): - out = linear(x) - out.backward(grad_out) - # Clear grads so they don't accumulate across iterations - x.grad = None - linear.weight.grad = None - - bwd_func() - - fwd_flops = 2 * M * N * K - bwd_flops = 2 * fwd_flops # dX + dW - - # Warmup - for _ in range(20): - fwd_func() - bwd_func() - torch.cuda.synchronize() - - # Benchmark - n_iters = 100 - - fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).timeit(n_iters).mean * 1e3 - fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": bwd_func}).timeit(n_iters).mean * 1e3 - - bwd_ms = max(fwd_bwd_ms - fwd_ms, 0.0) - - fwd_tflops = fwd_flops / (fwd_ms * 1e-3) / 1e12 - bwd_tflops = bwd_flops / (bwd_ms * 1e-3) / 1e12 if bwd_ms > 0 else 0.0 - - print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS") - print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)") - - return { - "TE Forward Time (ms)": f"{fwd_ms:.2f}", - "TE Forward TFLOPS": f"{fwd_tflops:.2f}", - "TE Backward Time (ms)": f"{bwd_ms:.2f}", - "TE Backward TFLOPS": f"{bwd_tflops:.2f}", - } - - -if __name__ == "__main__": - import pandas as pd - - test_cases = _generate_gemm_test_cases() - - columns = [ - "Case", "M", "N", "K", "dtype", - "TE Forward Time (ms)", - "TE Forward TFLOPS", - "TE Backward Time (ms)", - "TE Backward TFLOPS", - ] - rows = [] - - # Warmup run - c = test_cases[0] - print(f"\n{'='*60}") - print(f"WARMUP: {c}") - print(f"{'='*60}") - bench_gemm(M=c["M"], N=c["N"], K=c["K"], dtype=c["dtype"]) - - for case in test_cases: - print(f"\n{'='*60}") - print(f"Testing: {case}") - print(f"{'='*60}") - try: - metrics = bench_gemm( - M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] - ) - row = { - "Case": case["Case"], - "M": case["M"], - "N": case["N"], - "K": case["K"], - "dtype": str(case["dtype"]), - **metrics, - } - rows.append(row) - except Exception as e: - print(f"FAILED: {case}: {e}") - raise - - results = pd.DataFrame(rows, columns=columns) - - out_csv = "benchmark_gemm.csv" - results.to_csv(out_csv, index=False) - print(f"\nResults saved to {out_csv}") diff --git a/benchmarks/benchmark_grouped_gemm.py b/benchmarks/benchmark_grouped_gemm.py deleted file mode 100755 index 7cee6edd4..000000000 --- a/benchmarks/benchmark_grouped_gemm.py +++ /dev/null @@ -1,274 +0,0 @@ -#!/usr/bin/env python -############################################################################### -# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. -# -# See LICENSE for license information. -############################################################################### - -import os -import torch -import torch.utils.benchmark as benchmark - -def generate_grouped_gemm_group_lens(b, m, balance: bool): - if balance: - return torch.full((b,), m, dtype=torch.int64) - else: - dist = 0.2 + 0.8 * torch.rand(b) - dist /= dist.sum() - group_lens = (dist * b * m).to(torch.int64) - error = b * m - group_lens.sum() - group_lens[-1] += error - return group_lens - -M_SIZE_LIST = [512, 1024, 2048, 4096]#, 8192, 16384] -EP_SIZE_LIST = [32, 16, 8] - - -def _generate_moe_test_cases( - name_prefix: str, - n_routed_experts: int, - moe_intermediate_size: int, - hidden_size: int, -): - test_cases = [] - shapes_dict = { - f"{name_prefix}-GateUP": (2 * moe_intermediate_size, hidden_size), - f"{name_prefix}-Down": (hidden_size, moe_intermediate_size), - } - - for ep in EP_SIZE_LIST: - if n_routed_experts % ep != 0: - continue - B = n_routed_experts // ep - if B < 1: - continue - for M in M_SIZE_LIST: - for name, (N, K) in shapes_dict.items(): - for dtype in [torch.bfloat16]: - test_cases.append( - { - "Case": name, - "B": B, - "M": M, - "N": N, - "K": K, - "dtype": dtype, - } - ) - return test_cases - - -def generate_deepseekv3_test_cases(): - return _generate_moe_test_cases( - "DSV3", n_routed_experts=256, moe_intermediate_size=2048, hidden_size=7168 - ) - - -def generate_deepseekv2_test_cases(): - return _generate_moe_test_cases( - "DSV2", n_routed_experts=160, moe_intermediate_size=1536, hidden_size=5120 - ) - - -def generate_deepseekv2_lite_test_cases(): - return _generate_moe_test_cases( - "DSV2-Lite", n_routed_experts=64, moe_intermediate_size=1408, hidden_size=2048 - ) - - -def generate_grok_v2_test_cases(): - return _generate_moe_test_cases( - "Grok-V2", n_routed_experts=8, moe_intermediate_size=16384, hidden_size=8192 - ) - - -def make_fwd_bwd_funcs_te(x, w, group_lens, activation_dtype): - from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace - from transformer_engine.pytorch.cpp_extensions import general_grouped_gemm - - B = int(group_lens.numel()) - N = int(w.shape[1]) - K = int(w.shape[2]) - - m_splits = [int(v) for v in group_lens.tolist()] - assert len(m_splits) == B - sum_M = sum(m_splits) - assert x.numel() > 0 and x.shape[0] == sum_M - - x_view = x.reshape(-1, x.shape[-1]) - xs = list(torch.split(x_view, m_splits)) - weights = [w[i] for i in range(B)] - - workspaces = get_multi_stream_cublas_workspace() - - # Forward output buffer - out = torch.empty((sum_M, N), device=x.device, dtype=activation_dtype) - - def fwd_func_te(): - general_grouped_gemm( - A=weights, - B=xs, - out=[out], - out_dtype=activation_dtype, - workspaces=workspaces, - single_output=True, - m_splits=m_splits, - use_bias=False, - bias=None, - layout="TN", - ) - return out - - # dx buffers - dx = torch.empty((sum_M, K), device=x.device, dtype=activation_dtype) - dxs = list(torch.split(dx, m_splits)) - - # dw buffers - dw_stacked = torch.empty((B, N, K), device=x.device, dtype=activation_dtype) - dws = [dw_stacked[i] for i in range(B)] - - def bwd_func_te(grad_out): - go = grad_out.view(-1, grad_out.shape[-1]) - splits = torch.split(go, m_splits) - - general_grouped_gemm( - A=weights, - B=splits, - out=dxs, - out_dtype=activation_dtype, - workspaces=workspaces, - single_output=False, - layout="NN", - m_splits=m_splits, - grad=False, - use_bias=False, - bias=None, - ) - - general_grouped_gemm( - A=xs, - B=splits, - out=dws, - out_dtype=activation_dtype, - workspaces=workspaces, - single_output=False, - layout="NT", - m_splits=m_splits, - grad=False, - use_bias=False, - bias=None, - accumulate=False, - ) - - return dx, dw_stacked - - return fwd_func_te, bwd_func_te - - -def bench_grouped_gemm(B, M, N, K, dtype): - device = "cuda" - - x = torch.randn((B * M, K), dtype=dtype, device=device, requires_grad=True) - w = torch.randn((B, N, K), dtype=dtype, device=device, requires_grad=True) - group_lens = generate_grouped_gemm_group_lens(B, M, balance=True).to(device) - print("group_lens: ", group_lens) - - os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" - os.environ["NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK"] = "1" - - # TE grouped (CK_Tile) - x_te = x.clone().detach() - w_te = w.clone().detach() - fwd_func_te, bwd_func_te_inner = make_fwd_bwd_funcs_te( - x_te, w_te, group_lens, activation_dtype=dtype - ) - - out_te = fwd_func_te() - grad_out = torch.randn_like(out_te) - bwd_func_te = lambda: bwd_func_te_inner(grad_out) - dx_te, dw_te = bwd_func_te() - - # FLOPs - fwd_total_flops = 2 * B * M * N * K - bwd_total_flops = 2 * fwd_total_flops - - # Warmup - for _ in range(20): - fwd_func_te() - bwd_func_te() - - torch.cuda.synchronize() - - # Benchmark - n_iters = 100 - - fwd_te_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func_te}).timeit(n_iters).mean * 1e3 - bwd_te_ms = benchmark.Timer(stmt="fn()", globals={"fn": bwd_func_te}).timeit(n_iters).mean * 1e3 - - fwd_te_tflops = fwd_total_flops / (fwd_te_ms * 1e-3) / 1e12 - bwd_te_tflops = bwd_total_flops / (bwd_te_ms * 1e-3) / 1e12 - - print(f"TE (CK_Tile) Forward {fwd_te_ms:.3f} ms | {fwd_te_tflops:.2f} TFLOPS") - print(f"TE (CK_Tile) Backward {bwd_te_ms:.3f} ms | {bwd_te_tflops:.2f} TFLOPS") - - return { - "TE (CK_Tile) Forward Time (ms)": f"{fwd_te_ms:.2f}", - "TE (CK_Tile) Forward TFLOPS": f"{fwd_te_tflops:.2f}", - "TE (CK_Tile) Backward Time (ms)": f"{bwd_te_ms:.2f}", - "TE (CK_Tile) Backward TFLOPS": f"{bwd_te_tflops:.2f}", - } - - -if __name__ == "__main__": - import pandas as pd - - test_cases = ( - generate_deepseekv2_lite_test_cases() - + generate_deepseekv2_test_cases() - + generate_deepseekv3_test_cases() - + generate_grok_v2_test_cases() - ) - - columns = [ - "Case", "B", "M", "N", "K", "dtype", - "TE (CK_Tile) Forward Time (ms)", - "TE (CK_Tile) Forward TFLOPS", - "TE (CK_Tile) Backward Time (ms)", - "TE (CK_Tile) Backward TFLOPS", - ] - rows = [] - - # Warmup run - c = test_cases[0] - print(f"\n{'='*50}") - print(f"WARMUP: {c}") - print(f"{'='*50}") - bench_grouped_gemm(B=c["B"], M=c["M"], N=c["N"], K=c["K"], dtype=c["dtype"]) - - for case in test_cases: - print(f"\n{'='*50}") - print(f"Testing: {case}") - print(f"{'='*50}") - try: - metrics = bench_grouped_gemm( - B=case["B"], M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] - ) - row = { - "Case": case["Case"], - "B": case["B"], - "M": case["M"], - "N": case["N"], - "K": case["K"], - "dtype": str(case["dtype"]), - **metrics, - } - rows.append(row) - except Exception as e: - print(f"FAILED: {case}: {e}") - raise - - results = pd.DataFrame(rows, columns=columns) - - out_csv = "benchmark_grouped_gemm.csv" - results.to_csv(out_csv, index=False) - print(f"\nResults saved to {out_csv}") diff --git a/benchmarks/compare_results.py b/benchmarks/compare_results.py deleted file mode 100644 index ca6f7f70d..000000000 --- a/benchmarks/compare_results.py +++ /dev/null @@ -1,132 +0,0 @@ -#!/usr/bin/env python -############################################################################### -# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. -# -# See LICENSE for license information. -############################################################################### - -import argparse -import sys - -import pandas as pd - -SKIP_COLS = {"TestID", "Label"} - - -def auto_detect_columns(df: pd.DataFrame): - """Split columns into (key_cols, metric_cols) by naming convention.""" - metric_cols = [c for c in df.columns if "TFLOPS" in c] - key_cols = [ - c for c in df.columns - if c not in metric_cols and c not in SKIP_COLS - and "Time" not in c # skip timing columns, only compare TFLOPS - ] - return key_cols, metric_cols - - -def main(): - parser = argparse.ArgumentParser(description="Compare benchmark CSVs") - parser.add_argument("base_csv", help="Base branch CSV") - parser.add_argument("pr_csv", help="PR branch CSV") - parser.add_argument("--threshold", type=float, default=5.0, - help="Regression threshold %% (default: 5.0)") - parser.add_argument("--key-cols", default=None, - help="Comma-separated key columns (auto-detected if omitted)") - parser.add_argument("--metric-cols", default=None, - help="Comma-separated metric columns (auto-detected if omitted)") - args = parser.parse_args() - - base_df = pd.read_csv(args.base_csv) - pr_df = pd.read_csv(args.pr_csv) - - # Determine columns - if args.key_cols: - key_cols = [c.strip() for c in args.key_cols.split(",")] - else: - key_cols, _ = auto_detect_columns(base_df) - - if args.metric_cols: - metric_cols = [c.strip() for c in args.metric_cols.split(",")] - else: - _, metric_cols = auto_detect_columns(base_df) - - if not metric_cols: - print("No metric columns found — nothing to compare.") - return 0 - - print(f"Key columns: {key_cols}") - print(f"Metric columns: {metric_cols}") - print(f"Threshold: {args.threshold}%") - print(f"Base rows: {len(base_df)}, PR rows: {len(pr_df)}") - print() - - # Ensure metric columns are numeric - for col in metric_cols: - base_df[col] = pd.to_numeric(base_df[col], errors="coerce") - pr_df[col] = pd.to_numeric(pr_df[col], errors="coerce") - - # Match rows - merged = base_df.merge(pr_df, on=key_cols, suffixes=("_base", "_pr"), how="inner") - if merged.empty: - print("WARNING: No matching rows between base and PR.") - return 0 - - print(f"Matched rows: {len(merged)}") - print() - - # Compare - regressions = [] - for metric in metric_cols: - bc = f"{metric}_base" - pc = f"{metric}_pr" - if bc not in merged.columns or pc not in merged.columns: - continue - - bv = merged[bc] - pv = merged[pc] - delta_pct = ((pv - bv) / bv) * 100.0 - - for idx in merged.index: - if pd.isna(bv[idx]) or pd.isna(pv[idx]) or bv[idx] < 0.5: - continue - if delta_pct[idx] < -args.threshold: - key_info = " | ".join(f"{k}={merged.loc[idx, k]}" for k in key_cols) - regressions.append({ - "keys": key_info, - "metric": metric, - "base": bv[idx], - "pr": pv[idx], - "delta": delta_pct[idx], - }) - - # Print summary per metric - for metric in metric_cols: - bc = f"{metric}_base" - pc = f"{metric}_pr" - if bc not in merged.columns: - continue - bv = merged[bc].dropna() - pv = merged[pc].dropna() - if bv.empty: - continue - deltas = ((pv - bv) / bv) * 100.0 - print(f" {metric}:") - print(f" mean base={bv.mean():.2f} pr={pv.mean():.2f} delta={deltas.mean():+.2f}%") - print(f" min delta={deltas.min():+.2f}% max delta={deltas.max():+.2f}%") - print() - - if regressions: - print(f"REGRESSIONS DETECTED: {len(regressions)}") - print("-" * 80) - for r in regressions: - print(f" [{r['metric']}] {r['keys']}") - print(f" base={r['base']:.2f} pr={r['pr']:.2f} delta={r['delta']:+.2f}%") - print("-" * 80) - return 1 - else: - print("No regressions detected.") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) From fb2b3f3256375e12fa5275940931e092a01bd617 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 11 Mar 2026 15:57:45 -0500 Subject: [PATCH 03/25] restructure comment --- .github/workflows/rocm-ci.yml | 108 +++++-- benchmarks/microbenchmarks/benchmark_gemm.py | 164 +++++++++++ .../microbenchmarks/benchmark_grouped_gemm.py | 274 ++++++++++++++++++ benchmarks/microbenchmarks/compare_results.py | 135 +++++++++ 4 files changed, 652 insertions(+), 29 deletions(-) create mode 100755 benchmarks/microbenchmarks/benchmark_gemm.py create mode 100755 benchmarks/microbenchmarks/benchmark_grouped_gemm.py create mode 100644 benchmarks/microbenchmarks/compare_results.py diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 44af6fdb7..2de97632d 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -380,6 +380,13 @@ jobs: run: | set -ex + # Map runner names to display names + case "${RUNNER_NAME}" in + linux-te-mi325-8) DISPLAY_NAME="MI325" ;; + linux-te-mi355-8) DISPLAY_NAME="MI355" ;; + *) DISPLAY_NAME="${RUNNER_NAME}" ;; + esac + # Restore PR checkout no matter how this step exits, # in case a later step needs to access the PR code. # Note that the PR code is *not* recompiled. @@ -436,43 +443,86 @@ jobs: done # Compare and build report - FAIL=0 - { - echo "## Performance Report (\`${{ matrix.runner }}\`)" - echo "" - for pr_csv in perf_results/pr/benchmark_*.csv; do - name=$(basename "$pr_csv" .csv) - base_csv="perf_results/base/${name}.csv" - [ -f "$base_csv" ] || continue - echo "### ${name}" - echo '```' - python .perf_stash/compare_results.py "$base_csv" "$pr_csv" || FAIL=1 - echo '```' - echo "" - done - } > perf_results/report.md - cat perf_results/report.md - # exit $FAIL + mkdir -p perf_results/reports + SUMMARY="perf_results/reports/summary.md" + DETAILS="perf_results/reports/details.md" + : > "$SUMMARY" + : > "$DETAILS" + + for pr_csv in perf_results/pr/benchmark_*.csv; do + name=$(basename "$pr_csv" .csv) + base_csv="perf_results/base/${name}.csv" + [ -f "$base_csv" ] || continue + echo "========== Comparing: $name ==========" + python .perf_stash/compare_results.py "$base_csv" "$pr_csv" \ + --bench-name "$name" \ + --summary-file "$SUMMARY" \ + >> "$DETAILS" + done OUTER )" - # Post report as PR comment - REPORT="perf_results/report.md" + # Assemble this runner's section + SUMMARY="perf_results/reports/summary.md" + DETAILS="perf_results/reports/details.md" + [ -f "$SUMMARY" ] || exit 0 + + SECTION_START="" + SECTION_END="" - MARKER="" - BODY="${MARKER}"$'\n'"$(cat "$REPORT")" + SECTION=$(cat < /tmp/perf_section.md + + # Post or update the single shared PR comment + COMMENT_MARKER="" + + COMMENT_ID=$(gh api "repos/${GH_REPO}/issues/${PR_NUMBER}/comments" \ + --paginate --jq ".[] | select(.body | contains(\"${COMMENT_MARKER}\")) | .id" \ + | head -1) + + if [ -n "$COMMENT_ID" ]; then + gh api "repos/${GH_REPO}/issues/comments/${COMMENT_ID}" --jq .body \ + > /tmp/perf_existing.md + + python -c " + existing = open('/tmp/perf_existing.md').read() + section = open('/tmp/perf_section.md').read() + start = '${SECTION_START}' + end = '${SECTION_END}' + if start in existing: + i = existing.index(start) + j = existing.index(end) + len(end) + result = existing[:i] + section + existing[j:] + else: + result = existing.rstrip() + '\n\n' + section + open('/tmp/perf_comment.md', 'w').write(result) + " + + gh api "repos/${GH_REPO}/issues/comments/${COMMENT_ID}" \ + --method PATCH --field body=@/tmp/perf_comment.md else - gh pr comment "$PR_NUMBER" --repo "$GH_REPO" --body "$BODY" - fi + { + echo "${COMMENT_MARKER}" + echo "## Performance Regression Report" + echo "" + cat /tmp/perf_section.md + } > /tmp/perf_comment.md + gh pr comment "$PR_NUMBER" --repo "$GH_REPO" \ + --body-file /tmp/perf_comment.md + fi # - name: Check Test Failure Status # if: always() # run: | diff --git a/benchmarks/microbenchmarks/benchmark_gemm.py b/benchmarks/microbenchmarks/benchmark_gemm.py new file mode 100755 index 000000000..cd651c172 --- /dev/null +++ b/benchmarks/microbenchmarks/benchmark_gemm.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + + +import torch +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te + +# Sequence / batch-token sizes to sweep +M_SIZE_LIST = [1024, 2048, 4096, 8192] + +# Model configurations +# Sources: +# - Llama 3 8B (hidden=4096, intermediate=14336, heads=32, kv_heads=8, head_dim=128) +# https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + +# - Llama 3 70B (hidden=8192, intermediate=28672, heads=64, kv_heads=8, head_dim=128) +# https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + +# - Llama 3 405B (hidden=16384, intermediate=53248, heads=128, kv_heads=8, head_dim=128) +# https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + +# - Qwen 2.5 7B (hidden=3584, intermediate=18944, heads=28, kv_heads=4, head_dim=128) +# https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + +# - Qwen 2.5 72B (hidden=8192, intermediate=29568, heads=64, kv_heads=8, head_dim=128) +# https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json + +MODEL_CONFIGS = [ + # (name, hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) + ("Llama3-8B/TP1", 4096, 14336, 32, 8, 128, 1), + ("Llama3-8B/TP8", 4096, 14336, 32, 8, 128, 8), + ("Llama3-70B/TP8", 8192, 28672, 64, 8, 128, 8), + ("Llama3-405B/TP8", 16384, 53248, 128, 8, 128, 8), + ("Qwen2.5-7B/TP1", 3584, 18944, 28, 4, 128, 1), + ("Qwen2.5-72B/TP8", 8192, 29568, 64, 8, 128, 8), +] + + +def _generate_gemm_test_cases(): + test_cases = [] + + for (name, hidden, intermediate, n_q, n_kv, hd, tp) in MODEL_CONFIGS: + shapes = { + f"{name}-QKV": ((n_q * hd + 2 * n_kv * hd) // tp, hidden), + f"{name}-AttnOut": (hidden, (n_q * hd) // tp), + f"{name}-GateUp": ((2 * intermediate) // tp, hidden), + f"{name}-Down": (hidden, intermediate // tp), + } + + for M in M_SIZE_LIST: + for case_name, (N, K) in shapes.items(): + test_cases.append({ + "Case": case_name, + "M": M, + "N": N, + "K": K, + "dtype": torch.bfloat16, + }) + return test_cases + + +def bench_gemm(M, N, K, dtype): + device = "cuda" + + linear = te.Linear(K, N, bias=False).to(device=device, dtype=dtype) + x = torch.randn(M, K, dtype=dtype, device=device, requires_grad=True) + + fwd_func = lambda: linear(x) + out = fwd_func() + grad_out = torch.randn_like(out) + + def bwd_func(): + out = linear(x) + out.backward(grad_out) + # Clear grads so they don't accumulate across iterations + x.grad = None + linear.weight.grad = None + + bwd_func() + + fwd_flops = 2 * M * N * K + bwd_flops = 2 * fwd_flops # dX + dW + + # Warmup + for _ in range(20): + fwd_func() + bwd_func() + torch.cuda.synchronize() + + # Benchmark + n_iters = 100 + + fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).timeit(n_iters).mean * 1e3 + fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": bwd_func}).timeit(n_iters).mean * 1e3 + + bwd_ms = max(fwd_bwd_ms - fwd_ms, 0.0) + + fwd_tflops = fwd_flops / (fwd_ms * 1e-3) / 1e12 + bwd_tflops = bwd_flops / (bwd_ms * 1e-3) / 1e12 if bwd_ms > 0 else 0.0 + + print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS") + print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)") + + return { + "TE Forward Time (ms)": f"{fwd_ms:.2f}", + "TE Forward TFLOPS": f"{fwd_tflops:.2f}", + "TE Backward Time (ms)": f"{bwd_ms:.2f}", + "TE Backward TFLOPS": f"{bwd_tflops:.2f}", + } + + +if __name__ == "__main__": + import pandas as pd + + test_cases = _generate_gemm_test_cases() + + columns = [ + "Case", "M", "N", "K", "dtype", + "TE Forward Time (ms)", + "TE Forward TFLOPS", + "TE Backward Time (ms)", + "TE Backward TFLOPS", + ] + rows = [] + + # Warmup run + c = test_cases[0] + print(f"\n{'='*60}") + print(f"WARMUP: {c}") + print(f"{'='*60}") + bench_gemm(M=c["M"], N=c["N"], K=c["K"], dtype=c["dtype"]) + + for case in test_cases: + print(f"\n{'='*60}") + print(f"Testing: {case}") + print(f"{'='*60}") + try: + metrics = bench_gemm( + M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] + ) + row = { + "Case": case["Case"], + "M": case["M"], + "N": case["N"], + "K": case["K"], + "dtype": str(case["dtype"]), + **metrics, + } + rows.append(row) + except Exception as e: + print(f"FAILED: {case}: {e}") + raise + + results = pd.DataFrame(rows, columns=columns) + + out_csv = "benchmark_gemm.csv" + results.to_csv(out_csv, index=False) + print(f"\nResults saved to {out_csv}") diff --git a/benchmarks/microbenchmarks/benchmark_grouped_gemm.py b/benchmarks/microbenchmarks/benchmark_grouped_gemm.py new file mode 100755 index 000000000..7cee6edd4 --- /dev/null +++ b/benchmarks/microbenchmarks/benchmark_grouped_gemm.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +import os +import torch +import torch.utils.benchmark as benchmark + +def generate_grouped_gemm_group_lens(b, m, balance: bool): + if balance: + return torch.full((b,), m, dtype=torch.int64) + else: + dist = 0.2 + 0.8 * torch.rand(b) + dist /= dist.sum() + group_lens = (dist * b * m).to(torch.int64) + error = b * m - group_lens.sum() + group_lens[-1] += error + return group_lens + +M_SIZE_LIST = [512, 1024, 2048, 4096]#, 8192, 16384] +EP_SIZE_LIST = [32, 16, 8] + + +def _generate_moe_test_cases( + name_prefix: str, + n_routed_experts: int, + moe_intermediate_size: int, + hidden_size: int, +): + test_cases = [] + shapes_dict = { + f"{name_prefix}-GateUP": (2 * moe_intermediate_size, hidden_size), + f"{name_prefix}-Down": (hidden_size, moe_intermediate_size), + } + + for ep in EP_SIZE_LIST: + if n_routed_experts % ep != 0: + continue + B = n_routed_experts // ep + if B < 1: + continue + for M in M_SIZE_LIST: + for name, (N, K) in shapes_dict.items(): + for dtype in [torch.bfloat16]: + test_cases.append( + { + "Case": name, + "B": B, + "M": M, + "N": N, + "K": K, + "dtype": dtype, + } + ) + return test_cases + + +def generate_deepseekv3_test_cases(): + return _generate_moe_test_cases( + "DSV3", n_routed_experts=256, moe_intermediate_size=2048, hidden_size=7168 + ) + + +def generate_deepseekv2_test_cases(): + return _generate_moe_test_cases( + "DSV2", n_routed_experts=160, moe_intermediate_size=1536, hidden_size=5120 + ) + + +def generate_deepseekv2_lite_test_cases(): + return _generate_moe_test_cases( + "DSV2-Lite", n_routed_experts=64, moe_intermediate_size=1408, hidden_size=2048 + ) + + +def generate_grok_v2_test_cases(): + return _generate_moe_test_cases( + "Grok-V2", n_routed_experts=8, moe_intermediate_size=16384, hidden_size=8192 + ) + + +def make_fwd_bwd_funcs_te(x, w, group_lens, activation_dtype): + from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace + from transformer_engine.pytorch.cpp_extensions import general_grouped_gemm + + B = int(group_lens.numel()) + N = int(w.shape[1]) + K = int(w.shape[2]) + + m_splits = [int(v) for v in group_lens.tolist()] + assert len(m_splits) == B + sum_M = sum(m_splits) + assert x.numel() > 0 and x.shape[0] == sum_M + + x_view = x.reshape(-1, x.shape[-1]) + xs = list(torch.split(x_view, m_splits)) + weights = [w[i] for i in range(B)] + + workspaces = get_multi_stream_cublas_workspace() + + # Forward output buffer + out = torch.empty((sum_M, N), device=x.device, dtype=activation_dtype) + + def fwd_func_te(): + general_grouped_gemm( + A=weights, + B=xs, + out=[out], + out_dtype=activation_dtype, + workspaces=workspaces, + single_output=True, + m_splits=m_splits, + use_bias=False, + bias=None, + layout="TN", + ) + return out + + # dx buffers + dx = torch.empty((sum_M, K), device=x.device, dtype=activation_dtype) + dxs = list(torch.split(dx, m_splits)) + + # dw buffers + dw_stacked = torch.empty((B, N, K), device=x.device, dtype=activation_dtype) + dws = [dw_stacked[i] for i in range(B)] + + def bwd_func_te(grad_out): + go = grad_out.view(-1, grad_out.shape[-1]) + splits = torch.split(go, m_splits) + + general_grouped_gemm( + A=weights, + B=splits, + out=dxs, + out_dtype=activation_dtype, + workspaces=workspaces, + single_output=False, + layout="NN", + m_splits=m_splits, + grad=False, + use_bias=False, + bias=None, + ) + + general_grouped_gemm( + A=xs, + B=splits, + out=dws, + out_dtype=activation_dtype, + workspaces=workspaces, + single_output=False, + layout="NT", + m_splits=m_splits, + grad=False, + use_bias=False, + bias=None, + accumulate=False, + ) + + return dx, dw_stacked + + return fwd_func_te, bwd_func_te + + +def bench_grouped_gemm(B, M, N, K, dtype): + device = "cuda" + + x = torch.randn((B * M, K), dtype=dtype, device=device, requires_grad=True) + w = torch.randn((B, N, K), dtype=dtype, device=device, requires_grad=True) + group_lens = generate_grouped_gemm_group_lens(B, M, balance=True).to(device) + print("group_lens: ", group_lens) + + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + os.environ["NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK"] = "1" + + # TE grouped (CK_Tile) + x_te = x.clone().detach() + w_te = w.clone().detach() + fwd_func_te, bwd_func_te_inner = make_fwd_bwd_funcs_te( + x_te, w_te, group_lens, activation_dtype=dtype + ) + + out_te = fwd_func_te() + grad_out = torch.randn_like(out_te) + bwd_func_te = lambda: bwd_func_te_inner(grad_out) + dx_te, dw_te = bwd_func_te() + + # FLOPs + fwd_total_flops = 2 * B * M * N * K + bwd_total_flops = 2 * fwd_total_flops + + # Warmup + for _ in range(20): + fwd_func_te() + bwd_func_te() + + torch.cuda.synchronize() + + # Benchmark + n_iters = 100 + + fwd_te_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func_te}).timeit(n_iters).mean * 1e3 + bwd_te_ms = benchmark.Timer(stmt="fn()", globals={"fn": bwd_func_te}).timeit(n_iters).mean * 1e3 + + fwd_te_tflops = fwd_total_flops / (fwd_te_ms * 1e-3) / 1e12 + bwd_te_tflops = bwd_total_flops / (bwd_te_ms * 1e-3) / 1e12 + + print(f"TE (CK_Tile) Forward {fwd_te_ms:.3f} ms | {fwd_te_tflops:.2f} TFLOPS") + print(f"TE (CK_Tile) Backward {bwd_te_ms:.3f} ms | {bwd_te_tflops:.2f} TFLOPS") + + return { + "TE (CK_Tile) Forward Time (ms)": f"{fwd_te_ms:.2f}", + "TE (CK_Tile) Forward TFLOPS": f"{fwd_te_tflops:.2f}", + "TE (CK_Tile) Backward Time (ms)": f"{bwd_te_ms:.2f}", + "TE (CK_Tile) Backward TFLOPS": f"{bwd_te_tflops:.2f}", + } + + +if __name__ == "__main__": + import pandas as pd + + test_cases = ( + generate_deepseekv2_lite_test_cases() + + generate_deepseekv2_test_cases() + + generate_deepseekv3_test_cases() + + generate_grok_v2_test_cases() + ) + + columns = [ + "Case", "B", "M", "N", "K", "dtype", + "TE (CK_Tile) Forward Time (ms)", + "TE (CK_Tile) Forward TFLOPS", + "TE (CK_Tile) Backward Time (ms)", + "TE (CK_Tile) Backward TFLOPS", + ] + rows = [] + + # Warmup run + c = test_cases[0] + print(f"\n{'='*50}") + print(f"WARMUP: {c}") + print(f"{'='*50}") + bench_grouped_gemm(B=c["B"], M=c["M"], N=c["N"], K=c["K"], dtype=c["dtype"]) + + for case in test_cases: + print(f"\n{'='*50}") + print(f"Testing: {case}") + print(f"{'='*50}") + try: + metrics = bench_grouped_gemm( + B=case["B"], M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] + ) + row = { + "Case": case["Case"], + "B": case["B"], + "M": case["M"], + "N": case["N"], + "K": case["K"], + "dtype": str(case["dtype"]), + **metrics, + } + rows.append(row) + except Exception as e: + print(f"FAILED: {case}: {e}") + raise + + results = pd.DataFrame(rows, columns=columns) + + out_csv = "benchmark_grouped_gemm.csv" + results.to_csv(out_csv, index=False) + print(f"\nResults saved to {out_csv}") diff --git a/benchmarks/microbenchmarks/compare_results.py b/benchmarks/microbenchmarks/compare_results.py new file mode 100644 index 000000000..c64c124fa --- /dev/null +++ b/benchmarks/microbenchmarks/compare_results.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +Compare two CSVs from the same benchmark (base branch vs PR branch). + +Auto-detects metric columns (containing "TFLOPS") and key columns. +Outputs a markdown
block to stdout with per-config results, +and optionally appends a summary table row to --summary-file. + +Usage: + python compare_results.py base.csv pr.csv --bench-name NAME --summary-file FILE +""" + +import argparse +import sys + +import numpy as np +import pandas as pd + +SKIP_COLS = {"TestID", "Label"} + + +def auto_detect_columns(df): + metric_cols = [c for c in df.columns if "TFLOPS" in c] + key_cols = [ + c for c in df.columns + if c not in metric_cols and c not in SKIP_COLS + and "Time" not in c + ] + return key_cols, metric_cols + + +def main(): + parser = argparse.ArgumentParser(description="Compare benchmark CSVs") + parser.add_argument("base_csv", help="Base branch CSV") + parser.add_argument("pr_csv", help="PR branch CSV") + parser.add_argument("--bench-name", default="benchmark", + help="Benchmark name for markdown output") + parser.add_argument("--summary-file", default=None, + help="Append a summary table row (markdown) to this file") + args = parser.parse_args() + + base_df = pd.read_csv(args.base_csv) + pr_df = pd.read_csv(args.pr_csv) + + key_cols, metric_cols = auto_detect_columns(base_df) + + if not metric_cols: + print("No metric columns found.") + return 0 + + for col in metric_cols: + base_df[col] = pd.to_numeric(base_df[col], errors="coerce") + pr_df[col] = pd.to_numeric(pr_df[col], errors="coerce") + + merged = base_df.merge(pr_df, on=key_cols, suffixes=("_base", "_pr"), how="inner") + if merged.empty: + print("WARNING: No matching rows between base and PR.") + return 0 + + all_speedups = [] + per_row_data = [] + + for idx in merged.index: + row_keys = {k: merged.loc[idx, k] for k in key_cols} + row_metrics = {} + + for metric in metric_cols: + bc, pc = f"{metric}_base", f"{metric}_pr" + bv = merged.loc[idx, bc] + pv = merged.loc[idx, pc] + + if pd.isna(bv) or pd.isna(pv) or bv < 0.5: + continue + + speedup = pv / bv + all_speedups.append(speedup) + row_metrics[metric] = {"base": bv, "pr": pv, "speedup": speedup} + + if row_metrics: + per_row_data.append({"keys": row_keys, "metrics": row_metrics}) + + if not all_speedups: + print("WARNING: No valid comparisons found.") + return 0 + + speedups = np.array(all_speedups) + median_sp = float(np.median(speedups)) + min_sp = float(np.min(speedups)) + max_sp = float(np.max(speedups)) + + # Details block + print("
") + print(f"{args.bench_name} " + f"(median {median_sp:.3f}x, min {min_sp:.3f}x, max {max_sp:.3f}x)") + print() + + header_cols = list(key_cols) + for m in metric_cols: + short = m.replace(" TFLOPS", "") + header_cols.extend([f"{short} Base", f"{short} PR", f"{short} Speedup"]) + + print("| " + " | ".join(header_cols) + " |") + print("|" + "|".join(["---"] * len(header_cols)) + "|") + + for row in per_row_data: + cells = [str(row["keys"].get(k, "")) for k in key_cols] + for metric in metric_cols: + if metric in row["metrics"]: + v = row["metrics"][metric] + cells.append(f"{v['base']:.2f}") + cells.append(f"{v['pr']:.2f}") + cells.append(f"{v['speedup']:.3f}x") + else: + cells.extend(["", "", ""]) + print("| " + " | ".join(cells) + " |") + + print() + print("
") + print() + + # Summary row + if args.summary_file: + with open(args.summary_file, "a") as f: + f.write(f"| {args.bench_name} | {median_sp:.3f}x | {min_sp:.3f}x | {max_sp:.3f}x |\n") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) From d4e9b1e1febcec48b9b845807f45df73e34ffcc1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 11 Mar 2026 16:55:57 -0500 Subject: [PATCH 04/25] misc updates --- .github/workflows/rocm-ci.yml | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 2de97632d..8c9891783 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -382,9 +382,9 @@ jobs: # Map runner names to display names case "${RUNNER_NAME}" in - linux-te-mi325-8) DISPLAY_NAME="MI325" ;; - linux-te-mi355-8) DISPLAY_NAME="MI355" ;; - *) DISPLAY_NAME="${RUNNER_NAME}" ;; + linux-te-mi325*) DISPLAY_NAME="MI325" ;; + linux-te-mi355*) DISPLAY_NAME="MI355" ;; + *) DISPLAY_NAME="${RUNNER_NAME}" ;; esac # Restore PR checkout no matter how this step exits, @@ -467,14 +467,14 @@ jobs: DETAILS="perf_results/reports/details.md" [ -f "$SUMMARY" ] || exit 0 - SECTION_START="" - SECTION_END="" + SECTION_START="" + SECTION_END="" SECTION=$(cat < /tmp/perf_section.md - # Post or update the single shared PR comment + echo "" + echo "========== Performance Report ==========" + cat /tmp/perf_section.md + echo "========================================" + + # Post or update the single shared PR comment (skip under nektos act) + if [ -n "${ACT:-}" ]; then + echo "Running under nektos act, skipping PR comment." + exit 0 + fi + COMMENT_MARKER="" COMMENT_ID=$(gh api "repos/${GH_REPO}/issues/${PR_NUMBER}/comments" \ @@ -496,7 +506,7 @@ jobs: gh api "repos/${GH_REPO}/issues/comments/${COMMENT_ID}" --jq .body \ > /tmp/perf_existing.md - python -c " + python3 -c " existing = open('/tmp/perf_existing.md').read() section = open('/tmp/perf_section.md').read() start = '${SECTION_START}' @@ -515,7 +525,7 @@ jobs: else { echo "${COMMENT_MARKER}" - echo "## Performance Regression Report" + echo "## Performance Report" echo "" cat /tmp/perf_section.md } > /tmp/perf_comment.md @@ -523,6 +533,7 @@ jobs: gh pr comment "$PR_NUMBER" --repo "$GH_REPO" \ --body-file /tmp/perf_comment.md fi + # - name: Check Test Failure Status # if: always() # run: | From 95358f478649d4cbf4ab0612b590f1c29a57b3d0 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 11 Mar 2026 17:42:15 -0500 Subject: [PATCH 05/25] python fix --- .github/workflows/rocm-ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 8c9891783..25943931b 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -506,11 +506,11 @@ jobs: gh api "repos/${GH_REPO}/issues/comments/${COMMENT_ID}" --jq .body \ > /tmp/perf_existing.md - python3 -c " + python3 - "$SECTION_START" "$SECTION_END" << 'PYEOF' + import sys existing = open('/tmp/perf_existing.md').read() section = open('/tmp/perf_section.md').read() - start = '${SECTION_START}' - end = '${SECTION_END}' + start, end = sys.argv[1], sys.argv[2] if start in existing: i = existing.index(start) j = existing.index(end) + len(end) @@ -518,7 +518,7 @@ jobs: else: result = existing.rstrip() + '\n\n' + section open('/tmp/perf_comment.md', 'w').write(result) - " + PYEOF gh api "repos/${GH_REPO}/issues/comments/${COMMENT_ID}" \ --method PATCH --field body=@/tmp/perf_comment.md From d0a320dbd01676deba551319e26afa1ef7df1a9b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 12 Mar 2026 10:43:51 -0500 Subject: [PATCH 06/25] another embedded python fix --- .github/workflows/rocm-ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 25943931b..f2078386d 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -506,7 +506,7 @@ jobs: gh api "repos/${GH_REPO}/issues/comments/${COMMENT_ID}" --jq .body \ > /tmp/perf_existing.md - python3 - "$SECTION_START" "$SECTION_END" << 'PYEOF' + python3 -c " import sys existing = open('/tmp/perf_existing.md').read() section = open('/tmp/perf_section.md').read() @@ -518,7 +518,7 @@ jobs: else: result = existing.rstrip() + '\n\n' + section open('/tmp/perf_comment.md', 'w').write(result) - PYEOF + " "$SECTION_START" "$SECTION_END" gh api "repos/${GH_REPO}/issues/comments/${COMMENT_ID}" \ --method PATCH --field body=@/tmp/perf_comment.md From 6f458536d7f72e1043a0fd7941b24fe88cacee9f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 12 Mar 2026 11:11:54 -0500 Subject: [PATCH 07/25] replace py code --- .github/workflows/rocm-ci.yml | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index f2078386d..ba61c574c 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -506,19 +506,15 @@ jobs: gh api "repos/${GH_REPO}/issues/comments/${COMMENT_ID}" --jq .body \ > /tmp/perf_existing.md - python3 -c " - import sys - existing = open('/tmp/perf_existing.md').read() - section = open('/tmp/perf_section.md').read() - start, end = sys.argv[1], sys.argv[2] - if start in existing: - i = existing.index(start) - j = existing.index(end) + len(end) - result = existing[:i] + section + existing[j:] - else: - result = existing.rstrip() + '\n\n' + section - open('/tmp/perf_comment.md', 'w').write(result) - " "$SECTION_START" "$SECTION_END" + if grep -qF "$SECTION_START" /tmp/perf_existing.md; then + awk -v start="$SECTION_START" -v end="$SECTION_END" -v sf="/tmp/perf_section.md" ' + $0 ~ start { skip=1; while((getline l < sf)>0) print l; next } + $0 ~ end { skip=0; next } + !skip { print } + ' /tmp/perf_existing.md > /tmp/perf_comment.md + else + { cat /tmp/perf_existing.md; echo ""; cat /tmp/perf_section.md; } > /tmp/perf_comment.md + fi gh api "repos/${GH_REPO}/issues/comments/${COMMENT_ID}" \ --method PATCH --field body=@/tmp/perf_comment.md From 55e7eb5c203a2d42be0881fb1d85ae7574fe0a7b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 13 Mar 2026 10:16:18 -0500 Subject: [PATCH 08/25] restore disabled parts of CI --- .github/workflows/rocm-ci.yml | 392 +++++++++++++++++----------------- 1 file changed, 196 insertions(+), 196 deletions(-) diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index ba61c574c..f81a3e347 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -224,152 +224,152 @@ jobs: EOF )" - # - name: Run sGPU tests - # id: sgpu-tests - # continue-on-error: true - # run: | - # # Cleanup previous failure markers if any. Don't actually do anything on k8s pods - # rm -f FAIL_* - - # docker exec \ - # -e TEST_SGPU=1 \ - # -e TEST_LEVEL=${{ env.TEST_LEVEL }} \ - # te-runner bash -c "$(cat <<'EOF' - # #!/usr/bin/bash - # set -x -o pipefail - # ulimit -c 0 # Disable core dumps - - # # debug output - # ls -d /opt/rocm* - # python --version - # pip list | egrep "transformer_e|torch|jax|numpy|ml_dtypes|typing_ext" - - # HIP_VISIBLE_DEVICES=1 ci/pytorch.sh > /workspace/torch_sgpu.log 2>&1 & - # torch_pid=$!; echo Pytorch test pid $! + - name: Run sGPU tests + id: sgpu-tests + continue-on-error: true + run: | + # Cleanup previous failure markers if any. Don't actually do anything on k8s pods + rm -f FAIL_* + + docker exec \ + -e TEST_SGPU=1 \ + -e TEST_LEVEL=${{ env.TEST_LEVEL }} \ + te-runner bash -c "$(cat <<'EOF' + #!/usr/bin/bash + set -x -o pipefail + ulimit -c 0 # Disable core dumps + + # debug output + ls -d /opt/rocm* + python --version + pip list | egrep "transformer_e|torch|jax|numpy|ml_dtypes|typing_ext" + + HIP_VISIBLE_DEVICES=1 ci/pytorch.sh > /workspace/torch_sgpu.log 2>&1 & + torch_pid=$!; echo Pytorch test pid $! - # HIP_VISIBLE_DEVICES=2 ci/jax.sh > /workspace/jax_sgpu.log 2>&1 & - # jax_pid=$!; echo JAX test pid $! + HIP_VISIBLE_DEVICES=2 ci/jax.sh > /workspace/jax_sgpu.log 2>&1 & + jax_pid=$!; echo JAX test pid $! - # HIP_VISIBLE_DEVICES=3 ci/core.sh > /workspace/core_sgpu.log 2>&1 & - # core_pid=$!; echo Core test pid $! + HIP_VISIBLE_DEVICES=3 ci/core.sh > /workspace/core_sgpu.log 2>&1 & + core_pid=$!; echo Core test pid $! - # wait $core_pid; core_rc=$? - # wait $jax_pid; jax_rc=$? - # wait $torch_pid; torch_rc=$? + wait $core_pid; core_rc=$? + wait $jax_pid; jax_rc=$? + wait $torch_pid; torch_rc=$? - # # /workspace/FAIL_* files are for failure markers we can extract to the host runner and process later - # # Check PyTorch - # if [ $torch_rc -ne 0 ]; then - # echo "::group::[FAILED] PyTorch sGPU Log" - # cat /workspace/torch_sgpu.log - # echo "::endgroup::" - # echo "::error::Pytorch sGPU test FAILED." - # touch /workspace/FAIL_TORCH_SGPU - # fi - - # # Check JAX - # if [ $jax_rc -ne 0 ]; then - # echo "::group::[FAILED] JAX sGPU Log" - # cat /workspace/jax_sgpu.log - # echo "::endgroup::" - # echo "::error::JAX sGPU test FAILED." - # touch /workspace/FAIL_JAX_SGPU - # fi - - # # Check Core - # if [ $core_rc -ne 0 ]; then - # echo "::group::[FAILED] Core sGPU Log" - # cat /workspace/core_sgpu.log - # echo "::endgroup::" - # echo "::error::Core sGPU test FAILED." - # touch /workspace/FAIL_CORE_SGPU - # fi + # /workspace/FAIL_* files are for failure markers we can extract to the host runner and process later + # Check PyTorch + if [ $torch_rc -ne 0 ]; then + echo "::group::[FAILED] PyTorch sGPU Log" + cat /workspace/torch_sgpu.log + echo "::endgroup::" + echo "::error::Pytorch sGPU test FAILED." + touch /workspace/FAIL_TORCH_SGPU + fi + + # Check JAX + if [ $jax_rc -ne 0 ]; then + echo "::group::[FAILED] JAX sGPU Log" + cat /workspace/jax_sgpu.log + echo "::endgroup::" + echo "::error::JAX sGPU test FAILED." + touch /workspace/FAIL_JAX_SGPU + fi + + # Check Core + if [ $core_rc -ne 0 ]; then + echo "::group::[FAILED] Core sGPU Log" + cat /workspace/core_sgpu.log + echo "::endgroup::" + echo "::error::Core sGPU test FAILED." + touch /workspace/FAIL_CORE_SGPU + fi - # test $torch_rc -eq 0 -a $jax_rc -eq 0 -a $core_rc -eq 0 - # EOF - # )" + test $torch_rc -eq 0 -a $jax_rc -eq 0 -a $core_rc -eq 0 + EOF + )" - # # Export failed tests statuses to host runner - # if [ -f FAIL_TORCH_SGPU ]; then echo "torch=fail" >> $GITHUB_OUTPUT; fi - # if [ -f FAIL_JAX_SGPU ]; then echo "jax=fail" >> $GITHUB_OUTPUT; fi - # if [ -f FAIL_CORE_SGPU ]; then echo "core=fail" >> $GITHUB_OUTPUT; fi - - # - name: Run mGPU tests - # id: mgpu-tests - # continue-on-error: true - # run: | - # docker exec \ - # -e TEST_MGPU=1 \ - # -e TEST_LEVEL=${{ env.TEST_LEVEL }} \ - # te-runner bash -c "$(cat <<'EOF' - # #!/usr/bin/bash - # set -x -o pipefail - # ulimit -c 0 # Disable core dumps + # Export failed tests statuses to host runner + if [ -f FAIL_TORCH_SGPU ]; then echo "torch=fail" >> $GITHUB_OUTPUT; fi + if [ -f FAIL_JAX_SGPU ]; then echo "jax=fail" >> $GITHUB_OUTPUT; fi + if [ -f FAIL_CORE_SGPU ]; then echo "core=fail" >> $GITHUB_OUTPUT; fi + + - name: Run mGPU tests + id: mgpu-tests + continue-on-error: true + run: | + docker exec \ + -e TEST_MGPU=1 \ + -e TEST_LEVEL=${{ env.TEST_LEVEL }} \ + te-runner bash -c "$(cat <<'EOF' + #!/usr/bin/bash + set -x -o pipefail + ulimit -c 0 # Disable core dumps - # # Run PyTorch - # ci/pytorch.sh > /workspace/torch_mgpu.log 2>&1 - # torch_rc=$? + # Run PyTorch + ci/pytorch.sh > /workspace/torch_mgpu.log 2>&1 + torch_rc=$? - # # Run JAX - # ci/jax.sh > /workspace/jax_mgpu.log 2>&1 - # jax_rc=$? + # Run JAX + ci/jax.sh > /workspace/jax_mgpu.log 2>&1 + jax_rc=$? - # # /workspace/FAIL_* files are for failure markers we can extract to the host runner and process later - # if [ $torch_rc -ne 0 ]; then - # echo "::group::[FAILED] PyTorch mGPU Log" - # cat /workspace/torch_mgpu.log - # echo "::endgroup::" - # echo "::error::Pytorch mGPU test FAILED." - # touch /workspace/FAIL_TORCH_MGPU - # fi - - # if [ $jax_rc -ne 0 ]; then - # echo "::group::[FAILED] JAX mGPU Log" - # cat /workspace/jax_mgpu.log - # echo "::endgroup::" - # echo "::error::JAX mGPU test FAILED." - # touch /workspace/FAIL_JAX_MGPU - # fi + # /workspace/FAIL_* files are for failure markers we can extract to the host runner and process later + if [ $torch_rc -ne 0 ]; then + echo "::group::[FAILED] PyTorch mGPU Log" + cat /workspace/torch_mgpu.log + echo "::endgroup::" + echo "::error::Pytorch mGPU test FAILED." + touch /workspace/FAIL_TORCH_MGPU + fi + + if [ $jax_rc -ne 0 ]; then + echo "::group::[FAILED] JAX mGPU Log" + cat /workspace/jax_mgpu.log + echo "::endgroup::" + echo "::error::JAX mGPU test FAILED." + touch /workspace/FAIL_JAX_MGPU + fi - # test $torch_rc -eq 0 -a $jax_rc -eq 0 - # EOF - # )" - - # # Export failed tests statuses to host runner - # if [ -f FAIL_TORCH_MGPU ]; then echo "torch=fail" >> $GITHUB_OUTPUT; fi - # if [ -f FAIL_JAX_MGPU ]; then echo "jax=fail" >> $GITHUB_OUTPUT; fi - - # - name: Run Examples - # id: examples-tests - # continue-on-error: true - # env: - # HF_TOKEN: ${{ secrets.HF_TOKEN }} - # run: | - # docker exec -e HF_TOKEN="$HF_TOKEN" te-runner bash -c "$(cat <<'EOF' - # #!/usr/bin/bash - # set -ex -o pipefail - # ulimit -c 0 # Disable core dumps - - # # Check whether the HF_TOKEN is present - # python -c "import os; print('HF_TOKEN set:', bool(os.environ.get('HF_TOKEN')))" - - # cd /workspace/examples/pytorch/mnist - # python main.py 2>&1 | tee /workspace/examples.log - # python main.py --use-te 2>&1 | tee -a /workspace/examples.log - # python main.py --use-fp8 2>&1 | tee -a /workspace/examples.log + test $torch_rc -eq 0 -a $jax_rc -eq 0 + EOF + )" + + # Export failed tests statuses to host runner + if [ -f FAIL_TORCH_MGPU ]; then echo "torch=fail" >> $GITHUB_OUTPUT; fi + if [ -f FAIL_JAX_MGPU ]; then echo "jax=fail" >> $GITHUB_OUTPUT; fi + + - name: Run Examples + id: examples-tests + continue-on-error: true + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + docker exec -e HF_TOKEN="$HF_TOKEN" te-runner bash -c "$(cat <<'EOF' + #!/usr/bin/bash + set -ex -o pipefail + ulimit -c 0 # Disable core dumps + + # Check whether the HF_TOKEN is present + python -c "import os; print('HF_TOKEN set:', bool(os.environ.get('HF_TOKEN')))" + + cd /workspace/examples/pytorch/mnist + python main.py 2>&1 | tee /workspace/examples.log + python main.py --use-te 2>&1 | tee -a /workspace/examples.log + python main.py --use-fp8 2>&1 | tee -a /workspace/examples.log - # cd /workspace/examples/jax/mnist - # pip3 install -r requirements.txt - # python test_single_gpu_mnist.py 2>&1 | tee -a /workspace/examples.log - # python test_single_gpu_mnist.py --use-te 2>&1 | tee -a /workspace/examples.log - # python test_single_gpu_mnist.py --use-fp8 2>&1 | tee -a /workspace/examples.log + cd /workspace/examples/jax/mnist + pip3 install -r requirements.txt + python test_single_gpu_mnist.py 2>&1 | tee -a /workspace/examples.log + python test_single_gpu_mnist.py --use-te 2>&1 | tee -a /workspace/examples.log + python test_single_gpu_mnist.py --use-fp8 2>&1 | tee -a /workspace/examples.log - # cd /workspace/examples/jax/encoder - # pip3 install -r requirements.txt - # python test_single_gpu_encoder.py 2>&1 | tee -a /workspace/examples.log - # python test_single_gpu_encoder.py --use-fp8 2>&1 | tee -a /workspace/examples.log - # EOF - # )" + cd /workspace/examples/jax/encoder + pip3 install -r requirements.txt + python test_single_gpu_encoder.py 2>&1 | tee -a /workspace/examples.log + python test_single_gpu_encoder.py --use-fp8 2>&1 | tee -a /workspace/examples.log + EOF + )" - name: "Performance regression check" env: @@ -530,69 +530,69 @@ jobs: --body-file /tmp/perf_comment.md fi - # - name: Check Test Failure Status - # if: always() - # run: | - # EXIT_STATUS=0 - # # Check outcomes of the specific test steps - # # "outcome" will be 'failure' even if continue-on-error was true - - # # sGPU CHECKS - # # We check for the file existence directly because the 'Run sGPU tests' step - # # halts immediately on docker failure, skipping the lines that set step outputs. - # if [[ -f FAIL_CORE_SGPU ]]; then - # echo "::error::Core sGPU Tests Failed." - # EXIT_STATUS=1 - # fi - # if [[ -f FAIL_TORCH_SGPU ]]; then - # echo "::error::PyTorch sGPU Tests Failed." - # EXIT_STATUS=1 - # fi - # if [[ -f FAIL_JAX_SGPU ]]; then - # echo "::error::JAX sGPU Tests Failed." - # EXIT_STATUS=1 - # fi + - name: Check Test Failure Status + if: always() + run: | + EXIT_STATUS=0 + # Check outcomes of the specific test steps + # "outcome" will be 'failure' even if continue-on-error was true + + # sGPU CHECKS + # We check for the file existence directly because the 'Run sGPU tests' step + # halts immediately on docker failure, skipping the lines that set step outputs. + if [[ -f FAIL_CORE_SGPU ]]; then + echo "::error::Core sGPU Tests Failed." + EXIT_STATUS=1 + fi + if [[ -f FAIL_TORCH_SGPU ]]; then + echo "::error::PyTorch sGPU Tests Failed." + EXIT_STATUS=1 + fi + if [[ -f FAIL_JAX_SGPU ]]; then + echo "::error::JAX sGPU Tests Failed." + EXIT_STATUS=1 + fi - # # mGPU CHECKS - # if [[ -f FAIL_TORCH_MGPU ]]; then - # echo "::error::PyTorch mGPU Tests Failed." - # EXIT_STATUS=1 - # fi - # if [[ -f FAIL_JAX_MGPU ]]; then - # echo "::error::JAX mGPU Tests Failed." - # EXIT_STATUS=1 - # fi - - # # EXAMPLES CHECK - # # Examples script does not use marker files, so we rely on step outcome - # if [[ "${{ steps.examples-tests.outcome }}" == "failure" ]]; then - # echo "::error::Example Tests Failed." - # EXIT_STATUS=1 - # fi - - # # Fail the job if any errors were detected - # if [[ "$EXIT_STATUS" == "1" ]]; then - # exit 1 - # fi - - # - name: Copy logs and reports from container - # if: always() - # run: | - # docker cp te-runner:/workspace/torch_sgpu.log ./torch_sgpu.log || true - # docker cp te-runner:/workspace/jax_sgpu.log ./jax_sgpu.log || true - # docker cp te-runner:/workspace/core_sgpu.log ./core_sgpu.log || true - # docker cp te-runner:/workspace/torch_mgpu.log ./torch_mgpu.log || true - # docker cp te-runner:/workspace/jax_mgpu.log ./jax_mgpu.log || true - - # - name: Upload logs and test reports - # if: always() - # uses: actions/upload-artifact@v4 - # with: - # name: logs-and-reports-${{ matrix.runner }} - # path: | - # *.log - # if-no-files-found: ignore - # retention-days: 5 + # mGPU CHECKS + if [[ -f FAIL_TORCH_MGPU ]]; then + echo "::error::PyTorch mGPU Tests Failed." + EXIT_STATUS=1 + fi + if [[ -f FAIL_JAX_MGPU ]]; then + echo "::error::JAX mGPU Tests Failed." + EXIT_STATUS=1 + fi + + # EXAMPLES CHECK + # Examples script does not use marker files, so we rely on step outcome + if [[ "${{ steps.examples-tests.outcome }}" == "failure" ]]; then + echo "::error::Example Tests Failed." + EXIT_STATUS=1 + fi + + # Fail the job if any errors were detected + if [[ "$EXIT_STATUS" == "1" ]]; then + exit 1 + fi + + - name: Copy logs and reports from container + if: always() + run: | + docker cp te-runner:/workspace/torch_sgpu.log ./torch_sgpu.log || true + docker cp te-runner:/workspace/jax_sgpu.log ./jax_sgpu.log || true + docker cp te-runner:/workspace/core_sgpu.log ./core_sgpu.log || true + docker cp te-runner:/workspace/torch_mgpu.log ./torch_mgpu.log || true + docker cp te-runner:/workspace/jax_mgpu.log ./jax_mgpu.log || true + + - name: Upload logs and test reports + if: always() + uses: actions/upload-artifact@v4 + with: + name: logs-and-reports-${{ matrix.runner }} + path: | + *.log + if-no-files-found: ignore + retention-days: 5 - name: Cleanup container if: always() From 9c771b4dde4d2a8ca11035c77f5adcac57dfc1bb Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 13 Mar 2026 14:33:52 -0500 Subject: [PATCH 09/25] add attention, casting, normalization --- .../microbenchmarks/benchmark_attention.py | 188 ++++++++++++++++++ .../microbenchmarks/benchmark_casting.py | 163 +++++++++++++++ .../benchmark_normalization.py | 174 ++++++++++++++++ benchmarks/microbenchmarks/compare_results.py | 4 +- 4 files changed, 527 insertions(+), 2 deletions(-) create mode 100755 benchmarks/microbenchmarks/benchmark_attention.py create mode 100755 benchmarks/microbenchmarks/benchmark_casting.py create mode 100755 benchmarks/microbenchmarks/benchmark_normalization.py mode change 100644 => 100755 benchmarks/microbenchmarks/compare_results.py diff --git a/benchmarks/microbenchmarks/benchmark_attention.py b/benchmarks/microbenchmarks/benchmark_attention.py new file mode 100755 index 000000000..6f419b62d --- /dev/null +++ b/benchmarks/microbenchmarks/benchmark_attention.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +Attention micro-benchmark using te.DotProductAttention. + +Benchmarks fused multi-head attention (with flash attention backend) for +model configurations with grouped-query attention (GQA). + +Models: + - Llama 3 8B (TP=1, TP=8), 70B (TP=8), 405B (TP=8) + - Qwen 2.5 7B (TP=1), 72B (TP=8) + +Forward FLOPs = 4 * batch * num_q_heads * seq_len^2 * head_dim + (two matmuls: Q@K^T and attn@V, each contributing 2*b*h*s^2*d) +Backward FLOPs = 2 * Forward FLOPs (approximately) + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json + +Output: benchmark_attention.csv (written to cwd) +""" + +import torch +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te + +# Sweep parameters +BATCH_SIZE = 2 +SEQ_LEN_LIST = [1024, 2048, 4096, 8192] + +# (name, num_q_heads, num_kv_heads, head_dim, tp) +MODEL_CONFIGS = [ + ("Llama3-8B/TP1", 32, 8, 128, 1), + ("Llama3-8B/TP8", 32, 8, 128, 8), + ("Llama3-70B/TP8", 64, 8, 128, 8), + ("Llama3-405B/TP8", 128, 8, 128, 8), + ("Qwen2.5-7B/TP1", 28, 4, 128, 1), + ("Qwen2.5-72B/TP8", 64, 8, 128, 8), +] + + +def _generate_attn_test_cases(): + test_cases = [] + for (name, n_q, n_kv, hd, tp) in MODEL_CONFIGS: + q_per_gpu = n_q // tp + kv_per_gpu = n_kv // tp + if q_per_gpu < 1 or kv_per_gpu < 1: + continue + for seq_len in SEQ_LEN_LIST: + test_cases.append({ + "Case": name, + "batch": BATCH_SIZE, + "seq_len": seq_len, + "num_q_heads": q_per_gpu, + "num_kv_heads": kv_per_gpu, + "head_dim": hd, + }) + return test_cases + + +def bench_attention(batch, seq_len, num_q_heads, num_kv_heads, head_dim): + device = "cuda" + dtype = torch.bfloat16 + + attn = te.DotProductAttention( + num_attention_heads=num_q_heads, + kv_channels=head_dim, + num_gqa_groups=num_kv_heads, + attn_mask_type="causal", + ).to(device=device, dtype=dtype) + + q = torch.randn(seq_len, batch, num_q_heads, head_dim, + dtype=dtype, device=device, requires_grad=True) + k = torch.randn(seq_len, batch, num_kv_heads, head_dim, + dtype=dtype, device=device, requires_grad=True) + v = torch.randn(seq_len, batch, num_kv_heads, head_dim, + dtype=dtype, device=device, requires_grad=True) + + fwd_func = lambda: attn(q, k, v) + out = fwd_func() + grad_out = torch.randn_like(out) + + def fwd_bwd_func(): + out = attn(q, k, v) + out.backward(grad_out) + q.grad = None + k.grad = None + v.grad = None + + fwd_bwd_func() + + # FLOPs: two matmuls (Q@K^T and attn@V), each 2*b*h*s^2*d + fwd_flops = 4 * batch * num_q_heads * seq_len * seq_len * head_dim + bwd_flops = 2 * fwd_flops + + # Warmup + for _ in range(20): + fwd_func() + fwd_bwd_func() + torch.cuda.synchronize() + + # Benchmark + n_iters = 100 + + fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).timeit(n_iters).mean * 1e3 + fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_bwd_func}).timeit(n_iters).mean * 1e3 + + bwd_ms = max(fwd_bwd_ms - fwd_ms, 0.0) + + fwd_tflops = fwd_flops / (fwd_ms * 1e-3) / 1e12 + bwd_tflops = bwd_flops / (bwd_ms * 1e-3) / 1e12 if bwd_ms > 0 else 0.0 + + print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS") + print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)") + + return { + "TE Forward Time (ms)": f"{fwd_ms:.2f}", + "TE Forward TFLOPS": f"{fwd_tflops:.2f}", + "TE Backward Time (ms)": f"{bwd_ms:.2f}", + "TE Backward TFLOPS": f"{bwd_tflops:.2f}", + } + + +if __name__ == "__main__": + import pandas as pd + + test_cases = _generate_attn_test_cases() + + columns = [ + "Case", "batch", "seq_len", "num_q_heads", "num_kv_heads", "head_dim", + "TE Forward Time (ms)", + "TE Forward TFLOPS", + "TE Backward Time (ms)", + "TE Backward TFLOPS", + ] + rows = [] + + # Warmup run + c = test_cases[0] + print(f"\n{'='*60}") + print(f"WARMUP: {c['Case']} b={c['batch']} s={c['seq_len']} " + f"qh={c['num_q_heads']} kvh={c['num_kv_heads']} hd={c['head_dim']}") + print(f"{'='*60}") + bench_attention(batch=c["batch"], seq_len=c["seq_len"], + num_q_heads=c["num_q_heads"], num_kv_heads=c["num_kv_heads"], + head_dim=c["head_dim"]) + + for case in test_cases: + print(f"\n{'='*60}") + print(f"Testing: {case['Case']} b={case['batch']} s={case['seq_len']} " + f"qh={case['num_q_heads']} kvh={case['num_kv_heads']} hd={case['head_dim']}") + print(f"{'='*60}") + try: + metrics = bench_attention( + batch=case["batch"], + seq_len=case["seq_len"], + num_q_heads=case["num_q_heads"], + num_kv_heads=case["num_kv_heads"], + head_dim=case["head_dim"], + ) + row = { + "Case": case["Case"], + "batch": case["batch"], + "seq_len": case["seq_len"], + "num_q_heads": case["num_q_heads"], + "num_kv_heads": case["num_kv_heads"], + "head_dim": case["head_dim"], + **metrics, + } + rows.append(row) + except Exception as e: + print(f"FAILED: {case['Case']}: {e}") + raise + + results = pd.DataFrame(rows, columns=columns) + + out_csv = "benchmark_attention.csv" + results.to_csv(out_csv, index=False) + print(f"\nResults saved to {out_csv}") diff --git a/benchmarks/microbenchmarks/benchmark_casting.py b/benchmarks/microbenchmarks/benchmark_casting.py new file mode 100755 index 000000000..c53810765 --- /dev/null +++ b/benchmarks/microbenchmarks/benchmark_casting.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +FP8 casting micro-benchmark. + +Benchmarks quantization (BF16 -> FP8) and dequantization (FP8 -> BF16) for +both E4M3 (activations/weights) and E5M2 (gradients) formats. + +Shapes are (M, hidden_size) matching the activation tensors from models: + - Llama 3 8B, 70B, 405B + - Qwen 2.5 7B, 72B + +These casts are memory-bound; we report GB/s (input + output bytes). + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json + +Output: benchmark_casting.csv (written to cwd) +""" + +import torch +import torch.utils.benchmark as benchmark + + +# Detect FP8 dtypes (ROCm vs CUDA) +if hasattr(torch, "float8_e4m3fnuz"): + FP8_E4M3 = torch.float8_e4m3fnuz + FP8_E5M2 = torch.float8_e5m2fnuz +else: + FP8_E4M3 = torch.float8_e4m3fn + FP8_E5M2 = torch.float8_e5m2 + +# Sequence / batch-token sizes to sweep +M_SIZE_LIST = [1024, 2048, 4096, 8192] + +# (model_name, hidden_size) +MODEL_HIDDEN_SIZES = [ + ("Llama3-8B", 4096), + ("Llama3-70B", 8192), + ("Llama3-405B", 16384), + ("Qwen2.5-7B", 3584), + ("Qwen2.5-72B", 8192), +] + +# (cast_name, src_dtype, dst_dtype) +CAST_CONFIGS = [ + ("BF16-to-FP8-E4M3", torch.bfloat16, FP8_E4M3), + ("FP8-E4M3-to-BF16", FP8_E4M3, torch.bfloat16), + ("BF16-to-FP8-E5M2", torch.bfloat16, FP8_E5M2), + ("FP8-E5M2-to-BF16", FP8_E5M2, torch.bfloat16), +] + + +def _generate_cast_test_cases(): + test_cases = [] + for model_name, hidden in MODEL_HIDDEN_SIZES: + for cast_name, src_dtype, dst_dtype in CAST_CONFIGS: + for M in M_SIZE_LIST: + test_cases.append({ + "Case": f"{model_name}/{cast_name}", + "M": M, + "hidden_size": hidden, + "src_dtype": src_dtype, + "dst_dtype": dst_dtype, + "dtype_str": cast_name, + }) + return test_cases + + +def bench_cast(M, hidden_size, src_dtype, dst_dtype): + device = "cuda" + + # For FP8 source, create via cast from randn + if src_dtype in (FP8_E4M3, FP8_E5M2): + x = torch.randn(M, hidden_size, dtype=torch.bfloat16, device=device).to(src_dtype) + else: + x = torch.randn(M, hidden_size, dtype=src_dtype, device=device) + + cast_func = lambda: x.to(dst_dtype) + + # Sanity check + cast_func() + + # Total bytes moved: read input + write output + numel = M * hidden_size + src_bytes = numel * x.element_size() + dst_bytes = numel * cast_func().element_size() + total_bytes = src_bytes + dst_bytes + + # Warmup + for _ in range(20): + cast_func() + torch.cuda.synchronize() + + # Benchmark + n_iters = 100 + ms = benchmark.Timer(stmt="fn()", globals={"fn": cast_func}).timeit(n_iters).mean * 1e3 + gbps = total_bytes / (ms * 1e-3) / 1e9 + + print(f" {ms:.4f} ms | {gbps:.1f} GB/s") + + return { + "Cast Time (ms)": f"{ms:.4f}", + "Cast GB/s": f"{gbps:.1f}", + } + + +if __name__ == "__main__": + import pandas as pd + + test_cases = _generate_cast_test_cases() + + columns = [ + "Case", "M", "hidden_size", "dtype_str", + "Cast Time (ms)", + "Cast GB/s", + ] + rows = [] + + # Warmup run + c = test_cases[0] + print(f"\n{'='*60}") + print(f"WARMUP: {c['Case']} M={c['M']} hidden={c['hidden_size']}") + print(f"{'='*60}") + bench_cast(M=c["M"], hidden_size=c["hidden_size"], + src_dtype=c["src_dtype"], dst_dtype=c["dst_dtype"]) + + for case in test_cases: + print(f"\n{'='*60}") + print(f"Testing: {case['Case']} M={case['M']} hidden={case['hidden_size']}") + print(f"{'='*60}") + try: + metrics = bench_cast( + M=case["M"], + hidden_size=case["hidden_size"], + src_dtype=case["src_dtype"], + dst_dtype=case["dst_dtype"], + ) + row = { + "Case": case["Case"], + "M": case["M"], + "hidden_size": case["hidden_size"], + "dtype_str": case["dtype_str"], + **metrics, + } + rows.append(row) + except Exception as e: + print(f"FAILED: {case['Case']}: {e}") + raise + + results = pd.DataFrame(rows, columns=columns) + + out_csv = "benchmark_casting.csv" + results.to_csv(out_csv, index=False) + print(f"\nResults saved to {out_csv}") diff --git a/benchmarks/microbenchmarks/benchmark_normalization.py b/benchmarks/microbenchmarks/benchmark_normalization.py new file mode 100755 index 000000000..1caa04f43 --- /dev/null +++ b/benchmarks/microbenchmarks/benchmark_normalization.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +Normalization micro-benchmark using te.LayerNorm and te.RMSNorm. + +Shapes are derived from training workloads: + - Llama 3 8B, 70B, 405B (all use RMSNorm) + - Qwen 2.5 7B, 72B (all use RMSNorm) + +Modern models predominantly use RMSNorm, but we benchmark both +LayerNorm and RMSNorm since TE supports both and they share the +same kernel infrastructure. + +The M dimension (batch * seq_len) is swept across typical training sizes. + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json + +Output: benchmark_normalization.csv (written to cwd) +""" + +import torch +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te + +# Sequence / batch-token sizes to sweep +M_SIZE_LIST = [1024, 2048, 4096, 8192] + +# (model_name, hidden_size) +MODEL_HIDDEN_SIZES = [ + ("Llama3-8B", 4096), + ("Llama3-70B", 8192), + ("Llama3-405B", 16384), + ("Qwen2.5-7B", 3584), + ("Qwen2.5-72B", 8192), +] + +NORM_TYPES = [ + ("RMSNorm", te.RMSNorm), + ("LayerNorm", te.LayerNorm), +] + + +def _generate_norm_test_cases(): + test_cases = [] + for model_name, hidden in MODEL_HIDDEN_SIZES: + for norm_name, norm_cls in NORM_TYPES: + for M in M_SIZE_LIST: + test_cases.append({ + "Case": f"{model_name}/{norm_name}", + "M": M, + "hidden_size": hidden, + "norm_name": norm_name, + "norm_cls": norm_cls, + "dtype": torch.bfloat16, + }) + return test_cases + + +def bench_norm(M, hidden_size, norm_cls, dtype): + device = "cuda" + + norm = norm_cls(hidden_size).to(device=device, dtype=dtype) + x = torch.randn(M, hidden_size, dtype=dtype, device=device, requires_grad=True) + + fwd_func = lambda: norm(x) + out = fwd_func() + grad_out = torch.randn_like(out) + + def fwd_bwd_func(): + out = norm(x) + out.backward(grad_out) + x.grad = None + for p in norm.parameters(): + p.grad = None + + fwd_bwd_func() + + # Normalization is memory-bound; report bandwidth instead of FLOPS. + # Each element is read once (fwd) or read+written (bwd), plus the + # weight/bias vectors. We report effective GB/s based on the + # minimum data movement: fwd reads x and writes y, bwd reads + # grad_out+x+saved_stats and writes grad_x+grad_weight. + elem_bytes = x.element_size() + fwd_bytes = 2 * M * hidden_size * elem_bytes # read x, write y + bwd_bytes = 4 * M * hidden_size * elem_bytes # read grad+x+y, write grad_x + + # Warmup + for _ in range(20): + fwd_func() + fwd_bwd_func() + torch.cuda.synchronize() + + # Benchmark + n_iters = 100 + + fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).timeit(n_iters).mean * 1e3 + fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_bwd_func}).timeit(n_iters).mean * 1e3 + + bwd_ms = max(fwd_bwd_ms - fwd_ms, 0.0) + + fwd_gbps = fwd_bytes / (fwd_ms * 1e-3) / 1e9 + bwd_gbps = bwd_bytes / (bwd_ms * 1e-3) / 1e9 if bwd_ms > 0 else 0.0 + + print(f" Forward {fwd_ms:.3f} ms | {fwd_gbps:.1f} GB/s") + print(f" Backward {bwd_ms:.3f} ms | {bwd_gbps:.1f} GB/s (derived)") + + return { + "TE Forward Time (ms)": f"{fwd_ms:.4f}", + "TE Forward GB/s": f"{fwd_gbps:.1f}", + "TE Backward Time (ms)": f"{bwd_ms:.4f}", + "TE Backward GB/s": f"{bwd_gbps:.1f}", + } + + +if __name__ == "__main__": + import pandas as pd + + test_cases = _generate_norm_test_cases() + + columns = [ + "Case", "M", "hidden_size", "dtype", + "TE Forward Time (ms)", + "TE Forward GB/s", + "TE Backward Time (ms)", + "TE Backward GB/s", + ] + rows = [] + + # Warmup run + c = test_cases[0] + print(f"\n{'='*60}") + print(f"WARMUP: {c['Case']} M={c['M']} hidden={c['hidden_size']}") + print(f"{'='*60}") + bench_norm(M=c["M"], hidden_size=c["hidden_size"], + norm_cls=c["norm_cls"], dtype=c["dtype"]) + + for case in test_cases: + print(f"\n{'='*60}") + print(f"Testing: {case['Case']} M={case['M']} hidden={case['hidden_size']}") + print(f"{'='*60}") + try: + metrics = bench_norm( + M=case["M"], + hidden_size=case["hidden_size"], + norm_cls=case["norm_cls"], + dtype=case["dtype"], + ) + row = { + "Case": case["Case"], + "M": case["M"], + "hidden_size": case["hidden_size"], + "dtype": str(case["dtype"]), + **metrics, + } + rows.append(row) + except Exception as e: + print(f"FAILED: {case['Case']}: {e}") + raise + + results = pd.DataFrame(rows, columns=columns) + + out_csv = "benchmark_normalization.csv" + results.to_csv(out_csv, index=False) + print(f"\nResults saved to {out_csv}") diff --git a/benchmarks/microbenchmarks/compare_results.py b/benchmarks/microbenchmarks/compare_results.py old mode 100644 new mode 100755 index c64c124fa..7353066bc --- a/benchmarks/microbenchmarks/compare_results.py +++ b/benchmarks/microbenchmarks/compare_results.py @@ -7,7 +7,7 @@ """ Compare two CSVs from the same benchmark (base branch vs PR branch). -Auto-detects metric columns (containing "TFLOPS") and key columns. +Auto-detects metric columns (containing "TFLOPS"/ "GB/s") and key columns. Outputs a markdown
block to stdout with per-config results, and optionally appends a summary table row to --summary-file. @@ -25,7 +25,7 @@ def auto_detect_columns(df): - metric_cols = [c for c in df.columns if "TFLOPS" in c] + metric_cols = [c for c in df.columns if "TFLOPS" in c or "GB/s" in c] key_cols = [ c for c in df.columns if c not in metric_cols and c not in SKIP_COLS From 64e8da83980e7d85435c76fcfa2427e06abe1253 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 13 Mar 2026 15:05:45 -0500 Subject: [PATCH 10/25] add timestamp and commit ID --- .github/workflows/rocm-ci.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index f81a3e347..84271b4aa 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -371,7 +371,7 @@ jobs: EOF )" - - name: "Performance regression check" + - name: Performance regression check env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} PR_NUMBER: ${{ github.event.pull_request.number }} @@ -470,9 +470,12 @@ jobs: SECTION_START="" SECTION_END="" + CI_TRIGGERED_AT="$(TZ='America/Chicago' date -d '${{ github.event.pull_request.updated_at }}' '+%Y-%m-%d %H:%M:%S %Z')" + SECTION=$(cat <PR commit: ${{ github.sha }} | Base: \`${{ github.base_ref }}\` | ${CI_TRIGGERED_AT} | Benchmark suite | Median speedup | Min speedup | Max speedup | |---|---|---|---| From c986c97de14eeec0e93442d34983f6be1f6259d9 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 13 Mar 2026 15:30:52 -0500 Subject: [PATCH 11/25] add FP8 GEMM --- .../microbenchmarks/benchmark_gemm_fp8.py | 178 ++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100755 benchmarks/microbenchmarks/benchmark_gemm_fp8.py diff --git a/benchmarks/microbenchmarks/benchmark_gemm_fp8.py b/benchmarks/microbenchmarks/benchmark_gemm_fp8.py new file mode 100755 index 000000000..400383749 --- /dev/null +++ b/benchmarks/microbenchmarks/benchmark_gemm_fp8.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +FP8 GEMM micro-benchmark using te.Linear under fp8_autocast. + +Same model shapes as benchmark_gemm.py: + - Llama 3 8B (TP=1, TP=8), 70B (TP=8), 405B (TP=8) + - Qwen 2.5 7B (TP=1), 72B (TP=8) + +Each model contributes four GEMM shapes: + QKV projection (column-parallel) N = (Qheads + 2*KVheads)*head_dim / TP, K = hidden + Attention output (row-parallel) N = hidden, K = Qheads*head_dim / TP + MLP Gate+Up (column-parallel) N = 2*intermediate / TP, K = hidden (SwiGLU) + MLP Down (row-parallel) N = hidden, K = intermediate / TP + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json + +Output: benchmark_fp8_gemm.csv (written to cwd) +""" + +import torch +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling, Format + +# Sequence / batch-token sizes to sweep +M_SIZE_LIST = [1024, 2048, 4096, 8192] + +# (name, hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) +MODEL_CONFIGS = [ + ("Llama3-8B/TP1", 4096, 14336, 32, 8, 128, 1), + ("Llama3-8B/TP8", 4096, 14336, 32, 8, 128, 8), + ("Llama3-70B/TP8", 8192, 28672, 64, 8, 128, 8), + ("Llama3-405B/TP8", 16384, 53248, 128, 8, 128, 8), + ("Qwen2.5-7B/TP1", 3584, 18944, 28, 4, 128, 1), + ("Qwen2.5-72B/TP8", 8192, 29568, 64, 8, 128, 8), +] + +FP8_RECIPE = DelayedScaling( + fp8_format=Format.HYBRID, + amax_history_len=16, + amax_compute_algo="max", +) + + +def _generate_gemm_test_cases(): + test_cases = [] + for (name, hidden, intermediate, n_q, n_kv, hd, tp) in MODEL_CONFIGS: + shapes = { + f"{name}-QKV": ((n_q * hd + 2 * n_kv * hd) // tp, hidden), + f"{name}-AttnOut": (hidden, (n_q * hd) // tp), + f"{name}-GateUp": ((2 * intermediate) // tp, hidden), + f"{name}-Down": (hidden, intermediate // tp), + } + for M in M_SIZE_LIST: + for case_name, (N, K) in shapes.items(): + test_cases.append({ + "Case": case_name, + "M": M, + "N": N, + "K": K, + "dtype": torch.bfloat16, + }) + return test_cases + + +def bench_fp8_gemm(M, N, K, dtype): + device = "cuda" + + linear = te.Linear(K, N, bias=False).to(device=device, dtype=dtype) + x = torch.randn(M, K, dtype=dtype, device=device, requires_grad=True) + grad_out = torch.randn(M, N, dtype=dtype, device=device) + + # Forward under fp8_autocast + def fwd_func(): + with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + return linear(x) + + # Combined fwd+bwd (TE consumes saved state on backward, no retain_graph) + def fwd_bwd_func(): + with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + out = linear(x) + out.backward(grad_out) + x.grad = None + linear.weight.grad = None + + # Sanity run + fwd_func() + fwd_bwd_func() + + fwd_flops = 2 * M * N * K + bwd_flops = 2 * fwd_flops + + # Warmup + for _ in range(20): + fwd_func() + fwd_bwd_func() + torch.cuda.synchronize() + + # Benchmark + n_iters = 100 + + fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).timeit(n_iters).mean * 1e3 + fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_bwd_func}).timeit(n_iters).mean * 1e3 + + bwd_ms = max(fwd_bwd_ms - fwd_ms, 0.0) + + fwd_tflops = fwd_flops / (fwd_ms * 1e-3) / 1e12 + bwd_tflops = bwd_flops / (bwd_ms * 1e-3) / 1e12 if bwd_ms > 0 else 0.0 + + print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS") + print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)") + + return { + "FP8 Forward Time (ms)": f"{fwd_ms:.2f}", + "FP8 Forward TFLOPS": f"{fwd_tflops:.2f}", + "FP8 Backward Time (ms)": f"{bwd_ms:.2f}", + "FP8 Backward TFLOPS": f"{bwd_tflops:.2f}", + } + + +if __name__ == "__main__": + import pandas as pd + + test_cases = _generate_gemm_test_cases() + + columns = [ + "Case", "M", "N", "K", "dtype", + "FP8 Forward Time (ms)", + "FP8 Forward TFLOPS", + "FP8 Backward Time (ms)", + "FP8 Backward TFLOPS", + ] + rows = [] + + # Warmup run + c = test_cases[0] + print(f"\n{'='*60}") + print(f"WARMUP: {c}") + print(f"{'='*60}") + bench_fp8_gemm(M=c["M"], N=c["N"], K=c["K"], dtype=c["dtype"]) + + for case in test_cases: + print(f"\n{'='*60}") + print(f"Testing: {case}") + print(f"{'='*60}") + try: + metrics = bench_fp8_gemm( + M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] + ) + row = { + "Case": case["Case"], + "M": case["M"], + "N": case["N"], + "K": case["K"], + "dtype": str(case["dtype"]), + **metrics, + } + rows.append(row) + except Exception as e: + print(f"FAILED: {case}: {e}") + raise + + results = pd.DataFrame(rows, columns=columns) + + out_csv = "benchmark_fp8_gemm.csv" + results.to_csv(out_csv, index=False) + print(f"\nResults saved to {out_csv}") From 4f6dc865403f75f1cc294a17b448e98cdb8146e9 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 13 Mar 2026 22:04:40 -0500 Subject: [PATCH 12/25] fix name --- benchmarks/microbenchmarks/benchmark_gemm_fp8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/microbenchmarks/benchmark_gemm_fp8.py b/benchmarks/microbenchmarks/benchmark_gemm_fp8.py index 400383749..3d96edbd3 100755 --- a/benchmarks/microbenchmarks/benchmark_gemm_fp8.py +++ b/benchmarks/microbenchmarks/benchmark_gemm_fp8.py @@ -24,7 +24,7 @@ https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json -Output: benchmark_fp8_gemm.csv (written to cwd) +Output: benchmark_gemm_fp8.csv (written to cwd) """ import torch @@ -173,6 +173,6 @@ def fwd_bwd_func(): results = pd.DataFrame(rows, columns=columns) - out_csv = "benchmark_fp8_gemm.csv" + out_csv = "benchmark_gemm_fp8.csv" results.to_csv(out_csv, index=False) print(f"\nResults saved to {out_csv}") From c9d6d4d0300b838961ba244b2ce1c7b6adeb04e0 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 17 Mar 2026 18:30:42 -0500 Subject: [PATCH 13/25] updates casting --- .../microbenchmarks/benchmark_casting.py | 66 ++++++++++--------- 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/benchmarks/microbenchmarks/benchmark_casting.py b/benchmarks/microbenchmarks/benchmark_casting.py index c53810765..5df5ff7a3 100755 --- a/benchmarks/microbenchmarks/benchmark_casting.py +++ b/benchmarks/microbenchmarks/benchmark_casting.py @@ -11,7 +11,7 @@ both E4M3 (activations/weights) and E5M2 (gradients) formats. Shapes are (M, hidden_size) matching the activation tensors from models: - - Llama 3 8B, 70B, 405B + - Llama 3.1 8B, 70B, 405B - Qwen 2.5 7B, 72B These casts are memory-bound; we report GB/s (input + output bytes). @@ -28,15 +28,13 @@ import torch import torch.utils.benchmark as benchmark +import transformer_engine +import transformer_engine_torch as tex +from transformer_engine.pytorch import Float8Quantizer -# Detect FP8 dtypes (ROCm vs CUDA) -if hasattr(torch, "float8_e4m3fnuz"): - FP8_E4M3 = torch.float8_e4m3fnuz - FP8_E5M2 = torch.float8_e5m2fnuz -else: - FP8_E4M3 = torch.float8_e4m3fn - FP8_E5M2 = torch.float8_e5m2 +TE_FP8_E4M3 = tex.DType.kFloat8E4M3 +TE_FP8_E5M2 = tex.DType.kFloat8E5M2 # Sequence / batch-token sizes to sweep M_SIZE_LIST = [1024, 2048, 4096, 8192] @@ -50,51 +48,55 @@ ("Qwen2.5-72B", 8192), ] -# (cast_name, src_dtype, dst_dtype) CAST_CONFIGS = [ - ("BF16-to-FP8-E4M3", torch.bfloat16, FP8_E4M3), - ("FP8-E4M3-to-BF16", FP8_E4M3, torch.bfloat16), - ("BF16-to-FP8-E5M2", torch.bfloat16, FP8_E5M2), - ("FP8-E5M2-to-BF16", FP8_E5M2, torch.bfloat16), + # (name, direction, fp8_dtype) + ("BF16-to-FP8-E4M3", "quantize", TE_FP8_E4M3), + ("FP8-E4M3-to-BF16", "dequantize", TE_FP8_E4M3), + ("BF16-to-FP8-E5M2", "quantize", TE_FP8_E5M2), + ("FP8-E5M2-to-BF16", "dequantize", TE_FP8_E5M2), ] def _generate_cast_test_cases(): test_cases = [] for model_name, hidden in MODEL_HIDDEN_SIZES: - for cast_name, src_dtype, dst_dtype in CAST_CONFIGS: + for cast_name, direction, fp8_dtype in CAST_CONFIGS: for M in M_SIZE_LIST: test_cases.append({ "Case": f"{model_name}/{cast_name}", "M": M, "hidden_size": hidden, - "src_dtype": src_dtype, - "dst_dtype": dst_dtype, + "direction": direction, + "fp8_dtype": fp8_dtype, "dtype_str": cast_name, }) return test_cases -def bench_cast(M, hidden_size, src_dtype, dst_dtype): +def bench_cast(M, hidden_size, direction, fp8_dtype): device = "cuda" - # For FP8 source, create via cast from randn - if src_dtype in (FP8_E4M3, FP8_E5M2): - x = torch.randn(M, hidden_size, dtype=torch.bfloat16, device=device).to(src_dtype) + numel = M * hidden_size + scale = torch.ones(1, dtype=torch.float32, device=device) + amax = torch.zeros(1, dtype=torch.float32, device=device) + quantizer = Float8Quantizer(scale, amax, fp8_dtype) + + if direction == "quantize": + x = torch.randn(M, hidden_size, dtype=torch.bfloat16, device=device) + cast_func = lambda: quantizer(x) + + # BF16 read (2 bytes) + FP8 write (1 byte) + total_bytes = numel * (2 + 1) else: - x = torch.randn(M, hidden_size, dtype=src_dtype, device=device) + x = torch.randn(M, hidden_size, dtype=torch.bfloat16, device=device) + fp8_tensor = quantizer(x) + cast_func = lambda: fp8_tensor.dequantize() - cast_func = lambda: x.to(dst_dtype) + # FP8 read (1 byte) + BF16 write (2 bytes) + total_bytes = numel * (1 + 2) - # Sanity check cast_func() - # Total bytes moved: read input + write output - numel = M * hidden_size - src_bytes = numel * x.element_size() - dst_bytes = numel * cast_func().element_size() - total_bytes = src_bytes + dst_bytes - # Warmup for _ in range(20): cast_func() @@ -131,7 +133,7 @@ def bench_cast(M, hidden_size, src_dtype, dst_dtype): print(f"WARMUP: {c['Case']} M={c['M']} hidden={c['hidden_size']}") print(f"{'='*60}") bench_cast(M=c["M"], hidden_size=c["hidden_size"], - src_dtype=c["src_dtype"], dst_dtype=c["dst_dtype"]) + direction=c["direction"], fp8_dtype=c["fp8_dtype"]) for case in test_cases: print(f"\n{'='*60}") @@ -141,8 +143,8 @@ def bench_cast(M, hidden_size, src_dtype, dst_dtype): metrics = bench_cast( M=case["M"], hidden_size=case["hidden_size"], - src_dtype=case["src_dtype"], - dst_dtype=case["dst_dtype"], + direction=case["direction"], + fp8_dtype=case["fp8_dtype"], ) row = { "Case": case["Case"], From de21a7765856ca7243389521f73382f78cd3c7cb Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 20 Apr 2026 09:16:50 -0500 Subject: [PATCH 14/25] remove attention --- .../microbenchmarks/benchmark_attention.py | 188 ------------------ 1 file changed, 188 deletions(-) delete mode 100755 benchmarks/microbenchmarks/benchmark_attention.py diff --git a/benchmarks/microbenchmarks/benchmark_attention.py b/benchmarks/microbenchmarks/benchmark_attention.py deleted file mode 100755 index 6f419b62d..000000000 --- a/benchmarks/microbenchmarks/benchmark_attention.py +++ /dev/null @@ -1,188 +0,0 @@ -#!/usr/bin/env python -############################################################################### -# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. -# -# See LICENSE for license information. -############################################################################### -""" -Attention micro-benchmark using te.DotProductAttention. - -Benchmarks fused multi-head attention (with flash attention backend) for -model configurations with grouped-query attention (GQA). - -Models: - - Llama 3 8B (TP=1, TP=8), 70B (TP=8), 405B (TP=8) - - Qwen 2.5 7B (TP=1), 72B (TP=8) - -Forward FLOPs = 4 * batch * num_q_heads * seq_len^2 * head_dim - (two matmuls: Q@K^T and attn@V, each contributing 2*b*h*s^2*d) -Backward FLOPs = 2 * Forward FLOPs (approximately) - -Sources for model configs: - https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json - https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json - https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json - https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json - https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json - -Output: benchmark_attention.csv (written to cwd) -""" - -import torch -import torch.utils.benchmark as benchmark - -import transformer_engine.pytorch as te - -# Sweep parameters -BATCH_SIZE = 2 -SEQ_LEN_LIST = [1024, 2048, 4096, 8192] - -# (name, num_q_heads, num_kv_heads, head_dim, tp) -MODEL_CONFIGS = [ - ("Llama3-8B/TP1", 32, 8, 128, 1), - ("Llama3-8B/TP8", 32, 8, 128, 8), - ("Llama3-70B/TP8", 64, 8, 128, 8), - ("Llama3-405B/TP8", 128, 8, 128, 8), - ("Qwen2.5-7B/TP1", 28, 4, 128, 1), - ("Qwen2.5-72B/TP8", 64, 8, 128, 8), -] - - -def _generate_attn_test_cases(): - test_cases = [] - for (name, n_q, n_kv, hd, tp) in MODEL_CONFIGS: - q_per_gpu = n_q // tp - kv_per_gpu = n_kv // tp - if q_per_gpu < 1 or kv_per_gpu < 1: - continue - for seq_len in SEQ_LEN_LIST: - test_cases.append({ - "Case": name, - "batch": BATCH_SIZE, - "seq_len": seq_len, - "num_q_heads": q_per_gpu, - "num_kv_heads": kv_per_gpu, - "head_dim": hd, - }) - return test_cases - - -def bench_attention(batch, seq_len, num_q_heads, num_kv_heads, head_dim): - device = "cuda" - dtype = torch.bfloat16 - - attn = te.DotProductAttention( - num_attention_heads=num_q_heads, - kv_channels=head_dim, - num_gqa_groups=num_kv_heads, - attn_mask_type="causal", - ).to(device=device, dtype=dtype) - - q = torch.randn(seq_len, batch, num_q_heads, head_dim, - dtype=dtype, device=device, requires_grad=True) - k = torch.randn(seq_len, batch, num_kv_heads, head_dim, - dtype=dtype, device=device, requires_grad=True) - v = torch.randn(seq_len, batch, num_kv_heads, head_dim, - dtype=dtype, device=device, requires_grad=True) - - fwd_func = lambda: attn(q, k, v) - out = fwd_func() - grad_out = torch.randn_like(out) - - def fwd_bwd_func(): - out = attn(q, k, v) - out.backward(grad_out) - q.grad = None - k.grad = None - v.grad = None - - fwd_bwd_func() - - # FLOPs: two matmuls (Q@K^T and attn@V), each 2*b*h*s^2*d - fwd_flops = 4 * batch * num_q_heads * seq_len * seq_len * head_dim - bwd_flops = 2 * fwd_flops - - # Warmup - for _ in range(20): - fwd_func() - fwd_bwd_func() - torch.cuda.synchronize() - - # Benchmark - n_iters = 100 - - fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).timeit(n_iters).mean * 1e3 - fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_bwd_func}).timeit(n_iters).mean * 1e3 - - bwd_ms = max(fwd_bwd_ms - fwd_ms, 0.0) - - fwd_tflops = fwd_flops / (fwd_ms * 1e-3) / 1e12 - bwd_tflops = bwd_flops / (bwd_ms * 1e-3) / 1e12 if bwd_ms > 0 else 0.0 - - print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS") - print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)") - - return { - "TE Forward Time (ms)": f"{fwd_ms:.2f}", - "TE Forward TFLOPS": f"{fwd_tflops:.2f}", - "TE Backward Time (ms)": f"{bwd_ms:.2f}", - "TE Backward TFLOPS": f"{bwd_tflops:.2f}", - } - - -if __name__ == "__main__": - import pandas as pd - - test_cases = _generate_attn_test_cases() - - columns = [ - "Case", "batch", "seq_len", "num_q_heads", "num_kv_heads", "head_dim", - "TE Forward Time (ms)", - "TE Forward TFLOPS", - "TE Backward Time (ms)", - "TE Backward TFLOPS", - ] - rows = [] - - # Warmup run - c = test_cases[0] - print(f"\n{'='*60}") - print(f"WARMUP: {c['Case']} b={c['batch']} s={c['seq_len']} " - f"qh={c['num_q_heads']} kvh={c['num_kv_heads']} hd={c['head_dim']}") - print(f"{'='*60}") - bench_attention(batch=c["batch"], seq_len=c["seq_len"], - num_q_heads=c["num_q_heads"], num_kv_heads=c["num_kv_heads"], - head_dim=c["head_dim"]) - - for case in test_cases: - print(f"\n{'='*60}") - print(f"Testing: {case['Case']} b={case['batch']} s={case['seq_len']} " - f"qh={case['num_q_heads']} kvh={case['num_kv_heads']} hd={case['head_dim']}") - print(f"{'='*60}") - try: - metrics = bench_attention( - batch=case["batch"], - seq_len=case["seq_len"], - num_q_heads=case["num_q_heads"], - num_kv_heads=case["num_kv_heads"], - head_dim=case["head_dim"], - ) - row = { - "Case": case["Case"], - "batch": case["batch"], - "seq_len": case["seq_len"], - "num_q_heads": case["num_q_heads"], - "num_kv_heads": case["num_kv_heads"], - "head_dim": case["head_dim"], - **metrics, - } - rows.append(row) - except Exception as e: - print(f"FAILED: {case['Case']}: {e}") - raise - - results = pd.DataFrame(rows, columns=columns) - - out_csv = "benchmark_attention.csv" - results.to_csv(out_csv, index=False) - print(f"\nResults saved to {out_csv}") From 1d6f869df6e45ffec9a7709a72c2812854da8f19 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 20 Apr 2026 09:37:31 -0500 Subject: [PATCH 15/25] fix grouped gemm --- benchmarks/microbenchmarks/benchmark_grouped_gemm.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/benchmarks/microbenchmarks/benchmark_grouped_gemm.py b/benchmarks/microbenchmarks/benchmark_grouped_gemm.py index 7cee6edd4..06759e82d 100755 --- a/benchmarks/microbenchmarks/benchmark_grouped_gemm.py +++ b/benchmarks/microbenchmarks/benchmark_grouped_gemm.py @@ -83,7 +83,6 @@ def generate_grok_v2_test_cases(): def make_fwd_bwd_funcs_te(x, w, group_lens, activation_dtype): - from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace from transformer_engine.pytorch.cpp_extensions import general_grouped_gemm B = int(group_lens.numel()) @@ -99,8 +98,6 @@ def make_fwd_bwd_funcs_te(x, w, group_lens, activation_dtype): xs = list(torch.split(x_view, m_splits)) weights = [w[i] for i in range(B)] - workspaces = get_multi_stream_cublas_workspace() - # Forward output buffer out = torch.empty((sum_M, N), device=x.device, dtype=activation_dtype) @@ -109,8 +106,8 @@ def fwd_func_te(): A=weights, B=xs, out=[out], + quantization_params=[None] * B, out_dtype=activation_dtype, - workspaces=workspaces, single_output=True, m_splits=m_splits, use_bias=False, @@ -135,8 +132,8 @@ def bwd_func_te(grad_out): A=weights, B=splits, out=dxs, + quantization_params=[None] * B, out_dtype=activation_dtype, - workspaces=workspaces, single_output=False, layout="NN", m_splits=m_splits, @@ -149,8 +146,8 @@ def bwd_func_te(grad_out): A=xs, B=splits, out=dws, + quantization_params=[None] * B, out_dtype=activation_dtype, - workspaces=workspaces, single_output=False, layout="NT", m_splits=m_splits, From 12b4218e383b1c1eb36adf59d44324b6bf7a9b60 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 20 Apr 2026 09:59:05 -0500 Subject: [PATCH 16/25] remove CI part --- .github/workflows/rocm-ci.yml | 165 ---------------------------------- 1 file changed, 165 deletions(-) diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 167a86810..c85f1bed2 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -45,9 +45,6 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true -permissions: - pull-requests: write - jobs: build_and_test: name: Build and Test on GPU (${{ matrix.runner }}) - Level ${{ (github.event_name == 'push' && '3') || inputs.test_level || '1' }} @@ -375,168 +372,6 @@ jobs: EOF )" - - name: Performance regression check - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - PR_NUMBER: ${{ github.event.pull_request.number }} - GH_REPO: ${{ github.repository }} - RUNNER_NAME: ${{ matrix.runner }} - run: | - set -ex - - # Map runner names to display names - case "${RUNNER_NAME}" in - linux-te-mi325*) DISPLAY_NAME="MI325" ;; - linux-te-mi355*) DISPLAY_NAME="MI355" ;; - *) DISPLAY_NAME="${RUNNER_NAME}" ;; - esac - - # Restore PR checkout no matter how this step exits, - # in case a later step needs to access the PR code. - # Note that the PR code is *not* recompiled. - trap 'git checkout ${{ github.sha }} && git submodule update --init --recursive' EXIT - - # Benchmark PR branch (already built) - docker exec te-runner bash -c "$(cat <<'OUTER' - set -ex - pip install pandas tabulate - cd /workspace - - mkdir -p perf_results/pr - for bench in benchmarks/microbenchmarks/benchmark_*.py; do - name=$(basename "$bench" .py) - echo "=== Running $name (PR) ===" - python "$bench" - mv "${name}.csv" perf_results/pr/ - done - - # Stash benchmark scripts so they survive the base branch checkout - mkdir -p .perf_stash - cp benchmarks/microbenchmarks/benchmark_*.py benchmarks/microbenchmarks/compare_results.py .perf_stash/ - OUTER - )" - - # Checkout base branch (on host, where git credentials exist) - git fetch origin ${{ github.base_ref }} --depth=1 - git checkout FETCH_HEAD - git submodule update --init --recursive - - # Rebuild base, benchmark, compare, build report - docker exec \ - -e GPU_ARCH=${{ steps.container-diag.outputs.arch }} \ - te-runner bash -c "$(cat <<'OUTER' - set -ex - cd /workspace - - # Rebuild base branch - export HIP_PATH="" - export PYTORCH_ROCM_ARCH=$GPU_ARCH - export NVTE_ROCM_ARCH=$GPU_ARCH - export NVTE_AITER_PREBUILT_BASE_URL=https://compute-artifactory.amd.com:5000/artifactory/rocm-generic-local/te-ci/aiter-prebuilts - pip install ninja - git config --global --add safe.directory '*' - pip install --no-build-isolation . 2>&1 - - # Benchmark base branch - mkdir -p perf_results/base - for bench in .perf_stash/benchmark_*.py; do - name=$(basename "$bench" .py) - echo "=== Running $name (base) ===" - python "$bench" - mv "${name}.csv" perf_results/base/ - done - - # Compare and build report - mkdir -p perf_results/reports - SUMMARY="perf_results/reports/summary.md" - DETAILS="perf_results/reports/details.md" - : > "$SUMMARY" - : > "$DETAILS" - - for pr_csv in perf_results/pr/benchmark_*.csv; do - name=$(basename "$pr_csv" .csv) - base_csv="perf_results/base/${name}.csv" - [ -f "$base_csv" ] || continue - echo "========== Comparing: $name ==========" - python .perf_stash/compare_results.py "$base_csv" "$pr_csv" \ - --bench-name "$name" \ - --summary-file "$SUMMARY" \ - >> "$DETAILS" - done - OUTER - )" - - # Assemble this runner's section - SUMMARY="perf_results/reports/summary.md" - DETAILS="perf_results/reports/details.md" - [ -f "$SUMMARY" ] || exit 0 - - SECTION_START="" - SECTION_END="" - - CI_TRIGGERED_AT="$(TZ='America/Chicago' date -d '${{ github.event.pull_request.updated_at }}' '+%Y-%m-%d %H:%M:%S %Z')" - - SECTION=$(cat <PR commit: ${{ github.sha }} | Base: \`${{ github.base_ref }}\` | ${CI_TRIGGERED_AT} - - | Benchmark suite | Median speedup | Min speedup | Max speedup | - |---|---|---|---| - $(cat "$SUMMARY") - - $(cat "$DETAILS") - ${SECTION_END} - EOF - ) - - echo "$SECTION" > /tmp/perf_section.md - - echo "" - echo "========== Performance Report ==========" - cat /tmp/perf_section.md - echo "========================================" - - # Post or update the single shared PR comment (skip under nektos act) - if [ -n "${ACT:-}" ]; then - echo "Running under nektos act, skipping PR comment." - exit 0 - fi - - COMMENT_MARKER="" - - COMMENT_ID=$(gh api "repos/${GH_REPO}/issues/${PR_NUMBER}/comments" \ - --paginate --jq ".[] | select(.body | contains(\"${COMMENT_MARKER}\")) | .id" \ - | head -1) - - if [ -n "$COMMENT_ID" ]; then - gh api "repos/${GH_REPO}/issues/comments/${COMMENT_ID}" --jq .body \ - > /tmp/perf_existing.md - - if grep -qF "$SECTION_START" /tmp/perf_existing.md; then - awk -v start="$SECTION_START" -v end="$SECTION_END" -v sf="/tmp/perf_section.md" ' - $0 ~ start { skip=1; while((getline l < sf)>0) print l; next } - $0 ~ end { skip=0; next } - !skip { print } - ' /tmp/perf_existing.md > /tmp/perf_comment.md - else - { cat /tmp/perf_existing.md; echo ""; cat /tmp/perf_section.md; } > /tmp/perf_comment.md - fi - - gh api "repos/${GH_REPO}/issues/comments/${COMMENT_ID}" \ - --method PATCH --field body=@/tmp/perf_comment.md - else - { - echo "${COMMENT_MARKER}" - echo "## Performance Report" - echo "" - cat /tmp/perf_section.md - } > /tmp/perf_comment.md - - gh pr comment "$PR_NUMBER" --repo "$GH_REPO" \ - --body-file /tmp/perf_comment.md - fi - - name: Check Test Failure Status if: always() run: | From 75c82913892a7d402ff6b825627a090aae38a365 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 21 Apr 2026 16:34:31 -0500 Subject: [PATCH 17/25] use adaptive_autorange, cleanups --- .../microbenchmarks/benchmark_casting.py | 54 ++++++------------ benchmarks/microbenchmarks/benchmark_gemm.py | 52 ++++++----------- .../microbenchmarks/benchmark_gemm_fp8.py | 56 ++++++------------ .../microbenchmarks/benchmark_grouped_gemm.py | 52 ++++++----------- .../benchmark_normalization.py | 57 +++++++------------ 5 files changed, 87 insertions(+), 184 deletions(-) diff --git a/benchmarks/microbenchmarks/benchmark_casting.py b/benchmarks/microbenchmarks/benchmark_casting.py index 5df5ff7a3..2db87d190 100755 --- a/benchmarks/microbenchmarks/benchmark_casting.py +++ b/benchmarks/microbenchmarks/benchmark_casting.py @@ -83,7 +83,8 @@ def bench_cast(M, hidden_size, direction, fp8_dtype): if direction == "quantize": x = torch.randn(M, hidden_size, dtype=torch.bfloat16, device=device) - cast_func = lambda: quantizer(x) + out = quantizer(x) # pre-allocate output tensor + cast_func = lambda: quantizer.quantize(x, out=out) # BF16 read (2 bytes) + FP8 write (1 byte) total_bytes = numel * (2 + 1) @@ -95,16 +96,8 @@ def bench_cast(M, hidden_size, direction, fp8_dtype): # FP8 read (1 byte) + BF16 write (2 bytes) total_bytes = numel * (1 + 2) - cast_func() - - # Warmup - for _ in range(20): - cast_func() - torch.cuda.synchronize() - # Benchmark - n_iters = 100 - ms = benchmark.Timer(stmt="fn()", globals={"fn": cast_func}).timeit(n_iters).mean * 1e3 + ms = benchmark.Timer(stmt="fn()", globals={"fn": cast_func}).blocked_autorange().mean * 1e3 gbps = total_bytes / (ms * 1e-3) / 1e9 print(f" {ms:.4f} ms | {gbps:.1f} GB/s") @@ -127,36 +120,25 @@ def bench_cast(M, hidden_size, direction, fp8_dtype): ] rows = [] - # Warmup run - c = test_cases[0] - print(f"\n{'='*60}") - print(f"WARMUP: {c['Case']} M={c['M']} hidden={c['hidden_size']}") - print(f"{'='*60}") - bench_cast(M=c["M"], hidden_size=c["hidden_size"], - direction=c["direction"], fp8_dtype=c["fp8_dtype"]) - for case in test_cases: print(f"\n{'='*60}") print(f"Testing: {case['Case']} M={case['M']} hidden={case['hidden_size']}") print(f"{'='*60}") - try: - metrics = bench_cast( - M=case["M"], - hidden_size=case["hidden_size"], - direction=case["direction"], - fp8_dtype=case["fp8_dtype"], - ) - row = { - "Case": case["Case"], - "M": case["M"], - "hidden_size": case["hidden_size"], - "dtype_str": case["dtype_str"], - **metrics, - } - rows.append(row) - except Exception as e: - print(f"FAILED: {case['Case']}: {e}") - raise + + metrics = bench_cast( + M=case["M"], + hidden_size=case["hidden_size"], + direction=case["direction"], + fp8_dtype=case["fp8_dtype"], + ) + row = { + "Case": case["Case"], + "M": case["M"], + "hidden_size": case["hidden_size"], + "dtype_str": case["dtype_str"], + **metrics, + } + rows.append(row) results = pd.DataFrame(rows, columns=columns) diff --git a/benchmarks/microbenchmarks/benchmark_gemm.py b/benchmarks/microbenchmarks/benchmark_gemm.py index cd651c172..d028c7068 100755 --- a/benchmarks/microbenchmarks/benchmark_gemm.py +++ b/benchmarks/microbenchmarks/benchmark_gemm.py @@ -87,22 +87,14 @@ def bwd_func(): fwd_flops = 2 * M * N * K bwd_flops = 2 * fwd_flops # dX + dW - # Warmup - for _ in range(20): - fwd_func() - bwd_func() - torch.cuda.synchronize() - # Benchmark - n_iters = 100 - - fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).timeit(n_iters).mean * 1e3 - fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": bwd_func}).timeit(n_iters).mean * 1e3 + fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).adaptive_autorange().mean * 1e3 + fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": bwd_func}).adaptive_autorange().mean * 1e3 - bwd_ms = max(fwd_bwd_ms - fwd_ms, 0.0) + bwd_ms = fwd_bwd_ms - fwd_ms fwd_tflops = fwd_flops / (fwd_ms * 1e-3) / 1e12 - bwd_tflops = bwd_flops / (bwd_ms * 1e-3) / 1e12 if bwd_ms > 0 else 0.0 + bwd_tflops = bwd_flops / (bwd_ms * 1e-3) / 1e12 print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS") print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)") @@ -129,33 +121,23 @@ def bwd_func(): ] rows = [] - # Warmup run - c = test_cases[0] - print(f"\n{'='*60}") - print(f"WARMUP: {c}") - print(f"{'='*60}") - bench_gemm(M=c["M"], N=c["N"], K=c["K"], dtype=c["dtype"]) - for case in test_cases: print(f"\n{'='*60}") print(f"Testing: {case}") print(f"{'='*60}") - try: - metrics = bench_gemm( - M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] - ) - row = { - "Case": case["Case"], - "M": case["M"], - "N": case["N"], - "K": case["K"], - "dtype": str(case["dtype"]), - **metrics, - } - rows.append(row) - except Exception as e: - print(f"FAILED: {case}: {e}") - raise + + metrics = bench_gemm( + M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] + ) + row = { + "Case": case["Case"], + "M": case["M"], + "N": case["N"], + "K": case["K"], + "dtype": str(case["dtype"]), + **metrics, + } + rows.append(row) results = pd.DataFrame(rows, columns=columns) diff --git a/benchmarks/microbenchmarks/benchmark_gemm_fp8.py b/benchmarks/microbenchmarks/benchmark_gemm_fp8.py index 3d96edbd3..7c910c2c9 100755 --- a/benchmarks/microbenchmarks/benchmark_gemm_fp8.py +++ b/benchmarks/microbenchmarks/benchmark_gemm_fp8.py @@ -94,29 +94,17 @@ def fwd_bwd_func(): x.grad = None linear.weight.grad = None - # Sanity run - fwd_func() - fwd_bwd_func() - fwd_flops = 2 * M * N * K bwd_flops = 2 * fwd_flops - # Warmup - for _ in range(20): - fwd_func() - fwd_bwd_func() - torch.cuda.synchronize() - # Benchmark - n_iters = 100 - - fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).timeit(n_iters).mean * 1e3 - fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_bwd_func}).timeit(n_iters).mean * 1e3 + fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).adaptive_autorange().mean * 1e3 + fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_bwd_func}).adaptive_autorange().mean * 1e3 - bwd_ms = max(fwd_bwd_ms - fwd_ms, 0.0) + bwd_ms = fwd_bwd_ms - fwd_ms fwd_tflops = fwd_flops / (fwd_ms * 1e-3) / 1e12 - bwd_tflops = bwd_flops / (bwd_ms * 1e-3) / 1e12 if bwd_ms > 0 else 0.0 + bwd_tflops = bwd_flops / (bwd_ms * 1e-3) / 1e12 print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS") print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)") @@ -143,33 +131,23 @@ def fwd_bwd_func(): ] rows = [] - # Warmup run - c = test_cases[0] - print(f"\n{'='*60}") - print(f"WARMUP: {c}") - print(f"{'='*60}") - bench_fp8_gemm(M=c["M"], N=c["N"], K=c["K"], dtype=c["dtype"]) - for case in test_cases: print(f"\n{'='*60}") print(f"Testing: {case}") print(f"{'='*60}") - try: - metrics = bench_fp8_gemm( - M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] - ) - row = { - "Case": case["Case"], - "M": case["M"], - "N": case["N"], - "K": case["K"], - "dtype": str(case["dtype"]), - **metrics, - } - rows.append(row) - except Exception as e: - print(f"FAILED: {case}: {e}") - raise + + metrics = bench_fp8_gemm( + M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] + ) + row = { + "Case": case["Case"], + "M": case["M"], + "N": case["N"], + "K": case["K"], + "dtype": str(case["dtype"]), + **metrics, + } + rows.append(row) results = pd.DataFrame(rows, columns=columns) diff --git a/benchmarks/microbenchmarks/benchmark_grouped_gemm.py b/benchmarks/microbenchmarks/benchmark_grouped_gemm.py index 06759e82d..81fa1b1cf 100755 --- a/benchmarks/microbenchmarks/benchmark_grouped_gemm.py +++ b/benchmarks/microbenchmarks/benchmark_grouped_gemm.py @@ -183,24 +183,14 @@ def bench_grouped_gemm(B, M, N, K, dtype): out_te = fwd_func_te() grad_out = torch.randn_like(out_te) bwd_func_te = lambda: bwd_func_te_inner(grad_out) - dx_te, dw_te = bwd_func_te() # FLOPs fwd_total_flops = 2 * B * M * N * K bwd_total_flops = 2 * fwd_total_flops - # Warmup - for _ in range(20): - fwd_func_te() - bwd_func_te() - - torch.cuda.synchronize() - # Benchmark - n_iters = 100 - - fwd_te_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func_te}).timeit(n_iters).mean * 1e3 - bwd_te_ms = benchmark.Timer(stmt="fn()", globals={"fn": bwd_func_te}).timeit(n_iters).mean * 1e3 + fwd_te_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func_te}).adaptive_autorange().mean * 1e3 + bwd_te_ms = benchmark.Timer(stmt="fn()", globals={"fn": bwd_func_te}).adaptive_autorange().mean * 1e3 fwd_te_tflops = fwd_total_flops / (fwd_te_ms * 1e-3) / 1e12 bwd_te_tflops = bwd_total_flops / (bwd_te_ms * 1e-3) / 1e12 @@ -235,34 +225,24 @@ def bench_grouped_gemm(B, M, N, K, dtype): ] rows = [] - # Warmup run - c = test_cases[0] - print(f"\n{'='*50}") - print(f"WARMUP: {c}") - print(f"{'='*50}") - bench_grouped_gemm(B=c["B"], M=c["M"], N=c["N"], K=c["K"], dtype=c["dtype"]) - for case in test_cases: print(f"\n{'='*50}") print(f"Testing: {case}") print(f"{'='*50}") - try: - metrics = bench_grouped_gemm( - B=case["B"], M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] - ) - row = { - "Case": case["Case"], - "B": case["B"], - "M": case["M"], - "N": case["N"], - "K": case["K"], - "dtype": str(case["dtype"]), - **metrics, - } - rows.append(row) - except Exception as e: - print(f"FAILED: {case}: {e}") - raise + + metrics = bench_grouped_gemm( + B=case["B"], M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] + ) + row = { + "Case": case["Case"], + "B": case["B"], + "M": case["M"], + "N": case["N"], + "K": case["K"], + "dtype": str(case["dtype"]), + **metrics, + } + rows.append(row) results = pd.DataFrame(rows, columns=columns) diff --git a/benchmarks/microbenchmarks/benchmark_normalization.py b/benchmarks/microbenchmarks/benchmark_normalization.py index 1caa04f43..3bebfc39b 100755 --- a/benchmarks/microbenchmarks/benchmark_normalization.py +++ b/benchmarks/microbenchmarks/benchmark_normalization.py @@ -94,22 +94,14 @@ def fwd_bwd_func(): fwd_bytes = 2 * M * hidden_size * elem_bytes # read x, write y bwd_bytes = 4 * M * hidden_size * elem_bytes # read grad+x+y, write grad_x - # Warmup - for _ in range(20): - fwd_func() - fwd_bwd_func() - torch.cuda.synchronize() - # Benchmark - n_iters = 100 - - fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).timeit(n_iters).mean * 1e3 - fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_bwd_func}).timeit(n_iters).mean * 1e3 + fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).adaptive_autorange().mean * 1e3 + fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_bwd_func}).adaptive_autorange().mean * 1e3 - bwd_ms = max(fwd_bwd_ms - fwd_ms, 0.0) + bwd_ms = fwd_bwd_ms - fwd_ms fwd_gbps = fwd_bytes / (fwd_ms * 1e-3) / 1e9 - bwd_gbps = bwd_bytes / (bwd_ms * 1e-3) / 1e9 if bwd_ms > 0 else 0.0 + bwd_gbps = bwd_bytes / (bwd_ms * 1e-3) / 1e9 print(f" Forward {fwd_ms:.3f} ms | {fwd_gbps:.1f} GB/s") print(f" Backward {bwd_ms:.3f} ms | {bwd_gbps:.1f} GB/s (derived)") @@ -136,36 +128,25 @@ def fwd_bwd_func(): ] rows = [] - # Warmup run - c = test_cases[0] - print(f"\n{'='*60}") - print(f"WARMUP: {c['Case']} M={c['M']} hidden={c['hidden_size']}") - print(f"{'='*60}") - bench_norm(M=c["M"], hidden_size=c["hidden_size"], - norm_cls=c["norm_cls"], dtype=c["dtype"]) - for case in test_cases: print(f"\n{'='*60}") print(f"Testing: {case['Case']} M={case['M']} hidden={case['hidden_size']}") print(f"{'='*60}") - try: - metrics = bench_norm( - M=case["M"], - hidden_size=case["hidden_size"], - norm_cls=case["norm_cls"], - dtype=case["dtype"], - ) - row = { - "Case": case["Case"], - "M": case["M"], - "hidden_size": case["hidden_size"], - "dtype": str(case["dtype"]), - **metrics, - } - rows.append(row) - except Exception as e: - print(f"FAILED: {case['Case']}: {e}") - raise + + metrics = bench_norm( + M=case["M"], + hidden_size=case["hidden_size"], + norm_cls=case["norm_cls"], + dtype=case["dtype"], + ) + row = { + "Case": case["Case"], + "M": case["M"], + "hidden_size": case["hidden_size"], + "dtype": str(case["dtype"]), + **metrics, + } + rows.append(row) results = pd.DataFrame(rows, columns=columns) From 2e6da68d8ea881bab2d73716d7bd76b5a381d37b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 21 Apr 2026 16:48:52 -0500 Subject: [PATCH 18/25] add csv to asv converter --- benchmarks/microbenchmarks/asv.conf.json | 15 + benchmarks/microbenchmarks/csv_to_asv.py | 455 +++++++++++++++++++++++ 2 files changed, 470 insertions(+) create mode 100644 benchmarks/microbenchmarks/asv.conf.json create mode 100755 benchmarks/microbenchmarks/csv_to_asv.py diff --git a/benchmarks/microbenchmarks/asv.conf.json b/benchmarks/microbenchmarks/asv.conf.json new file mode 100644 index 000000000..0645b7212 --- /dev/null +++ b/benchmarks/microbenchmarks/asv.conf.json @@ -0,0 +1,15 @@ +{ + "version": 1, + "project": "TransformerEngine", + "project_url": "https://github.com/ROCm/TransformerEngine", + "repo": "../..", + "branches": ["HEAD"], + "environment_type": "existing", + "install_command": [], + "build_command": [], + "benchmark_dir": ".", + "results_dir": "../.asv/results", + "html_dir": "../.asv/html", + "benchmark_timeout": 1200, + "launch_method": "spawn" +} diff --git a/benchmarks/microbenchmarks/csv_to_asv.py b/benchmarks/microbenchmarks/csv_to_asv.py new file mode 100755 index 000000000..b6ddace59 --- /dev/null +++ b/benchmarks/microbenchmarks/csv_to_asv.py @@ -0,0 +1,455 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +Convert benchmark CSV files (from benchmark_*.py) to ASV-compatible JSON. + +Reads one or more CSV files produced by the microbenchmarks and writes +the same JSON result format that asv requires, so results +can be visualised with ``asv publish && asv preview`` or compared with +``asv compare``. + +Usage: + # Convert all CSVs in the current directory + python csv_to_asv.py benchmark_gemm.csv benchmark_casting.csv ... + + # Convert all benchmark CSVs found in a directory + python csv_to_asv.py perf_results/pr/*.csv + + # Specify output directory (default: benchmarks/.asv/results) + python csv_to_asv.py --results-dir ./my_results benchmark_gemm.csv + + # Provide a custom machine name or commit hash + python csv_to_asv.py --machine mi325 --commit abc1234 *.csv +""" + +import argparse +import glob +import hashlib +import json +import os +import platform +import subprocess +import sys +import time + +import pandas as pd + + +# --------------------------------------------------------------------------- +# Column classification +# --------------------------------------------------------------------------- + +# Columns that are never parameters and never metrics +_SKIP_COLS = {"TestID", "Label"} + + +def _classify_columns(df): + """Split DataFrame columns into (key_cols, time_cols, throughput_cols). + + Heuristic: + - Columns containing "Time" and "(ms)" are time metrics. + - Columns containing "TFLOPS" or "GB/s" are throughput metrics. + - Everything else (except _SKIP_COLS) is a key/parameter column. + """ + time_cols = [] + throughput_cols = [] + key_cols = [] + + for c in df.columns: + if c in _SKIP_COLS: + continue + if "Time" in c and "(ms)" in c: + time_cols.append(c) + elif "TFLOPS" in c or "GB/s" in c: + throughput_cols.append(c) + else: + key_cols.append(c) + + return key_cols, time_cols, throughput_cols + + +def _pair_time_throughput(time_cols, throughput_cols): + """Pair each time column with its throughput companion, if any. + + Matching heuristic: strip the distinctive suffixes and compare the + remaining prefix. E.g. + "TE Forward Time (ms)" <-> "TE Forward TFLOPS" + "Cast Time (ms)" <-> "Cast GB/s" + + Returns a list of (time_col, throughput_col_or_None) tuples. + """ + + def _time_key(col): + return col.replace(" Time (ms)", "").strip() + + def _tp_key(col): + for suffix in (" TFLOPS", " GB/s"): + if col.endswith(suffix): + return col[: -len(suffix)].strip() + return col.strip() + + tp_by_key = {} + for tc in throughput_cols: + tp_by_key[_tp_key(tc)] = tc + + pairs = [] + matched_tp = set() + for tc in time_cols: + key = _time_key(tc) + companion = tp_by_key.get(key) + pairs.append((tc, companion)) + if companion: + matched_tp.add(companion) + + # Standalone throughput columns (no matching time col) + for tc in throughput_cols: + if tc not in matched_tp: + pairs.append((None, tc)) + + return pairs + + +# --------------------------------------------------------------------------- +# ASV helpers (mirrored from PR #487 driver.py) +# --------------------------------------------------------------------------- + +def _get_machine_info(): + """Build the params / machine dict that ASV expects.""" + machine = platform.node() + info = { + "arch": platform.machine(), + "cpu": "", + "machine": machine, + "num_cpu": str(os.cpu_count()), + "os": f"{platform.system()} {platform.release()}", + "ram": "", + } + try: + with open("/proc/cpuinfo") as f: + for line in f: + if line.startswith("model name"): + info["cpu"] = line.split(":", 1)[1].strip() + break + with open("/proc/meminfo") as f: + for line in f: + if line.startswith("MemTotal"): + info["ram"] = line.split()[1] # kB + break + except OSError: + pass + return machine, info + + +def _get_commit_hash(): + """Get the current git HEAD hash.""" + try: + return ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL + ) + .decode() + .strip() + ) + except Exception: + return "unknown" + + +def _format_param_value(v): + """Format a parameter value the way ASV stores it in JSON.""" + if isinstance(v, str): + return f"'{v}'" + return repr(v) + + +def _make_version(suite_name, bench_name, param_names): + """Deterministic version hash for a benchmark entry.""" + code = f"{suite_name}.{bench_name}({', '.join(param_names)})" + return hashlib.sha256(code.encode()).hexdigest() + + +# --------------------------------------------------------------------------- +# Conversion +# --------------------------------------------------------------------------- + +def csv_to_asv_entries(csv_path): + """Convert one CSV file into ASV result + meta dicts. + + Returns (results_dict, meta_dict) where each key is a fully-qualified + benchmark name like ``benchmark_gemm.time_forward``. + """ + df = pd.read_csv(csv_path) + if df.empty: + return {}, {} + + suite_name = os.path.splitext(os.path.basename(csv_path))[0] + key_cols, time_cols, throughput_cols = _classify_columns(df) + pairs = _pair_time_throughput(time_cols, throughput_cols) + + # Build the ASV parameter axes from unique values in key columns + param_names = list(key_cols) + asv_params = [] + for col in key_cols: + asv_params.append([_format_param_value(v) for v in df[col].unique().tolist()]) + + # Build a lookup from key tuple -> row index for fast access + key_tuples = [tuple(row) for row in df[key_cols].values] + + # Cross-product of unique param values (in the order they appear) + unique_per_col = [df[col].unique().tolist() for col in key_cols] + import itertools + + all_combos = list(itertools.product(*unique_per_col)) + + combo_to_idx = {} + for idx, kt in enumerate(key_tuples): + combo_to_idx[kt] = idx + + results = {} + meta = {} + now_ms = int(time.time() * 1000) + + for time_col, tp_col in pairs: + # Derive a short benchmark name + if time_col: + # e.g. "TE Forward Time (ms)" -> "time_te_forward" + short = time_col.replace(" Time (ms)", "").strip() + short = "time_" + short.lower().replace(" ", "_").replace("(", "").replace(")", "") + bench_key = f"{suite_name}.{short}" + elif tp_col: + short = tp_col.strip() + for suffix in (" TFLOPS", " GB/s"): + short = short.replace(suffix, "") + short = short.strip().lower().replace(" ", "_").replace("(", "").replace(")", "") + bench_key = f"{suite_name}.throughput_{short}" + else: + continue + + version = _make_version(suite_name, short, param_names) + + # Populate values for every combo in the cross-product + time_values = [] + for combo in all_combos: + idx = combo_to_idx.get(combo) + if idx is not None and time_col and time_col in df.columns: + val = df.loc[idx, time_col] + try: + # Convert ms -> seconds (ASV convention) + time_values.append(float(val) / 1000.0) + except (ValueError, TypeError): + time_values.append(None) + else: + time_values.append(None) + + # Store the time benchmark + if time_col: + n = len(time_values) + results[bench_key] = [ + time_values, # result (medians) + asv_params, # params + version, # version + now_ms, # started_at + 0, # duration + [None] * n, # stats_ci_99_a + [None] * n, # stats_ci_99_b + [None] * n, # stats_q_25 + [None] * n, # stats_q_75 + [1] * n, # stats_number + [1] * n, # stats_repeat + ] + meta[bench_key] = { + "code": "", + "name": bench_key, + "param_names": param_names, + "params": asv_params, + "timeout": 300, + "type": "time", + "unit": "seconds", + "version": version, + } + + # Store the throughput companion + if tp_col: + tp_values = [] + for combo in all_combos: + idx = combo_to_idx.get(combo) + if idx is not None and tp_col in df.columns: + val = df.loc[idx, tp_col] + try: + tp_values.append(float(val)) + except (ValueError, TypeError): + tp_values.append(None) + else: + tp_values.append(None) + + if "TFLOPS" in tp_col: + tp_unit = "TFLOPS" + elif "GB/s" in tp_col: + tp_unit = "GB/s" + else: + tp_unit = "" + + tp_short = tp_col.strip() + for suffix in (" TFLOPS", " GB/s"): + tp_short = tp_short.replace(suffix, "") + tp_short = tp_short.strip().lower().replace(" ", "_").replace("(", "").replace(")", "") + tp_key = f"{suite_name}.throughput_{tp_short}" + tp_version = _make_version(suite_name, f"throughput_{tp_short}", param_names) + + n = len(tp_values) + results[tp_key] = [ + tp_values, + asv_params, + tp_version, + now_ms, + 0, + [None] * n, + [None] * n, + [None] * n, + [None] * n, + [1] * n, + [1] * n, + ] + meta[tp_key] = { + "code": "", + "name": tp_key, + "param_names": param_names, + "params": asv_params, + "timeout": 300, + "type": "time", + "unit": tp_unit, + "version": tp_version, + } + + return results, meta + + +def save_asv_results(all_results, all_meta, results_dir, machine_name=None, + commit_hash=None): + """Write results and benchmark index to ASV's results directory.""" + if commit_hash is None: + commit_hash = _get_commit_hash() + detected_machine, machine_info = _get_machine_info() + if machine_name: + machine_info["machine"] = machine_name + else: + machine_name = detected_machine + + env_name = "existing-" + sys.executable.replace("/", "_").strip("_") + machine_dir = os.path.join(results_dir, machine_name) + os.makedirs(machine_dir, exist_ok=True) + + # Write machine.json if missing + machine_json = os.path.join(machine_dir, "machine.json") + if not os.path.exists(machine_json): + with open(machine_json, "w") as f: + json.dump({**machine_info, "version": 1}, f, indent=4) + + # Load existing result file or start fresh + filename = f"{commit_hash[:8]}-{env_name}.json" + result_path = os.path.join(machine_dir, filename) + if os.path.exists(result_path): + with open(result_path) as f: + data = json.load(f) + else: + data = { + "commit_hash": commit_hash, + "env_name": env_name, + "date": int(time.time() * 1000), + "params": {**machine_info, "python": sys.executable}, + "python": sys.executable, + "requirements": {}, + "env_vars": {}, + "result_columns": [ + "result", "params", "version", + "started_at", "duration", + "stats_ci_99_a", "stats_ci_99_b", + "stats_q_25", "stats_q_75", + "stats_number", "stats_repeat", + ], + "results": {}, + "durations": {}, + "version": 2, + } + + # Merge new results + for bench_key, bench_data in all_results.items(): + data["results"][bench_key] = bench_data + + with open(result_path, "w") as f: + json.dump(data, f, indent=2) + + print(f"Results saved to {result_path}") + + # Update benchmarks.json index + benchmarks_path = os.path.join(results_dir, "benchmarks.json") + if os.path.exists(benchmarks_path): + with open(benchmarks_path) as f: + benchmarks_data = json.load(f) + else: + benchmarks_data = {"version": 2} + + benchmarks_data.update(all_meta) + + with open(benchmarks_path, "w") as f: + json.dump(benchmarks_data, f, indent=4) + + print(f"Updated {benchmarks_path}") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Convert benchmark CSV files to ASV-compatible JSON.") + parser.add_argument("csv_files", nargs="+", + help="CSV files produced by benchmark_*.py") + parser.add_argument("--results-dir", default=None, + help="ASV results directory " + "(default: benchmarks/.asv/results relative to repo root)") + parser.add_argument("--machine", default=None, + help="Machine name for ASV (default: hostname)") + parser.add_argument("--commit", default=None, + help="Commit hash (default: git rev-parse HEAD)") + args = parser.parse_args() + + if args.results_dir is None: + # Default: benchmarks/.asv/results relative to repo root + try: + repo_root = ( + subprocess.check_output( + ["git", "rev-parse", "--show-toplevel"], stderr=subprocess.DEVNULL + ) + .decode() + .strip() + ) + except Exception: + repo_root = os.getcwd() + args.results_dir = os.path.join(repo_root, "benchmarks", ".asv", "results") + + all_results = {} + all_meta = {} + + for csv_path in args.csv_files: + for f in glob.glob(csv_path): + print(f"Processing {f} ...") + results, meta_data = csv_to_asv_entries(f) + all_results.update(results) + all_meta.update(meta_data) + print(f" {len(results)} benchmark entries extracted") + + if not all_results: + print("No benchmark data found.") + return + + save_asv_results(all_results, all_meta, args.results_dir, + machine_name=args.machine, commit_hash=args.commit) + + +if __name__ == "__main__": + main() From aa8997c0bbe67d35455d9a119fdc2f53303c827e Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 7 May 2026 12:20:59 -0500 Subject: [PATCH 19/25] refactor --- .../microbenchmarks/benchmark_casting.py | 90 ++--------- benchmarks/microbenchmarks/benchmark_gemm.py | 134 +++++---------- .../microbenchmarks/benchmark_gemm_fp8.py | 124 ++++---------- .../microbenchmarks/benchmark_grouped_gemm.py | 67 +++----- .../benchmark_normalization.py | 99 +++--------- benchmarks/microbenchmarks/utils.py | 152 ++++++++++++++++++ 6 files changed, 279 insertions(+), 387 deletions(-) create mode 100644 benchmarks/microbenchmarks/utils.py diff --git a/benchmarks/microbenchmarks/benchmark_casting.py b/benchmarks/microbenchmarks/benchmark_casting.py index 2db87d190..0d7878620 100755 --- a/benchmarks/microbenchmarks/benchmark_casting.py +++ b/benchmarks/microbenchmarks/benchmark_casting.py @@ -10,44 +10,22 @@ Benchmarks quantization (BF16 -> FP8) and dequantization (FP8 -> BF16) for both E4M3 (activations/weights) and E5M2 (gradients) formats. -Shapes are (M, hidden_size) matching the activation tensors from models: - - Llama 3.1 8B, 70B, 405B - - Qwen 2.5 7B, 72B - These casts are memory-bound; we report GB/s (input + output bytes). - -Sources for model configs: - https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json - https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json - https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json - https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json - https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json - Output: benchmark_casting.csv (written to cwd) """ import torch -import torch.utils.benchmark as benchmark import transformer_engine import transformer_engine_torch as tex from transformer_engine.pytorch import Float8Quantizer - +from utils import ( + MODEL_HIDDEN_SIZES, M_SIZE_LIST, + time_func, compute_gbps, run_benchmarks, +) TE_FP8_E4M3 = tex.DType.kFloat8E4M3 TE_FP8_E5M2 = tex.DType.kFloat8E5M2 -# Sequence / batch-token sizes to sweep -M_SIZE_LIST = [1024, 2048, 4096, 8192] - -# (model_name, hidden_size) -MODEL_HIDDEN_SIZES = [ - ("Llama3-8B", 4096), - ("Llama3-70B", 8192), - ("Llama3-405B", 16384), - ("Qwen2.5-7B", 3584), - ("Qwen2.5-72B", 8192), -] - CAST_CONFIGS = [ # (name, direction, fp8_dtype) ("BF16-to-FP8-E4M3", "quantize", TE_FP8_E4M3), @@ -73,7 +51,7 @@ def _generate_cast_test_cases(): return test_cases -def bench_cast(M, hidden_size, direction, fp8_dtype): +def bench_cast(Case, M, hidden_size, direction, fp8_dtype, dtype_str): device = "cuda" numel = M * hidden_size @@ -83,22 +61,17 @@ def bench_cast(M, hidden_size, direction, fp8_dtype): if direction == "quantize": x = torch.randn(M, hidden_size, dtype=torch.bfloat16, device=device) - out = quantizer(x) # pre-allocate output tensor + out = quantizer(x) cast_func = lambda: quantizer.quantize(x, out=out) - - # BF16 read (2 bytes) + FP8 write (1 byte) - total_bytes = numel * (2 + 1) + total_bytes = numel * (2 + 1) # BF16 read + FP8 write else: x = torch.randn(M, hidden_size, dtype=torch.bfloat16, device=device) fp8_tensor = quantizer(x) cast_func = lambda: fp8_tensor.dequantize() + total_bytes = numel * (1 + 2) # FP8 read + BF16 write - # FP8 read (1 byte) + BF16 write (2 bytes) - total_bytes = numel * (1 + 2) - - # Benchmark - ms = benchmark.Timer(stmt="fn()", globals={"fn": cast_func}).blocked_autorange().mean * 1e3 - gbps = total_bytes / (ms * 1e-3) / 1e9 + ms = time_func(cast_func, method="blocked") + gbps = compute_gbps(total_bytes, ms) print(f" {ms:.4f} ms | {gbps:.1f} GB/s") @@ -109,39 +82,10 @@ def bench_cast(M, hidden_size, direction, fp8_dtype): if __name__ == "__main__": - import pandas as pd - - test_cases = _generate_cast_test_cases() - - columns = [ - "Case", "M", "hidden_size", "dtype_str", - "Cast Time (ms)", - "Cast GB/s", - ] - rows = [] - - for case in test_cases: - print(f"\n{'='*60}") - print(f"Testing: {case['Case']} M={case['M']} hidden={case['hidden_size']}") - print(f"{'='*60}") - - metrics = bench_cast( - M=case["M"], - hidden_size=case["hidden_size"], - direction=case["direction"], - fp8_dtype=case["fp8_dtype"], - ) - row = { - "Case": case["Case"], - "M": case["M"], - "hidden_size": case["hidden_size"], - "dtype_str": case["dtype_str"], - **metrics, - } - rows.append(row) - - results = pd.DataFrame(rows, columns=columns) - - out_csv = "benchmark_casting.csv" - results.to_csv(out_csv, index=False) - print(f"\nResults saved to {out_csv}") + run_benchmarks( + test_cases=_generate_cast_test_cases(), + bench_fn=bench_cast, + param_columns=["Case", "M", "hidden_size", "dtype_str"], + metric_columns=["Cast Time (ms)", "Cast GB/s"], + default_csv="benchmark_casting.csv", + ) diff --git a/benchmarks/microbenchmarks/benchmark_gemm.py b/benchmarks/microbenchmarks/benchmark_gemm.py index d028c7068..36c7bbf1c 100755 --- a/benchmarks/microbenchmarks/benchmark_gemm.py +++ b/benchmarks/microbenchmarks/benchmark_gemm.py @@ -7,65 +7,42 @@ import torch -import torch.utils.benchmark as benchmark - import transformer_engine.pytorch as te - -# Sequence / batch-token sizes to sweep -M_SIZE_LIST = [1024, 2048, 4096, 8192] - -# Model configurations -# Sources: -# - Llama 3 8B (hidden=4096, intermediate=14336, heads=32, kv_heads=8, head_dim=128) -# https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json - -# - Llama 3 70B (hidden=8192, intermediate=28672, heads=64, kv_heads=8, head_dim=128) -# https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json - -# - Llama 3 405B (hidden=16384, intermediate=53248, heads=128, kv_heads=8, head_dim=128) -# https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json - -# - Qwen 2.5 7B (hidden=3584, intermediate=18944, heads=28, kv_heads=4, head_dim=128) -# https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json - -# - Qwen 2.5 72B (hidden=8192, intermediate=29568, heads=64, kv_heads=8, head_dim=128) -# https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json - -MODEL_CONFIGS = [ - # (name, hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) - ("Llama3-8B/TP1", 4096, 14336, 32, 8, 128, 1), - ("Llama3-8B/TP8", 4096, 14336, 32, 8, 128, 8), - ("Llama3-70B/TP8", 8192, 28672, 64, 8, 128, 8), - ("Llama3-405B/TP8", 16384, 53248, 128, 8, 128, 8), - ("Qwen2.5-7B/TP1", 3584, 18944, 28, 4, 128, 1), - ("Qwen2.5-72B/TP8", 8192, 29568, 64, 8, 128, 8), +from utils import ( + MODEL_CONFIGS, M_SIZE_LIST, gemm_shapes, + time_func, compute_tflops, run_benchmarks, +) + +# Select which configs / shapes to run (comment/uncomment as needed) +ACTIVE_CONFIGS = [ + MODEL_CONFIGS[0], # Llama3-8B/TP1 + # MODEL_CONFIGS[1], # Llama3-8B/TP8 + # MODEL_CONFIGS[2], # Llama3-70B/TP8 + # MODEL_CONFIGS[3], # Llama3-405B/TP8 + # MODEL_CONFIGS[4], # Qwen2.5-7B/TP1 + # MODEL_CONFIGS[5], # Qwen2.5-72B/TP8 ] +ACTIVE_SHAPES = gemm_shapes(ACTIVE_CONFIGS) +# To restrict shapes, filter the dict: +ACTIVE_SHAPES = {k: v for k, v in ACTIVE_SHAPES.items() if "QKV" in k} + def _generate_gemm_test_cases(): test_cases = [] - - for (name, hidden, intermediate, n_q, n_kv, hd, tp) in MODEL_CONFIGS: - shapes = { - f"{name}-QKV": ((n_q * hd + 2 * n_kv * hd) // tp, hidden), - f"{name}-AttnOut": (hidden, (n_q * hd) // tp), - f"{name}-GateUp": ((2 * intermediate) // tp, hidden), - f"{name}-Down": (hidden, intermediate // tp), - } - - for M in M_SIZE_LIST: - for case_name, (N, K) in shapes.items(): - test_cases.append({ - "Case": case_name, - "M": M, - "N": N, - "K": K, - "dtype": torch.bfloat16, - }) + for M in M_SIZE_LIST: + for case_name, (N, K) in ACTIVE_SHAPES.items(): + test_cases.append({ + "Case": case_name, + "M": M, + "N": N, + "K": K, + "dtype": torch.bfloat16, + }) return test_cases -def bench_gemm(M, N, K, dtype): +def bench_gemm(Case, M, N, K, dtype): device = "cuda" linear = te.Linear(K, N, bias=False).to(device=device, dtype=dtype) @@ -78,7 +55,6 @@ def bench_gemm(M, N, K, dtype): def bwd_func(): out = linear(x) out.backward(grad_out) - # Clear grads so they don't accumulate across iterations x.grad = None linear.weight.grad = None @@ -87,14 +63,12 @@ def bwd_func(): fwd_flops = 2 * M * N * K bwd_flops = 2 * fwd_flops # dX + dW - # Benchmark - fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).adaptive_autorange().mean * 1e3 - fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": bwd_func}).adaptive_autorange().mean * 1e3 - + fwd_ms = time_func(fwd_func) + fwd_bwd_ms = time_func(bwd_func) bwd_ms = fwd_bwd_ms - fwd_ms - fwd_tflops = fwd_flops / (fwd_ms * 1e-3) / 1e12 - bwd_tflops = bwd_flops / (bwd_ms * 1e-3) / 1e12 + fwd_tflops = compute_tflops(fwd_flops, fwd_ms) + bwd_tflops = compute_tflops(bwd_flops, bwd_ms) print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS") print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)") @@ -108,39 +82,13 @@ def bwd_func(): if __name__ == "__main__": - import pandas as pd - - test_cases = _generate_gemm_test_cases() - - columns = [ - "Case", "M", "N", "K", "dtype", - "TE Forward Time (ms)", - "TE Forward TFLOPS", - "TE Backward Time (ms)", - "TE Backward TFLOPS", - ] - rows = [] - - for case in test_cases: - print(f"\n{'='*60}") - print(f"Testing: {case}") - print(f"{'='*60}") - - metrics = bench_gemm( - M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] - ) - row = { - "Case": case["Case"], - "M": case["M"], - "N": case["N"], - "K": case["K"], - "dtype": str(case["dtype"]), - **metrics, - } - rows.append(row) - - results = pd.DataFrame(rows, columns=columns) - - out_csv = "benchmark_gemm.csv" - results.to_csv(out_csv, index=False) - print(f"\nResults saved to {out_csv}") + run_benchmarks( + test_cases=_generate_gemm_test_cases(), + bench_fn=bench_gemm, + param_columns=["Case", "M", "N", "K", "dtype"], + metric_columns=[ + "TE Forward Time (ms)", "TE Forward TFLOPS", + "TE Backward Time (ms)", "TE Backward TFLOPS", + ], + default_csv="benchmark_gemm.csv", + ) diff --git a/benchmarks/microbenchmarks/benchmark_gemm_fp8.py b/benchmarks/microbenchmarks/benchmark_gemm_fp8.py index 7c910c2c9..886cedcb3 100755 --- a/benchmarks/microbenchmarks/benchmark_gemm_fp8.py +++ b/benchmarks/microbenchmarks/benchmark_gemm_fp8.py @@ -7,44 +7,17 @@ """ FP8 GEMM micro-benchmark using te.Linear under fp8_autocast. -Same model shapes as benchmark_gemm.py: - - Llama 3 8B (TP=1, TP=8), 70B (TP=8), 405B (TP=8) - - Qwen 2.5 7B (TP=1), 72B (TP=8) - -Each model contributes four GEMM shapes: - QKV projection (column-parallel) N = (Qheads + 2*KVheads)*head_dim / TP, K = hidden - Attention output (row-parallel) N = hidden, K = Qheads*head_dim / TP - MLP Gate+Up (column-parallel) N = 2*intermediate / TP, K = hidden (SwiGLU) - MLP Down (row-parallel) N = hidden, K = intermediate / TP - -Sources for model configs: - https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json - https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json - https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json - https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json - https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json - +Same model shapes as benchmark_gemm.py. Output: benchmark_gemm_fp8.csv (written to cwd) """ import torch -import torch.utils.benchmark as benchmark - import transformer_engine.pytorch as te from transformer_engine.common.recipe import DelayedScaling, Format - -# Sequence / batch-token sizes to sweep -M_SIZE_LIST = [1024, 2048, 4096, 8192] - -# (name, hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) -MODEL_CONFIGS = [ - ("Llama3-8B/TP1", 4096, 14336, 32, 8, 128, 1), - ("Llama3-8B/TP8", 4096, 14336, 32, 8, 128, 8), - ("Llama3-70B/TP8", 8192, 28672, 64, 8, 128, 8), - ("Llama3-405B/TP8", 16384, 53248, 128, 8, 128, 8), - ("Qwen2.5-7B/TP1", 3584, 18944, 28, 4, 128, 1), - ("Qwen2.5-72B/TP8", 8192, 29568, 64, 8, 128, 8), -] +from utils import ( + MODEL_CONFIGS, M_SIZE_LIST, gemm_shapes, + time_func, compute_tflops, run_benchmarks, +) FP8_RECIPE = DelayedScaling( fp8_format=Format.HYBRID, @@ -52,41 +25,34 @@ amax_compute_algo="max", ) +ACTIVE_SHAPES = gemm_shapes(MODEL_CONFIGS) + def _generate_gemm_test_cases(): test_cases = [] - for (name, hidden, intermediate, n_q, n_kv, hd, tp) in MODEL_CONFIGS: - shapes = { - f"{name}-QKV": ((n_q * hd + 2 * n_kv * hd) // tp, hidden), - f"{name}-AttnOut": (hidden, (n_q * hd) // tp), - f"{name}-GateUp": ((2 * intermediate) // tp, hidden), - f"{name}-Down": (hidden, intermediate // tp), - } - for M in M_SIZE_LIST: - for case_name, (N, K) in shapes.items(): - test_cases.append({ - "Case": case_name, - "M": M, - "N": N, - "K": K, - "dtype": torch.bfloat16, - }) + for M in M_SIZE_LIST: + for case_name, (N, K) in ACTIVE_SHAPES.items(): + test_cases.append({ + "Case": case_name, + "M": M, + "N": N, + "K": K, + "dtype": torch.bfloat16, + }) return test_cases -def bench_fp8_gemm(M, N, K, dtype): +def bench_fp8_gemm(Case, M, N, K, dtype): device = "cuda" linear = te.Linear(K, N, bias=False).to(device=device, dtype=dtype) x = torch.randn(M, K, dtype=dtype, device=device, requires_grad=True) grad_out = torch.randn(M, N, dtype=dtype, device=device) - # Forward under fp8_autocast def fwd_func(): with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): return linear(x) - # Combined fwd+bwd (TE consumes saved state on backward, no retain_graph) def fwd_bwd_func(): with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): out = linear(x) @@ -97,14 +63,12 @@ def fwd_bwd_func(): fwd_flops = 2 * M * N * K bwd_flops = 2 * fwd_flops - # Benchmark - fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).adaptive_autorange().mean * 1e3 - fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_bwd_func}).adaptive_autorange().mean * 1e3 - + fwd_ms = time_func(fwd_func) + fwd_bwd_ms = time_func(fwd_bwd_func) bwd_ms = fwd_bwd_ms - fwd_ms - fwd_tflops = fwd_flops / (fwd_ms * 1e-3) / 1e12 - bwd_tflops = bwd_flops / (bwd_ms * 1e-3) / 1e12 + fwd_tflops = compute_tflops(fwd_flops, fwd_ms) + bwd_tflops = compute_tflops(bwd_flops, bwd_ms) print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS") print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)") @@ -118,39 +82,13 @@ def fwd_bwd_func(): if __name__ == "__main__": - import pandas as pd - - test_cases = _generate_gemm_test_cases() - - columns = [ - "Case", "M", "N", "K", "dtype", - "FP8 Forward Time (ms)", - "FP8 Forward TFLOPS", - "FP8 Backward Time (ms)", - "FP8 Backward TFLOPS", - ] - rows = [] - - for case in test_cases: - print(f"\n{'='*60}") - print(f"Testing: {case}") - print(f"{'='*60}") - - metrics = bench_fp8_gemm( - M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] - ) - row = { - "Case": case["Case"], - "M": case["M"], - "N": case["N"], - "K": case["K"], - "dtype": str(case["dtype"]), - **metrics, - } - rows.append(row) - - results = pd.DataFrame(rows, columns=columns) - - out_csv = "benchmark_gemm_fp8.csv" - results.to_csv(out_csv, index=False) - print(f"\nResults saved to {out_csv}") + run_benchmarks( + test_cases=_generate_gemm_test_cases(), + bench_fn=bench_fp8_gemm, + param_columns=["Case", "M", "N", "K", "dtype"], + metric_columns=[ + "FP8 Forward Time (ms)", "FP8 Forward TFLOPS", + "FP8 Backward Time (ms)", "FP8 Backward TFLOPS", + ], + default_csv="benchmark_gemm_fp8.csv", + ) diff --git a/benchmarks/microbenchmarks/benchmark_grouped_gemm.py b/benchmarks/microbenchmarks/benchmark_grouped_gemm.py index 81fa1b1cf..58ffe5ad3 100755 --- a/benchmarks/microbenchmarks/benchmark_grouped_gemm.py +++ b/benchmarks/microbenchmarks/benchmark_grouped_gemm.py @@ -7,7 +7,7 @@ import os import torch -import torch.utils.benchmark as benchmark +from utils import time_func, compute_tflops, run_benchmarks def generate_grouped_gemm_group_lens(b, m, balance: bool): if balance: @@ -19,8 +19,8 @@ def generate_grouped_gemm_group_lens(b, m, balance: bool): error = b * m - group_lens.sum() group_lens[-1] += error return group_lens - -M_SIZE_LIST = [512, 1024, 2048, 4096]#, 8192, 16384] + +M_SIZE_LIST = [512, 1024, 2048, 4096] EP_SIZE_LIST = [32, 16, 8] @@ -98,7 +98,6 @@ def make_fwd_bwd_funcs_te(x, w, group_lens, activation_dtype): xs = list(torch.split(x_view, m_splits)) weights = [w[i] for i in range(B)] - # Forward output buffer out = torch.empty((sum_M, N), device=x.device, dtype=activation_dtype) def fwd_func_te(): @@ -116,11 +115,9 @@ def fwd_func_te(): ) return out - # dx buffers dx = torch.empty((sum_M, K), device=x.device, dtype=activation_dtype) dxs = list(torch.split(dx, m_splits)) - # dw buffers dw_stacked = torch.empty((B, N, K), device=x.device, dtype=activation_dtype) dws = [dw_stacked[i] for i in range(B)] @@ -162,7 +159,7 @@ def bwd_func_te(grad_out): return fwd_func_te, bwd_func_te -def bench_grouped_gemm(B, M, N, K, dtype): +def bench_grouped_gemm(Case, B, M, N, K, dtype): device = "cuda" x = torch.randn((B * M, K), dtype=dtype, device=device, requires_grad=True) @@ -173,7 +170,6 @@ def bench_grouped_gemm(B, M, N, K, dtype): os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" os.environ["NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK"] = "1" - # TE grouped (CK_Tile) x_te = x.clone().detach() w_te = w.clone().detach() fwd_func_te, bwd_func_te_inner = make_fwd_bwd_funcs_te( @@ -184,16 +180,14 @@ def bench_grouped_gemm(B, M, N, K, dtype): grad_out = torch.randn_like(out_te) bwd_func_te = lambda: bwd_func_te_inner(grad_out) - # FLOPs fwd_total_flops = 2 * B * M * N * K bwd_total_flops = 2 * fwd_total_flops - # Benchmark - fwd_te_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func_te}).adaptive_autorange().mean * 1e3 - bwd_te_ms = benchmark.Timer(stmt="fn()", globals={"fn": bwd_func_te}).adaptive_autorange().mean * 1e3 + fwd_te_ms = time_func(fwd_func_te) + bwd_te_ms = time_func(bwd_func_te) - fwd_te_tflops = fwd_total_flops / (fwd_te_ms * 1e-3) / 1e12 - bwd_te_tflops = bwd_total_flops / (bwd_te_ms * 1e-3) / 1e12 + fwd_te_tflops = compute_tflops(fwd_total_flops, fwd_te_ms) + bwd_te_tflops = compute_tflops(bwd_total_flops, bwd_te_ms) print(f"TE (CK_Tile) Forward {fwd_te_ms:.3f} ms | {fwd_te_tflops:.2f} TFLOPS") print(f"TE (CK_Tile) Backward {bwd_te_ms:.3f} ms | {bwd_te_tflops:.2f} TFLOPS") @@ -207,8 +201,6 @@ def bench_grouped_gemm(B, M, N, K, dtype): if __name__ == "__main__": - import pandas as pd - test_cases = ( generate_deepseekv2_lite_test_cases() + generate_deepseekv2_test_cases() @@ -216,36 +208,13 @@ def bench_grouped_gemm(B, M, N, K, dtype): + generate_grok_v2_test_cases() ) - columns = [ - "Case", "B", "M", "N", "K", "dtype", - "TE (CK_Tile) Forward Time (ms)", - "TE (CK_Tile) Forward TFLOPS", - "TE (CK_Tile) Backward Time (ms)", - "TE (CK_Tile) Backward TFLOPS", - ] - rows = [] - - for case in test_cases: - print(f"\n{'='*50}") - print(f"Testing: {case}") - print(f"{'='*50}") - - metrics = bench_grouped_gemm( - B=case["B"], M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] - ) - row = { - "Case": case["Case"], - "B": case["B"], - "M": case["M"], - "N": case["N"], - "K": case["K"], - "dtype": str(case["dtype"]), - **metrics, - } - rows.append(row) - - results = pd.DataFrame(rows, columns=columns) - - out_csv = "benchmark_grouped_gemm.csv" - results.to_csv(out_csv, index=False) - print(f"\nResults saved to {out_csv}") + run_benchmarks( + test_cases=test_cases, + bench_fn=bench_grouped_gemm, + param_columns=["Case", "B", "M", "N", "K", "dtype"], + metric_columns=[ + "TE (CK_Tile) Forward Time (ms)", "TE (CK_Tile) Forward TFLOPS", + "TE (CK_Tile) Backward Time (ms)", "TE (CK_Tile) Backward TFLOPS", + ], + default_csv="benchmark_grouped_gemm.csv", + ) diff --git a/benchmarks/microbenchmarks/benchmark_normalization.py b/benchmarks/microbenchmarks/benchmark_normalization.py index 3bebfc39b..c6af11beb 100755 --- a/benchmarks/microbenchmarks/benchmark_normalization.py +++ b/benchmarks/microbenchmarks/benchmark_normalization.py @@ -7,42 +7,18 @@ """ Normalization micro-benchmark using te.LayerNorm and te.RMSNorm. -Shapes are derived from training workloads: - - Llama 3 8B, 70B, 405B (all use RMSNorm) - - Qwen 2.5 7B, 72B (all use RMSNorm) - -Modern models predominantly use RMSNorm, but we benchmark both -LayerNorm and RMSNorm since TE supports both and they share the -same kernel infrastructure. - +Both LayerNorm and RMSNorm share the same kernel infrastructure. The M dimension (batch * seq_len) is swept across typical training sizes. -Sources for model configs: - https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json - https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json - https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json - https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json - https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json - Output: benchmark_normalization.csv (written to cwd) """ import torch -import torch.utils.benchmark as benchmark - import transformer_engine.pytorch as te - -# Sequence / batch-token sizes to sweep -M_SIZE_LIST = [1024, 2048, 4096, 8192] - -# (model_name, hidden_size) -MODEL_HIDDEN_SIZES = [ - ("Llama3-8B", 4096), - ("Llama3-70B", 8192), - ("Llama3-405B", 16384), - ("Qwen2.5-7B", 3584), - ("Qwen2.5-72B", 8192), -] +from utils import ( + MODEL_HIDDEN_SIZES, M_SIZE_LIST, + time_func, compute_gbps, run_benchmarks, +) NORM_TYPES = [ ("RMSNorm", te.RMSNorm), @@ -66,7 +42,7 @@ def _generate_norm_test_cases(): return test_cases -def bench_norm(M, hidden_size, norm_cls, dtype): +def bench_norm(Case, M, hidden_size, norm_name, norm_cls, dtype): device = "cuda" norm = norm_cls(hidden_size).to(device=device, dtype=dtype) @@ -85,23 +61,16 @@ def fwd_bwd_func(): fwd_bwd_func() - # Normalization is memory-bound; report bandwidth instead of FLOPS. - # Each element is read once (fwd) or read+written (bwd), plus the - # weight/bias vectors. We report effective GB/s based on the - # minimum data movement: fwd reads x and writes y, bwd reads - # grad_out+x+saved_stats and writes grad_x+grad_weight. elem_bytes = x.element_size() fwd_bytes = 2 * M * hidden_size * elem_bytes # read x, write y bwd_bytes = 4 * M * hidden_size * elem_bytes # read grad+x+y, write grad_x - # Benchmark - fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).adaptive_autorange().mean * 1e3 - fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_bwd_func}).adaptive_autorange().mean * 1e3 - + fwd_ms = time_func(fwd_func) + fwd_bwd_ms = time_func(fwd_bwd_func) bwd_ms = fwd_bwd_ms - fwd_ms - fwd_gbps = fwd_bytes / (fwd_ms * 1e-3) / 1e9 - bwd_gbps = bwd_bytes / (bwd_ms * 1e-3) / 1e9 + fwd_gbps = compute_gbps(fwd_bytes, fwd_ms) + bwd_gbps = compute_gbps(bwd_bytes, bwd_ms) print(f" Forward {fwd_ms:.3f} ms | {fwd_gbps:.1f} GB/s") print(f" Backward {bwd_ms:.3f} ms | {bwd_gbps:.1f} GB/s (derived)") @@ -115,41 +84,13 @@ def fwd_bwd_func(): if __name__ == "__main__": - import pandas as pd - - test_cases = _generate_norm_test_cases() - - columns = [ - "Case", "M", "hidden_size", "dtype", - "TE Forward Time (ms)", - "TE Forward GB/s", - "TE Backward Time (ms)", - "TE Backward GB/s", - ] - rows = [] - - for case in test_cases: - print(f"\n{'='*60}") - print(f"Testing: {case['Case']} M={case['M']} hidden={case['hidden_size']}") - print(f"{'='*60}") - - metrics = bench_norm( - M=case["M"], - hidden_size=case["hidden_size"], - norm_cls=case["norm_cls"], - dtype=case["dtype"], - ) - row = { - "Case": case["Case"], - "M": case["M"], - "hidden_size": case["hidden_size"], - "dtype": str(case["dtype"]), - **metrics, - } - rows.append(row) - - results = pd.DataFrame(rows, columns=columns) - - out_csv = "benchmark_normalization.csv" - results.to_csv(out_csv, index=False) - print(f"\nResults saved to {out_csv}") + run_benchmarks( + test_cases=_generate_norm_test_cases(), + bench_fn=bench_norm, + param_columns=["Case", "M", "hidden_size", "dtype"], + metric_columns=[ + "TE Forward Time (ms)", "TE Forward GB/s", + "TE Backward Time (ms)", "TE Backward GB/s", + ], + default_csv="benchmark_normalization.csv", + ) diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py new file mode 100644 index 000000000..0e4a27be3 --- /dev/null +++ b/benchmarks/microbenchmarks/utils.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""Shared utilities for microbenchmarks: model configs, timing, throughput, runner.""" + +import argparse +import torch +import torch.utils.benchmark as benchmark + +# --------------------------------------------------------------------------- +# Sequence / batch-token sizes +# --------------------------------------------------------------------------- +M_SIZE_LIST = [1024, 2048, 4096, 8192] + +# --------------------------------------------------------------------------- +# Model configurations +# --------------------------------------------------------------------------- +# (name, hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) +# +# Sources: +# - Llama 3 8B https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json +# - Llama 3 70B https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json +# - Llama 3 405B https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json +# - Qwen 2.5 7B https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json +# - Qwen 2.5 72B https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json + +MODEL_CONFIGS = [ + ("Llama3-8B/TP1", 4096, 14336, 32, 8, 128, 1), + ("Llama3-8B/TP8", 4096, 14336, 32, 8, 128, 8), + ("Llama3-70B/TP8", 8192, 28672, 64, 8, 128, 8), + ("Llama3-405B/TP8", 16384, 53248, 128, 8, 128, 8), + ("Qwen2.5-7B/TP1", 3584, 18944, 28, 4, 128, 1), + ("Qwen2.5-72B/TP8", 8192, 29568, 64, 8, 128, 8), +] + +# Unique (model_name, hidden_size) pairs for element-wise benchmarks +MODEL_HIDDEN_SIZES = [ + ("Llama3-8B", 4096), + ("Llama3-70B", 8192), + ("Llama3-405B", 16384), + ("Qwen2.5-7B", 3584), + ("Qwen2.5-72B", 8192), +] + + +def gemm_shapes(configs=None): + """Generate {case_name: (N, K)} dict from MODEL_CONFIGS. + + Each model contributes up to four GEMM shapes: + QKV, AttnOut, GateUp (SwiGLU), Down. + """ + shapes = {} + for (name, hidden, intermediate, n_q, n_kv, hd, tp) in (configs or MODEL_CONFIGS): + shapes[f"{name}-QKV"] = ((n_q * hd + 2 * n_kv * hd) // tp, hidden) + shapes[f"{name}-AttnOut"] = (hidden, (n_q * hd) // tp) + shapes[f"{name}-GateUp"] = ((2 * intermediate) // tp, hidden) + shapes[f"{name}-Down"] = (hidden, intermediate // tp) + return shapes + + +# --------------------------------------------------------------------------- +# Timing helpers +# --------------------------------------------------------------------------- + +def time_func(fn, method="adaptive"): + """Time *fn* and return elapsed milliseconds. + + method: "adaptive" uses adaptive_autorange (good for compute-bound), + "blocked" uses blocked_autorange (good for memory-bound). + """ + timer = benchmark.Timer(stmt="fn()", globals={"fn": fn}) + if method == "blocked": + return timer.blocked_autorange().mean * 1e3 + return timer.adaptive_autorange().mean * 1e3 + + +# --------------------------------------------------------------------------- +# Throughput helpers +# --------------------------------------------------------------------------- + +def compute_tflops(flops, ms): + """TFLOPS from operation count and milliseconds.""" + return flops / (ms * 1e-3) / 1e12 + + +def compute_gbps(nbytes, ms): + """GB/s from byte count and milliseconds.""" + return nbytes / (ms * 1e-3) / 1e9 + + +# --------------------------------------------------------------------------- +# Benchmark runner +# --------------------------------------------------------------------------- + +def add_csv_arg(parser): + """Add a ``--csv`` flag to an argparse parser.""" + parser.add_argument( + "--csv", nargs="?", const=True, default=None, metavar="FILE", + help="Write results to CSV. Optional filename; default derived from script name.", + ) + + +def run_benchmarks(test_cases, bench_fn, param_columns, metric_columns, + default_csv=None): + """Iterate *test_cases*, call *bench_fn*, and optionally write a CSV. + + Parameters + ---------- + test_cases : list[dict] + Each dict has at least the keys in *param_columns* plus any extra + keys the bench_fn needs (passed as **case). + bench_fn : callable + Called as ``bench_fn(**case)`` and must return a dict whose keys + match *metric_columns*. + param_columns : list[str] + Column names to pull from each test case into the output row. + metric_columns : list[str] + Column names to pull from the bench_fn return value. + default_csv : str or None + Default CSV filename used when ``--csv`` is passed without a + filename. CSV output is only written when the caller passes + ``--csv`` on the command line. + """ + parser = argparse.ArgumentParser(add_help=False) + add_csv_arg(parser) + args, _ = parser.parse_known_args() + + columns = param_columns + metric_columns + rows = [] + + for case in test_cases: + label = " ".join(f"{k}={case[k]}" for k in param_columns) + print(f"\n{'='*60}") + print(f"Testing: {label}") + print(f"{'='*60}") + + metrics = bench_fn(**case) + + row = {k: (str(case[k]) if isinstance(case[k], torch.dtype) else case[k]) + for k in param_columns} + row.update(metrics) + rows.append(row) + + if args.csv is not None: + import pandas as pd + out_csv = args.csv if isinstance(args.csv, str) else default_csv + results = pd.DataFrame(rows, columns=columns) + results.to_csv(out_csv, index=False) + print(f"\nResults saved to {out_csv}") From fefaf1337b630dae907dbf8c95ebe5754a535f68 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 7 May 2026 12:27:44 -0500 Subject: [PATCH 20/25] remove asv converter --- benchmarks/microbenchmarks/asv.conf.json | 15 - benchmarks/microbenchmarks/csv_to_asv.py | 455 ----------------------- 2 files changed, 470 deletions(-) delete mode 100644 benchmarks/microbenchmarks/asv.conf.json delete mode 100755 benchmarks/microbenchmarks/csv_to_asv.py diff --git a/benchmarks/microbenchmarks/asv.conf.json b/benchmarks/microbenchmarks/asv.conf.json deleted file mode 100644 index 0645b7212..000000000 --- a/benchmarks/microbenchmarks/asv.conf.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "version": 1, - "project": "TransformerEngine", - "project_url": "https://github.com/ROCm/TransformerEngine", - "repo": "../..", - "branches": ["HEAD"], - "environment_type": "existing", - "install_command": [], - "build_command": [], - "benchmark_dir": ".", - "results_dir": "../.asv/results", - "html_dir": "../.asv/html", - "benchmark_timeout": 1200, - "launch_method": "spawn" -} diff --git a/benchmarks/microbenchmarks/csv_to_asv.py b/benchmarks/microbenchmarks/csv_to_asv.py deleted file mode 100755 index b6ddace59..000000000 --- a/benchmarks/microbenchmarks/csv_to_asv.py +++ /dev/null @@ -1,455 +0,0 @@ -#!/usr/bin/env python -############################################################################### -# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. -# -# See LICENSE for license information. -############################################################################### -""" -Convert benchmark CSV files (from benchmark_*.py) to ASV-compatible JSON. - -Reads one or more CSV files produced by the microbenchmarks and writes -the same JSON result format that asv requires, so results -can be visualised with ``asv publish && asv preview`` or compared with -``asv compare``. - -Usage: - # Convert all CSVs in the current directory - python csv_to_asv.py benchmark_gemm.csv benchmark_casting.csv ... - - # Convert all benchmark CSVs found in a directory - python csv_to_asv.py perf_results/pr/*.csv - - # Specify output directory (default: benchmarks/.asv/results) - python csv_to_asv.py --results-dir ./my_results benchmark_gemm.csv - - # Provide a custom machine name or commit hash - python csv_to_asv.py --machine mi325 --commit abc1234 *.csv -""" - -import argparse -import glob -import hashlib -import json -import os -import platform -import subprocess -import sys -import time - -import pandas as pd - - -# --------------------------------------------------------------------------- -# Column classification -# --------------------------------------------------------------------------- - -# Columns that are never parameters and never metrics -_SKIP_COLS = {"TestID", "Label"} - - -def _classify_columns(df): - """Split DataFrame columns into (key_cols, time_cols, throughput_cols). - - Heuristic: - - Columns containing "Time" and "(ms)" are time metrics. - - Columns containing "TFLOPS" or "GB/s" are throughput metrics. - - Everything else (except _SKIP_COLS) is a key/parameter column. - """ - time_cols = [] - throughput_cols = [] - key_cols = [] - - for c in df.columns: - if c in _SKIP_COLS: - continue - if "Time" in c and "(ms)" in c: - time_cols.append(c) - elif "TFLOPS" in c or "GB/s" in c: - throughput_cols.append(c) - else: - key_cols.append(c) - - return key_cols, time_cols, throughput_cols - - -def _pair_time_throughput(time_cols, throughput_cols): - """Pair each time column with its throughput companion, if any. - - Matching heuristic: strip the distinctive suffixes and compare the - remaining prefix. E.g. - "TE Forward Time (ms)" <-> "TE Forward TFLOPS" - "Cast Time (ms)" <-> "Cast GB/s" - - Returns a list of (time_col, throughput_col_or_None) tuples. - """ - - def _time_key(col): - return col.replace(" Time (ms)", "").strip() - - def _tp_key(col): - for suffix in (" TFLOPS", " GB/s"): - if col.endswith(suffix): - return col[: -len(suffix)].strip() - return col.strip() - - tp_by_key = {} - for tc in throughput_cols: - tp_by_key[_tp_key(tc)] = tc - - pairs = [] - matched_tp = set() - for tc in time_cols: - key = _time_key(tc) - companion = tp_by_key.get(key) - pairs.append((tc, companion)) - if companion: - matched_tp.add(companion) - - # Standalone throughput columns (no matching time col) - for tc in throughput_cols: - if tc not in matched_tp: - pairs.append((None, tc)) - - return pairs - - -# --------------------------------------------------------------------------- -# ASV helpers (mirrored from PR #487 driver.py) -# --------------------------------------------------------------------------- - -def _get_machine_info(): - """Build the params / machine dict that ASV expects.""" - machine = platform.node() - info = { - "arch": platform.machine(), - "cpu": "", - "machine": machine, - "num_cpu": str(os.cpu_count()), - "os": f"{platform.system()} {platform.release()}", - "ram": "", - } - try: - with open("/proc/cpuinfo") as f: - for line in f: - if line.startswith("model name"): - info["cpu"] = line.split(":", 1)[1].strip() - break - with open("/proc/meminfo") as f: - for line in f: - if line.startswith("MemTotal"): - info["ram"] = line.split()[1] # kB - break - except OSError: - pass - return machine, info - - -def _get_commit_hash(): - """Get the current git HEAD hash.""" - try: - return ( - subprocess.check_output( - ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL - ) - .decode() - .strip() - ) - except Exception: - return "unknown" - - -def _format_param_value(v): - """Format a parameter value the way ASV stores it in JSON.""" - if isinstance(v, str): - return f"'{v}'" - return repr(v) - - -def _make_version(suite_name, bench_name, param_names): - """Deterministic version hash for a benchmark entry.""" - code = f"{suite_name}.{bench_name}({', '.join(param_names)})" - return hashlib.sha256(code.encode()).hexdigest() - - -# --------------------------------------------------------------------------- -# Conversion -# --------------------------------------------------------------------------- - -def csv_to_asv_entries(csv_path): - """Convert one CSV file into ASV result + meta dicts. - - Returns (results_dict, meta_dict) where each key is a fully-qualified - benchmark name like ``benchmark_gemm.time_forward``. - """ - df = pd.read_csv(csv_path) - if df.empty: - return {}, {} - - suite_name = os.path.splitext(os.path.basename(csv_path))[0] - key_cols, time_cols, throughput_cols = _classify_columns(df) - pairs = _pair_time_throughput(time_cols, throughput_cols) - - # Build the ASV parameter axes from unique values in key columns - param_names = list(key_cols) - asv_params = [] - for col in key_cols: - asv_params.append([_format_param_value(v) for v in df[col].unique().tolist()]) - - # Build a lookup from key tuple -> row index for fast access - key_tuples = [tuple(row) for row in df[key_cols].values] - - # Cross-product of unique param values (in the order they appear) - unique_per_col = [df[col].unique().tolist() for col in key_cols] - import itertools - - all_combos = list(itertools.product(*unique_per_col)) - - combo_to_idx = {} - for idx, kt in enumerate(key_tuples): - combo_to_idx[kt] = idx - - results = {} - meta = {} - now_ms = int(time.time() * 1000) - - for time_col, tp_col in pairs: - # Derive a short benchmark name - if time_col: - # e.g. "TE Forward Time (ms)" -> "time_te_forward" - short = time_col.replace(" Time (ms)", "").strip() - short = "time_" + short.lower().replace(" ", "_").replace("(", "").replace(")", "") - bench_key = f"{suite_name}.{short}" - elif tp_col: - short = tp_col.strip() - for suffix in (" TFLOPS", " GB/s"): - short = short.replace(suffix, "") - short = short.strip().lower().replace(" ", "_").replace("(", "").replace(")", "") - bench_key = f"{suite_name}.throughput_{short}" - else: - continue - - version = _make_version(suite_name, short, param_names) - - # Populate values for every combo in the cross-product - time_values = [] - for combo in all_combos: - idx = combo_to_idx.get(combo) - if idx is not None and time_col and time_col in df.columns: - val = df.loc[idx, time_col] - try: - # Convert ms -> seconds (ASV convention) - time_values.append(float(val) / 1000.0) - except (ValueError, TypeError): - time_values.append(None) - else: - time_values.append(None) - - # Store the time benchmark - if time_col: - n = len(time_values) - results[bench_key] = [ - time_values, # result (medians) - asv_params, # params - version, # version - now_ms, # started_at - 0, # duration - [None] * n, # stats_ci_99_a - [None] * n, # stats_ci_99_b - [None] * n, # stats_q_25 - [None] * n, # stats_q_75 - [1] * n, # stats_number - [1] * n, # stats_repeat - ] - meta[bench_key] = { - "code": "", - "name": bench_key, - "param_names": param_names, - "params": asv_params, - "timeout": 300, - "type": "time", - "unit": "seconds", - "version": version, - } - - # Store the throughput companion - if tp_col: - tp_values = [] - for combo in all_combos: - idx = combo_to_idx.get(combo) - if idx is not None and tp_col in df.columns: - val = df.loc[idx, tp_col] - try: - tp_values.append(float(val)) - except (ValueError, TypeError): - tp_values.append(None) - else: - tp_values.append(None) - - if "TFLOPS" in tp_col: - tp_unit = "TFLOPS" - elif "GB/s" in tp_col: - tp_unit = "GB/s" - else: - tp_unit = "" - - tp_short = tp_col.strip() - for suffix in (" TFLOPS", " GB/s"): - tp_short = tp_short.replace(suffix, "") - tp_short = tp_short.strip().lower().replace(" ", "_").replace("(", "").replace(")", "") - tp_key = f"{suite_name}.throughput_{tp_short}" - tp_version = _make_version(suite_name, f"throughput_{tp_short}", param_names) - - n = len(tp_values) - results[tp_key] = [ - tp_values, - asv_params, - tp_version, - now_ms, - 0, - [None] * n, - [None] * n, - [None] * n, - [None] * n, - [1] * n, - [1] * n, - ] - meta[tp_key] = { - "code": "", - "name": tp_key, - "param_names": param_names, - "params": asv_params, - "timeout": 300, - "type": "time", - "unit": tp_unit, - "version": tp_version, - } - - return results, meta - - -def save_asv_results(all_results, all_meta, results_dir, machine_name=None, - commit_hash=None): - """Write results and benchmark index to ASV's results directory.""" - if commit_hash is None: - commit_hash = _get_commit_hash() - detected_machine, machine_info = _get_machine_info() - if machine_name: - machine_info["machine"] = machine_name - else: - machine_name = detected_machine - - env_name = "existing-" + sys.executable.replace("/", "_").strip("_") - machine_dir = os.path.join(results_dir, machine_name) - os.makedirs(machine_dir, exist_ok=True) - - # Write machine.json if missing - machine_json = os.path.join(machine_dir, "machine.json") - if not os.path.exists(machine_json): - with open(machine_json, "w") as f: - json.dump({**machine_info, "version": 1}, f, indent=4) - - # Load existing result file or start fresh - filename = f"{commit_hash[:8]}-{env_name}.json" - result_path = os.path.join(machine_dir, filename) - if os.path.exists(result_path): - with open(result_path) as f: - data = json.load(f) - else: - data = { - "commit_hash": commit_hash, - "env_name": env_name, - "date": int(time.time() * 1000), - "params": {**machine_info, "python": sys.executable}, - "python": sys.executable, - "requirements": {}, - "env_vars": {}, - "result_columns": [ - "result", "params", "version", - "started_at", "duration", - "stats_ci_99_a", "stats_ci_99_b", - "stats_q_25", "stats_q_75", - "stats_number", "stats_repeat", - ], - "results": {}, - "durations": {}, - "version": 2, - } - - # Merge new results - for bench_key, bench_data in all_results.items(): - data["results"][bench_key] = bench_data - - with open(result_path, "w") as f: - json.dump(data, f, indent=2) - - print(f"Results saved to {result_path}") - - # Update benchmarks.json index - benchmarks_path = os.path.join(results_dir, "benchmarks.json") - if os.path.exists(benchmarks_path): - with open(benchmarks_path) as f: - benchmarks_data = json.load(f) - else: - benchmarks_data = {"version": 2} - - benchmarks_data.update(all_meta) - - with open(benchmarks_path, "w") as f: - json.dump(benchmarks_data, f, indent=4) - - print(f"Updated {benchmarks_path}") - - -# --------------------------------------------------------------------------- -# CLI -# --------------------------------------------------------------------------- - -def main(): - parser = argparse.ArgumentParser( - description="Convert benchmark CSV files to ASV-compatible JSON.") - parser.add_argument("csv_files", nargs="+", - help="CSV files produced by benchmark_*.py") - parser.add_argument("--results-dir", default=None, - help="ASV results directory " - "(default: benchmarks/.asv/results relative to repo root)") - parser.add_argument("--machine", default=None, - help="Machine name for ASV (default: hostname)") - parser.add_argument("--commit", default=None, - help="Commit hash (default: git rev-parse HEAD)") - args = parser.parse_args() - - if args.results_dir is None: - # Default: benchmarks/.asv/results relative to repo root - try: - repo_root = ( - subprocess.check_output( - ["git", "rev-parse", "--show-toplevel"], stderr=subprocess.DEVNULL - ) - .decode() - .strip() - ) - except Exception: - repo_root = os.getcwd() - args.results_dir = os.path.join(repo_root, "benchmarks", ".asv", "results") - - all_results = {} - all_meta = {} - - for csv_path in args.csv_files: - for f in glob.glob(csv_path): - print(f"Processing {f} ...") - results, meta_data = csv_to_asv_entries(f) - all_results.update(results) - all_meta.update(meta_data) - print(f" {len(results)} benchmark entries extracted") - - if not all_results: - print("No benchmark data found.") - return - - save_asv_results(all_results, all_meta, args.results_dir, - machine_name=args.machine, commit_hash=args.commit) - - -if __name__ == "__main__": - main() From 7f2669d13aabc69a829aec3e8994fdeb45705136 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 7 May 2026 15:54:49 -0500 Subject: [PATCH 21/25] cleanups, misc fixes --- benchmarks/microbenchmarks/benchmark_gemm.py | 20 ++++-------------- .../microbenchmarks/benchmark_gemm_fp8.py | 2 +- .../microbenchmarks/benchmark_grouped_gemm.py | 21 +++++++------------ 3 files changed, 13 insertions(+), 30 deletions(-) diff --git a/benchmarks/microbenchmarks/benchmark_gemm.py b/benchmarks/microbenchmarks/benchmark_gemm.py index 36c7bbf1c..28353183a 100755 --- a/benchmarks/microbenchmarks/benchmark_gemm.py +++ b/benchmarks/microbenchmarks/benchmark_gemm.py @@ -13,19 +13,7 @@ time_func, compute_tflops, run_benchmarks, ) -# Select which configs / shapes to run (comment/uncomment as needed) -ACTIVE_CONFIGS = [ - MODEL_CONFIGS[0], # Llama3-8B/TP1 - # MODEL_CONFIGS[1], # Llama3-8B/TP8 - # MODEL_CONFIGS[2], # Llama3-70B/TP8 - # MODEL_CONFIGS[3], # Llama3-405B/TP8 - # MODEL_CONFIGS[4], # Qwen2.5-7B/TP1 - # MODEL_CONFIGS[5], # Qwen2.5-72B/TP8 -] - -ACTIVE_SHAPES = gemm_shapes(ACTIVE_CONFIGS) -# To restrict shapes, filter the dict: -ACTIVE_SHAPES = {k: v for k, v in ACTIVE_SHAPES.items() if "QKV" in k} +ACTIVE_SHAPES = gemm_shapes(MODEL_CONFIGS) def _generate_gemm_test_cases(): @@ -52,19 +40,19 @@ def bench_gemm(Case, M, N, K, dtype): out = fwd_func() grad_out = torch.randn_like(out) - def bwd_func(): + def fwd_bwd_func(): out = linear(x) out.backward(grad_out) x.grad = None linear.weight.grad = None - bwd_func() + fwd_bwd_func() fwd_flops = 2 * M * N * K bwd_flops = 2 * fwd_flops # dX + dW fwd_ms = time_func(fwd_func) - fwd_bwd_ms = time_func(bwd_func) + fwd_bwd_ms = time_func(fwd_bwd_func) bwd_ms = fwd_bwd_ms - fwd_ms fwd_tflops = compute_tflops(fwd_flops, fwd_ms) diff --git a/benchmarks/microbenchmarks/benchmark_gemm_fp8.py b/benchmarks/microbenchmarks/benchmark_gemm_fp8.py index 886cedcb3..3c7ebae34 100755 --- a/benchmarks/microbenchmarks/benchmark_gemm_fp8.py +++ b/benchmarks/microbenchmarks/benchmark_gemm_fp8.py @@ -56,7 +56,7 @@ def fwd_func(): def fwd_bwd_func(): with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): out = linear(x) - out.backward(grad_out) + out.backward(grad_out) x.grad = None linear.weight.grad = None diff --git a/benchmarks/microbenchmarks/benchmark_grouped_gemm.py b/benchmarks/microbenchmarks/benchmark_grouped_gemm.py index 58ffe5ad3..484a4b3aa 100755 --- a/benchmarks/microbenchmarks/benchmark_grouped_gemm.py +++ b/benchmarks/microbenchmarks/benchmark_grouped_gemm.py @@ -5,7 +5,6 @@ # See LICENSE for license information. ############################################################################### -import os import torch from utils import time_func, compute_tflops, run_benchmarks @@ -165,10 +164,6 @@ def bench_grouped_gemm(Case, B, M, N, K, dtype): x = torch.randn((B * M, K), dtype=dtype, device=device, requires_grad=True) w = torch.randn((B, N, K), dtype=dtype, device=device, requires_grad=True) group_lens = generate_grouped_gemm_group_lens(B, M, balance=True).to(device) - print("group_lens: ", group_lens) - - os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" - os.environ["NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK"] = "1" x_te = x.clone().detach() w_te = w.clone().detach() @@ -189,14 +184,14 @@ def bench_grouped_gemm(Case, B, M, N, K, dtype): fwd_te_tflops = compute_tflops(fwd_total_flops, fwd_te_ms) bwd_te_tflops = compute_tflops(bwd_total_flops, bwd_te_ms) - print(f"TE (CK_Tile) Forward {fwd_te_ms:.3f} ms | {fwd_te_tflops:.2f} TFLOPS") - print(f"TE (CK_Tile) Backward {bwd_te_ms:.3f} ms | {bwd_te_tflops:.2f} TFLOPS") + print(f" Forward {fwd_te_ms:.3f} ms | {fwd_te_tflops:.2f} TFLOPS") + print(f" Backward {bwd_te_ms:.3f} ms | {bwd_te_tflops:.2f} TFLOPS") return { - "TE (CK_Tile) Forward Time (ms)": f"{fwd_te_ms:.2f}", - "TE (CK_Tile) Forward TFLOPS": f"{fwd_te_tflops:.2f}", - "TE (CK_Tile) Backward Time (ms)": f"{bwd_te_ms:.2f}", - "TE (CK_Tile) Backward TFLOPS": f"{bwd_te_tflops:.2f}", + "TE Forward Time (ms)": f"{fwd_te_ms:.2f}", + "TE Forward TFLOPS": f"{fwd_te_tflops:.2f}", + "TE Backward Time (ms)": f"{bwd_te_ms:.2f}", + "TE Backward TFLOPS": f"{bwd_te_tflops:.2f}", } @@ -213,8 +208,8 @@ def bench_grouped_gemm(Case, B, M, N, K, dtype): bench_fn=bench_grouped_gemm, param_columns=["Case", "B", "M", "N", "K", "dtype"], metric_columns=[ - "TE (CK_Tile) Forward Time (ms)", "TE (CK_Tile) Forward TFLOPS", - "TE (CK_Tile) Backward Time (ms)", "TE (CK_Tile) Backward TFLOPS", + "TE Forward Time (ms)", "TE Forward TFLOPS", + "TE Backward Time (ms)", "TE Backward TFLOPS", ], default_csv="benchmark_grouped_gemm.csv", ) From d4e116af203df2e821370aa5682fce82bcb3c7bc Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 11 May 2026 14:12:39 -0500 Subject: [PATCH 22/25] simplifications, address review comments --- .../microbenchmarks/benchmark_casting.py | 12 +- benchmarks/microbenchmarks/benchmark_gemm.py | 26 ++--- .../microbenchmarks/benchmark_gemm_fp8.py | 26 ++--- .../microbenchmarks/benchmark_grouped_gemm.py | 30 ++--- .../benchmark_normalization.py | 26 ++--- benchmarks/microbenchmarks/compare_results.py | 78 +++++++++---- benchmarks/microbenchmarks/utils.py | 107 ++++++++++++++++-- 7 files changed, 210 insertions(+), 95 deletions(-) diff --git a/benchmarks/microbenchmarks/benchmark_casting.py b/benchmarks/microbenchmarks/benchmark_casting.py index 0d7878620..1d8056d08 100755 --- a/benchmarks/microbenchmarks/benchmark_casting.py +++ b/benchmarks/microbenchmarks/benchmark_casting.py @@ -20,12 +20,14 @@ from transformer_engine.pytorch import Float8Quantizer from utils import ( MODEL_HIDDEN_SIZES, M_SIZE_LIST, - time_func, compute_gbps, run_benchmarks, + time_func, compute_gbps, make_metric_record, run_benchmarks, ) TE_FP8_E4M3 = tex.DType.kFloat8E4M3 TE_FP8_E5M2 = tex.DType.kFloat8E5M2 +CAST_LABEL = "Cast" + CAST_CONFIGS = [ # (name, direction, fp8_dtype) ("BF16-to-FP8-E4M3", "quantize", TE_FP8_E4M3), @@ -73,12 +75,7 @@ def bench_cast(Case, M, hidden_size, direction, fp8_dtype, dtype_str): ms = time_func(cast_func, method="blocked") gbps = compute_gbps(total_bytes, ms) - print(f" {ms:.4f} ms | {gbps:.1f} GB/s") - - return { - "Cast Time (ms)": f"{ms:.4f}", - "Cast GB/s": f"{gbps:.1f}", - } + return [make_metric_record(CAST_LABEL, ms, "GB/s", gbps)] if __name__ == "__main__": @@ -86,6 +83,5 @@ def bench_cast(Case, M, hidden_size, direction, fp8_dtype, dtype_str): test_cases=_generate_cast_test_cases(), bench_fn=bench_cast, param_columns=["Case", "M", "hidden_size", "dtype_str"], - metric_columns=["Cast Time (ms)", "Cast GB/s"], default_csv="benchmark_casting.csv", ) diff --git a/benchmarks/microbenchmarks/benchmark_gemm.py b/benchmarks/microbenchmarks/benchmark_gemm.py index 28353183a..f5270e4cc 100755 --- a/benchmarks/microbenchmarks/benchmark_gemm.py +++ b/benchmarks/microbenchmarks/benchmark_gemm.py @@ -10,11 +10,13 @@ import transformer_engine.pytorch as te from utils import ( MODEL_CONFIGS, M_SIZE_LIST, gemm_shapes, - time_func, compute_tflops, run_benchmarks, + time_func, compute_tflops, make_forward_backward_metric_records, run_benchmarks, ) ACTIVE_SHAPES = gemm_shapes(MODEL_CONFIGS) +BENCHMARK_LABEL = "BF16 GEMM" + def _generate_gemm_test_cases(): test_cases = [] @@ -58,15 +60,15 @@ def fwd_bwd_func(): fwd_tflops = compute_tflops(fwd_flops, fwd_ms) bwd_tflops = compute_tflops(bwd_flops, bwd_ms) - print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS") - print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)") - - return { - "TE Forward Time (ms)": f"{fwd_ms:.2f}", - "TE Forward TFLOPS": f"{fwd_tflops:.2f}", - "TE Backward Time (ms)": f"{bwd_ms:.2f}", - "TE Backward TFLOPS": f"{bwd_tflops:.2f}", - } + return make_forward_backward_metric_records( + BENCHMARK_LABEL, + "TFLOPS", + fwd_ms, + fwd_tflops, + bwd_ms, + bwd_tflops, + backward_derived=True, + ) if __name__ == "__main__": @@ -74,9 +76,5 @@ def fwd_bwd_func(): test_cases=_generate_gemm_test_cases(), bench_fn=bench_gemm, param_columns=["Case", "M", "N", "K", "dtype"], - metric_columns=[ - "TE Forward Time (ms)", "TE Forward TFLOPS", - "TE Backward Time (ms)", "TE Backward TFLOPS", - ], default_csv="benchmark_gemm.csv", ) diff --git a/benchmarks/microbenchmarks/benchmark_gemm_fp8.py b/benchmarks/microbenchmarks/benchmark_gemm_fp8.py index 3c7ebae34..e11a39a0d 100755 --- a/benchmarks/microbenchmarks/benchmark_gemm_fp8.py +++ b/benchmarks/microbenchmarks/benchmark_gemm_fp8.py @@ -16,7 +16,7 @@ from transformer_engine.common.recipe import DelayedScaling, Format from utils import ( MODEL_CONFIGS, M_SIZE_LIST, gemm_shapes, - time_func, compute_tflops, run_benchmarks, + time_func, compute_tflops, make_forward_backward_metric_records, run_benchmarks, ) FP8_RECIPE = DelayedScaling( @@ -27,6 +27,8 @@ ACTIVE_SHAPES = gemm_shapes(MODEL_CONFIGS) +BENCHMARK_LABEL = "FP8 GEMM" + def _generate_gemm_test_cases(): test_cases = [] @@ -70,15 +72,15 @@ def fwd_bwd_func(): fwd_tflops = compute_tflops(fwd_flops, fwd_ms) bwd_tflops = compute_tflops(bwd_flops, bwd_ms) - print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS") - print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)") - - return { - "FP8 Forward Time (ms)": f"{fwd_ms:.2f}", - "FP8 Forward TFLOPS": f"{fwd_tflops:.2f}", - "FP8 Backward Time (ms)": f"{bwd_ms:.2f}", - "FP8 Backward TFLOPS": f"{bwd_tflops:.2f}", - } + return make_forward_backward_metric_records( + BENCHMARK_LABEL, + "TFLOPS", + fwd_ms, + fwd_tflops, + bwd_ms, + bwd_tflops, + backward_derived=True, + ) if __name__ == "__main__": @@ -86,9 +88,5 @@ def fwd_bwd_func(): test_cases=_generate_gemm_test_cases(), bench_fn=bench_fp8_gemm, param_columns=["Case", "M", "N", "K", "dtype"], - metric_columns=[ - "FP8 Forward Time (ms)", "FP8 Forward TFLOPS", - "FP8 Backward Time (ms)", "FP8 Backward TFLOPS", - ], default_csv="benchmark_gemm_fp8.csv", ) diff --git a/benchmarks/microbenchmarks/benchmark_grouped_gemm.py b/benchmarks/microbenchmarks/benchmark_grouped_gemm.py index 484a4b3aa..63617949d 100755 --- a/benchmarks/microbenchmarks/benchmark_grouped_gemm.py +++ b/benchmarks/microbenchmarks/benchmark_grouped_gemm.py @@ -6,7 +6,14 @@ ############################################################################### import torch -from utils import time_func, compute_tflops, run_benchmarks +from utils import ( + time_func, + compute_tflops, + make_forward_backward_metric_records, + run_benchmarks, +) + +BENCHMARK_LABEL = "BF16 Grouped GEMM" def generate_grouped_gemm_group_lens(b, m, balance: bool): if balance: @@ -184,15 +191,14 @@ def bench_grouped_gemm(Case, B, M, N, K, dtype): fwd_te_tflops = compute_tflops(fwd_total_flops, fwd_te_ms) bwd_te_tflops = compute_tflops(bwd_total_flops, bwd_te_ms) - print(f" Forward {fwd_te_ms:.3f} ms | {fwd_te_tflops:.2f} TFLOPS") - print(f" Backward {bwd_te_ms:.3f} ms | {bwd_te_tflops:.2f} TFLOPS") - - return { - "TE Forward Time (ms)": f"{fwd_te_ms:.2f}", - "TE Forward TFLOPS": f"{fwd_te_tflops:.2f}", - "TE Backward Time (ms)": f"{bwd_te_ms:.2f}", - "TE Backward TFLOPS": f"{bwd_te_tflops:.2f}", - } + return make_forward_backward_metric_records( + BENCHMARK_LABEL, + "TFLOPS", + fwd_te_ms, + fwd_te_tflops, + bwd_te_ms, + bwd_te_tflops, + ) if __name__ == "__main__": @@ -207,9 +213,5 @@ def bench_grouped_gemm(Case, B, M, N, K, dtype): test_cases=test_cases, bench_fn=bench_grouped_gemm, param_columns=["Case", "B", "M", "N", "K", "dtype"], - metric_columns=[ - "TE Forward Time (ms)", "TE Forward TFLOPS", - "TE Backward Time (ms)", "TE Backward TFLOPS", - ], default_csv="benchmark_grouped_gemm.csv", ) diff --git a/benchmarks/microbenchmarks/benchmark_normalization.py b/benchmarks/microbenchmarks/benchmark_normalization.py index c6af11beb..bcaadb509 100755 --- a/benchmarks/microbenchmarks/benchmark_normalization.py +++ b/benchmarks/microbenchmarks/benchmark_normalization.py @@ -17,7 +17,7 @@ import transformer_engine.pytorch as te from utils import ( MODEL_HIDDEN_SIZES, M_SIZE_LIST, - time_func, compute_gbps, run_benchmarks, + time_func, compute_gbps, make_forward_backward_metric_records, run_benchmarks, ) NORM_TYPES = [ @@ -25,6 +25,8 @@ ("LayerNorm", te.LayerNorm), ] +BENCHMARK_LABEL = "Normalization" + def _generate_norm_test_cases(): test_cases = [] @@ -72,15 +74,15 @@ def fwd_bwd_func(): fwd_gbps = compute_gbps(fwd_bytes, fwd_ms) bwd_gbps = compute_gbps(bwd_bytes, bwd_ms) - print(f" Forward {fwd_ms:.3f} ms | {fwd_gbps:.1f} GB/s") - print(f" Backward {bwd_ms:.3f} ms | {bwd_gbps:.1f} GB/s (derived)") - - return { - "TE Forward Time (ms)": f"{fwd_ms:.4f}", - "TE Forward GB/s": f"{fwd_gbps:.1f}", - "TE Backward Time (ms)": f"{bwd_ms:.4f}", - "TE Backward GB/s": f"{bwd_gbps:.1f}", - } + return make_forward_backward_metric_records( + BENCHMARK_LABEL, + "GB/s", + fwd_ms, + fwd_gbps, + bwd_ms, + bwd_gbps, + backward_derived=True + ) if __name__ == "__main__": @@ -88,9 +90,5 @@ def fwd_bwd_func(): test_cases=_generate_norm_test_cases(), bench_fn=bench_norm, param_columns=["Case", "M", "hidden_size", "dtype"], - metric_columns=[ - "TE Forward Time (ms)", "TE Forward GB/s", - "TE Backward Time (ms)", "TE Backward GB/s", - ], default_csv="benchmark_normalization.csv", ) diff --git a/benchmarks/microbenchmarks/compare_results.py b/benchmarks/microbenchmarks/compare_results.py index 7353066bc..95dda31c5 100755 --- a/benchmarks/microbenchmarks/compare_results.py +++ b/benchmarks/microbenchmarks/compare_results.py @@ -5,14 +5,14 @@ # See LICENSE for license information. ############################################################################### """ -Compare two CSVs from the same benchmark (base branch vs PR branch). +Compare two CSVs from the same benchmark suite. -Auto-detects metric columns (containing "TFLOPS"/ "GB/s") and key columns. +Auto-detects metric columns (containing "TFLOPS" or "GB/s") and key columns. Outputs a markdown
block to stdout with per-config results, and optionally appends a summary table row to --summary-file. Usage: - python compare_results.py base.csv pr.csv --bench-name NAME --summary-file FILE + python compare_results.py baseline.csv candidate.csv --bench-name NAME --summary-file FILE """ import argparse @@ -22,6 +22,7 @@ import pandas as pd SKIP_COLS = {"TestID", "Label"} +DEFAULT_MIN_BASELINE_METRIC = 0.5 def auto_detect_columns(df): @@ -36,30 +37,45 @@ def auto_detect_columns(df): def main(): parser = argparse.ArgumentParser(description="Compare benchmark CSVs") - parser.add_argument("base_csv", help="Base branch CSV") - parser.add_argument("pr_csv", help="PR branch CSV") + parser.add_argument("baseline_csv", help="Baseline CSV") + parser.add_argument("candidate_csv", help="Candidate CSV") parser.add_argument("--bench-name", default="benchmark", help="Benchmark name for markdown output") parser.add_argument("--summary-file", default=None, help="Append a summary table row (markdown) to this file") + parser.add_argument( + "--min-baseline-metric", + type=float, + default=DEFAULT_MIN_BASELINE_METRIC, + help=( + "Small baseline metrics can produce noisy speedups; skip speedup " + "calculations when the baseline metric is below this threshold. " + "Set to 0 to disable the filter." + ), + ) args = parser.parse_args() - base_df = pd.read_csv(args.base_csv) - pr_df = pd.read_csv(args.pr_csv) + baseline_df = pd.read_csv(args.baseline_csv) + candidate_df = pd.read_csv(args.candidate_csv) - key_cols, metric_cols = auto_detect_columns(base_df) + key_cols, metric_cols = auto_detect_columns(baseline_df) if not metric_cols: print("No metric columns found.") return 0 for col in metric_cols: - base_df[col] = pd.to_numeric(base_df[col], errors="coerce") - pr_df[col] = pd.to_numeric(pr_df[col], errors="coerce") - - merged = base_df.merge(pr_df, on=key_cols, suffixes=("_base", "_pr"), how="inner") + baseline_df[col] = pd.to_numeric(baseline_df[col], errors="coerce") + candidate_df[col] = pd.to_numeric(candidate_df[col], errors="coerce") + + merged = baseline_df.merge( + candidate_df, + on=key_cols, + suffixes=("_baseline", "_candidate"), + how="inner", + ) if merged.empty: - print("WARNING: No matching rows between base and PR.") + print("WARNING: No matching rows between baseline and candidate CSVs.") return 0 all_speedups = [] @@ -70,16 +86,30 @@ def main(): row_metrics = {} for metric in metric_cols: - bc, pc = f"{metric}_base", f"{metric}_pr" - bv = merged.loc[idx, bc] - pv = merged.loc[idx, pc] + baseline_col = f"{metric}_baseline" + candidate_col = f"{metric}_candidate" + baseline_value = merged.loc[idx, baseline_col] + candidate_value = merged.loc[idx, candidate_col] - if pd.isna(bv) or pd.isna(pv) or bv < 0.5: + if pd.isna(baseline_value) or pd.isna(candidate_value): + continue + if not np.isfinite(baseline_value) or not np.isfinite(candidate_value): + continue + if baseline_value <= 0: + continue + if ( + args.min_baseline_metric > 0 + and baseline_value < args.min_baseline_metric + ): continue - speedup = pv / bv + speedup = candidate_value / baseline_value all_speedups.append(speedup) - row_metrics[metric] = {"base": bv, "pr": pv, "speedup": speedup} + row_metrics[metric] = { + "baseline": baseline_value, + "candidate": candidate_value, + "speedup": speedup, + } if row_metrics: per_row_data.append({"keys": row_keys, "metrics": row_metrics}) @@ -102,7 +132,11 @@ def main(): header_cols = list(key_cols) for m in metric_cols: short = m.replace(" TFLOPS", "") - header_cols.extend([f"{short} Base", f"{short} PR", f"{short} Speedup"]) + header_cols.extend([ + f"{short} Baseline", + f"{short} Candidate", + f"{short} Speedup", + ]) print("| " + " | ".join(header_cols) + " |") print("|" + "|".join(["---"] * len(header_cols)) + "|") @@ -112,8 +146,8 @@ def main(): for metric in metric_cols: if metric in row["metrics"]: v = row["metrics"][metric] - cells.append(f"{v['base']:.2f}") - cells.append(f"{v['pr']:.2f}") + cells.append(f"{v['baseline']:.2f}") + cells.append(f"{v['candidate']:.2f}") cells.append(f"{v['speedup']:.3f}x") else: cells.extend(["", "", ""]) diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 0e4a27be3..678db25f5 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -91,6 +91,86 @@ def compute_gbps(nbytes, ms): return nbytes / (ms * 1e-3) / 1e9 +def make_metric_record(label, ms, unit, value, derived=False, + ms_precision=3, value_precision=2): + """Create a structured metric record for stdout and CSV generation.""" + return { + "label": label, + "ms": ms, + "unit": unit, + "value": value, + "derived": derived, + "ms_precision": ms_precision, + "value_precision": value_precision, + } + + +def make_forward_backward_metric_records(label_prefix, unit, + forward_ms, forward_value, + backward_ms, backward_value, + backward_derived=False, + ms_precision=3, + value_precision=2): + """Create standard forward/backward metric records for a benchmark.""" + return [ + make_metric_record( + f"{label_prefix} Forward", + forward_ms, + unit, + forward_value, + ms_precision=ms_precision, + value_precision=value_precision, + ), + make_metric_record( + f"{label_prefix} Backward", + backward_ms, + unit, + backward_value, + derived=backward_derived, + ms_precision=ms_precision, + value_precision=value_precision, + ), + ] + + +def _metric_time_key(metric): + return f"{metric['label']} Time (ms)" + + +def _metric_value_key(metric): + return f"{metric['label']} {metric['unit']}" + + +def _format_metric_number(value, precision): + return f"{value:.{precision}f}" + + +def _metric_row_from_records(metric_records): + row = {} + for metric in metric_records: + row[_metric_time_key(metric)] = _format_metric_number( + metric["ms"], metric.get("ms_precision", 3) + ) + row[_metric_value_key(metric)] = _format_metric_number( + metric["value"], metric.get("value_precision", 2) + ) + return row + + +def _print_metric_records(metric_records): + label_width = max(24, *(len(metric["label"]) for metric in metric_records)) + for metric in metric_records: + ms_str = _format_metric_number(metric["ms"], metric.get("ms_precision", 3)) + value_str = _format_metric_number( + metric["value"], metric.get("value_precision", 2) + ) + derived_suffix = " (derived)" if metric.get("derived", False) else "" + print( + f" {metric['label']:<{label_width}} {ms_str} ms | " + f"{value_str} {metric['unit']}{derived_suffix}" + ) + + # --------------------------------------------------------------------------- # Benchmark runner # --------------------------------------------------------------------------- @@ -103,8 +183,7 @@ def add_csv_arg(parser): ) -def run_benchmarks(test_cases, bench_fn, param_columns, metric_columns, - default_csv=None): +def run_benchmarks(test_cases, bench_fn, param_columns, default_csv=None): """Iterate *test_cases*, call *bench_fn*, and optionally write a CSV. Parameters @@ -113,12 +192,10 @@ def run_benchmarks(test_cases, bench_fn, param_columns, metric_columns, Each dict has at least the keys in *param_columns* plus any extra keys the bench_fn needs (passed as **case). bench_fn : callable - Called as ``bench_fn(**case)`` and must return a dict whose keys - match *metric_columns*. + Called as ``bench_fn(**case)`` and must return a list of metric + records created by ``make_metric_record``. param_columns : list[str] Column names to pull from each test case into the output row. - metric_columns : list[str] - Column names to pull from the bench_fn return value. default_csv : str or None Default CSV filename used when ``--csv`` is passed without a filename. CSV output is only written when the caller passes @@ -128,8 +205,8 @@ def run_benchmarks(test_cases, bench_fn, param_columns, metric_columns, add_csv_arg(parser) args, _ = parser.parse_known_args() - columns = param_columns + metric_columns rows = [] + resolved_metric_columns = None for case in test_cases: label = " ".join(f"{k}={case[k]}" for k in param_columns) @@ -137,16 +214,28 @@ def run_benchmarks(test_cases, bench_fn, param_columns, metric_columns, print(f"Testing: {label}") print(f"{'='*60}") - metrics = bench_fn(**case) + metric_records = bench_fn(**case) + metric_row = _metric_row_from_records(metric_records) + _print_metric_records(metric_records) + current_metric_columns = list(metric_row.keys()) + + if resolved_metric_columns is None: + resolved_metric_columns = current_metric_columns + elif current_metric_columns != resolved_metric_columns: + raise ValueError( + f"Inconsistent metric columns for case {case}: " + f"expected {resolved_metric_columns}, got {current_metric_columns}" + ) row = {k: (str(case[k]) if isinstance(case[k], torch.dtype) else case[k]) for k in param_columns} - row.update(metrics) + row.update(metric_row) rows.append(row) if args.csv is not None: import pandas as pd out_csv = args.csv if isinstance(args.csv, str) else default_csv + columns = param_columns + (resolved_metric_columns or []) results = pd.DataFrame(rows, columns=columns) results.to_csv(out_csv, index=False) print(f"\nResults saved to {out_csv}") From 372e6df0ddc86cd2c62525cc014d6da18463f63c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 11 May 2026 14:17:41 -0500 Subject: [PATCH 23/25] Llama 3.1 --- benchmarks/microbenchmarks/utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 678db25f5..99e4ebcbb 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -21,17 +21,17 @@ # (name, hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) # # Sources: -# - Llama 3 8B https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json -# - Llama 3 70B https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json -# - Llama 3 405B https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json +# - Llama 3.1 8B https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json +# - Llama 3.1 70B https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json +# - Llama 3.1 405B https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json # - Qwen 2.5 7B https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json # - Qwen 2.5 72B https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json MODEL_CONFIGS = [ - ("Llama3-8B/TP1", 4096, 14336, 32, 8, 128, 1), - ("Llama3-8B/TP8", 4096, 14336, 32, 8, 128, 8), - ("Llama3-70B/TP8", 8192, 28672, 64, 8, 128, 8), - ("Llama3-405B/TP8", 16384, 53248, 128, 8, 128, 8), + ("Llama3.1-8B/TP1", 4096, 14336, 32, 8, 128, 1), + ("Llama3.1-8B/TP8", 4096, 14336, 32, 8, 128, 8), + ("Llama3.1-70B/TP8", 8192, 28672, 64, 8, 128, 8), + ("Llama3.1-405B/TP8", 16384, 53248, 128, 8, 128, 8), ("Qwen2.5-7B/TP1", 3584, 18944, 28, 4, 128, 1), ("Qwen2.5-72B/TP8", 8192, 29568, 64, 8, 128, 8), ] From 284addab3df9761fe76620c94b55201c7ce14b51 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 11 May 2026 15:35:56 -0500 Subject: [PATCH 24/25] address reviewer comments --- .../microbenchmarks/benchmark_casting.py | 1 - benchmarks/microbenchmarks/benchmark_gemm.py | 25 +--- .../microbenchmarks/benchmark_gemm_fp8.py | 33 ++--- .../microbenchmarks/benchmark_grouped_gemm.py | 14 +- .../benchmark_normalization.py | 20 +-- benchmarks/microbenchmarks/compare_results.py | 123 ++++++++++++------ benchmarks/microbenchmarks/utils.py | 87 +++++++++---- 7 files changed, 177 insertions(+), 126 deletions(-) diff --git a/benchmarks/microbenchmarks/benchmark_casting.py b/benchmarks/microbenchmarks/benchmark_casting.py index 1d8056d08..118070770 100755 --- a/benchmarks/microbenchmarks/benchmark_casting.py +++ b/benchmarks/microbenchmarks/benchmark_casting.py @@ -83,5 +83,4 @@ def bench_cast(Case, M, hidden_size, direction, fp8_dtype, dtype_str): test_cases=_generate_cast_test_cases(), bench_fn=bench_cast, param_columns=["Case", "M", "hidden_size", "dtype_str"], - default_csv="benchmark_casting.csv", ) diff --git a/benchmarks/microbenchmarks/benchmark_gemm.py b/benchmarks/microbenchmarks/benchmark_gemm.py index f5270e4cc..8634e7f09 100755 --- a/benchmarks/microbenchmarks/benchmark_gemm.py +++ b/benchmarks/microbenchmarks/benchmark_gemm.py @@ -1,6 +1,6 @@ #!/usr/bin/env python ############################################################################### -# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # # See LICENSE for license information. ############################################################################### @@ -9,27 +9,11 @@ import torch import transformer_engine.pytorch as te from utils import ( - MODEL_CONFIGS, M_SIZE_LIST, gemm_shapes, + generate_gemm_test_cases, time_func, compute_tflops, make_forward_backward_metric_records, run_benchmarks, ) -ACTIVE_SHAPES = gemm_shapes(MODEL_CONFIGS) - -BENCHMARK_LABEL = "BF16 GEMM" - - -def _generate_gemm_test_cases(): - test_cases = [] - for M in M_SIZE_LIST: - for case_name, (N, K) in ACTIVE_SHAPES.items(): - test_cases.append({ - "Case": case_name, - "M": M, - "N": N, - "K": K, - "dtype": torch.bfloat16, - }) - return test_cases +BENCHMARK_LABEL = "GEMM" def bench_gemm(Case, M, N, K, dtype): @@ -73,8 +57,7 @@ def fwd_bwd_func(): if __name__ == "__main__": run_benchmarks( - test_cases=_generate_gemm_test_cases(), + test_cases=generate_gemm_test_cases(), bench_fn=bench_gemm, param_columns=["Case", "M", "N", "K", "dtype"], - default_csv="benchmark_gemm.csv", ) diff --git a/benchmarks/microbenchmarks/benchmark_gemm_fp8.py b/benchmarks/microbenchmarks/benchmark_gemm_fp8.py index e11a39a0d..85623204a 100755 --- a/benchmarks/microbenchmarks/benchmark_gemm_fp8.py +++ b/benchmarks/microbenchmarks/benchmark_gemm_fp8.py @@ -15,35 +15,23 @@ import transformer_engine.pytorch as te from transformer_engine.common.recipe import DelayedScaling, Format from utils import ( - MODEL_CONFIGS, M_SIZE_LIST, gemm_shapes, + generate_gemm_test_cases, time_func, compute_tflops, make_forward_backward_metric_records, run_benchmarks, ) -FP8_RECIPE = DelayedScaling( - fp8_format=Format.HYBRID, - amax_history_len=16, - amax_compute_algo="max", -) +RECIPES = { + "hybrid": DelayedScaling( + fp8_format=Format.HYBRID, + amax_history_len=16, + amax_compute_algo="max", + ), +} -ACTIVE_SHAPES = gemm_shapes(MODEL_CONFIGS) +FP8_RECIPE = RECIPES["hybrid"] BENCHMARK_LABEL = "FP8 GEMM" -def _generate_gemm_test_cases(): - test_cases = [] - for M in M_SIZE_LIST: - for case_name, (N, K) in ACTIVE_SHAPES.items(): - test_cases.append({ - "Case": case_name, - "M": M, - "N": N, - "K": K, - "dtype": torch.bfloat16, - }) - return test_cases - - def bench_fp8_gemm(Case, M, N, K, dtype): device = "cuda" @@ -85,8 +73,7 @@ def fwd_bwd_func(): if __name__ == "__main__": run_benchmarks( - test_cases=_generate_gemm_test_cases(), + test_cases=generate_gemm_test_cases(), bench_fn=bench_fp8_gemm, param_columns=["Case", "M", "N", "K", "dtype"], - default_csv="benchmark_gemm_fp8.csv", ) diff --git a/benchmarks/microbenchmarks/benchmark_grouped_gemm.py b/benchmarks/microbenchmarks/benchmark_grouped_gemm.py index 63617949d..5197a9a38 100755 --- a/benchmarks/microbenchmarks/benchmark_grouped_gemm.py +++ b/benchmarks/microbenchmarks/benchmark_grouped_gemm.py @@ -1,19 +1,20 @@ #!/usr/bin/env python ############################################################################### -# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # # See LICENSE for license information. ############################################################################### import torch from utils import ( + DTYPE_LIST, time_func, compute_tflops, make_forward_backward_metric_records, run_benchmarks, ) -BENCHMARK_LABEL = "BF16 Grouped GEMM" +BENCHMARK_LABEL = "Grouped GEMM" def generate_grouped_gemm_group_lens(b, m, balance: bool): if balance: @@ -26,7 +27,9 @@ def generate_grouped_gemm_group_lens(b, m, balance: bool): group_lens[-1] += error return group_lens -M_SIZE_LIST = [512, 1024, 2048, 4096] +# Grouped GEMM scales with expert count B, so we sweep smaller M values than +# the dense GEMM benchmarks to keep the working set and runtime reasonable. +GROUPED_GEMM_M_SIZE_LIST = [512, 1024, 2048, 4096] EP_SIZE_LIST = [32, 16, 8] @@ -48,9 +51,9 @@ def _generate_moe_test_cases( B = n_routed_experts // ep if B < 1: continue - for M in M_SIZE_LIST: + for M in GROUPED_GEMM_M_SIZE_LIST: for name, (N, K) in shapes_dict.items(): - for dtype in [torch.bfloat16]: + for dtype in DTYPE_LIST: test_cases.append( { "Case": name, @@ -213,5 +216,4 @@ def bench_grouped_gemm(Case, B, M, N, K, dtype): test_cases=test_cases, bench_fn=bench_grouped_gemm, param_columns=["Case", "B", "M", "N", "K", "dtype"], - default_csv="benchmark_grouped_gemm.csv", ) diff --git a/benchmarks/microbenchmarks/benchmark_normalization.py b/benchmarks/microbenchmarks/benchmark_normalization.py index bcaadb509..25c8cef46 100755 --- a/benchmarks/microbenchmarks/benchmark_normalization.py +++ b/benchmarks/microbenchmarks/benchmark_normalization.py @@ -16,7 +16,7 @@ import torch import transformer_engine.pytorch as te from utils import ( - MODEL_HIDDEN_SIZES, M_SIZE_LIST, + DTYPE_LIST, MODEL_HIDDEN_SIZES, M_SIZE_LIST, time_func, compute_gbps, make_forward_backward_metric_records, run_benchmarks, ) @@ -33,14 +33,15 @@ def _generate_norm_test_cases(): for model_name, hidden in MODEL_HIDDEN_SIZES: for norm_name, norm_cls in NORM_TYPES: for M in M_SIZE_LIST: - test_cases.append({ - "Case": f"{model_name}/{norm_name}", - "M": M, - "hidden_size": hidden, - "norm_name": norm_name, - "norm_cls": norm_cls, - "dtype": torch.bfloat16, - }) + for dtype in DTYPE_LIST: + test_cases.append({ + "Case": f"{model_name}/{norm_name}", + "M": M, + "hidden_size": hidden, + "norm_name": norm_name, + "norm_cls": norm_cls, + "dtype": dtype, + }) return test_cases @@ -90,5 +91,4 @@ def fwd_bwd_func(): test_cases=_generate_norm_test_cases(), bench_fn=bench_norm, param_columns=["Case", "M", "hidden_size", "dtype"], - default_csv="benchmark_normalization.csv", ) diff --git a/benchmarks/microbenchmarks/compare_results.py b/benchmarks/microbenchmarks/compare_results.py index 95dda31c5..4a7e1dab8 100755 --- a/benchmarks/microbenchmarks/compare_results.py +++ b/benchmarks/microbenchmarks/compare_results.py @@ -35,6 +35,20 @@ def auto_detect_columns(df): return key_cols, metric_cols +def print_key_table(title, rows_df, key_cols): + if rows_df.empty: + return + + print(title) + print() + print("| " + " | ".join(key_cols) + " |") + print("|" + "|".join(["---"] * len(key_cols)) + "|") + for idx in rows_df.index: + cells = [str(rows_df.loc[idx, key]) for key in key_cols] + print("| " + " | ".join(cells) + " |") + print() + + def main(): parser = argparse.ArgumentParser(description="Compare benchmark CSVs") parser.add_argument("baseline_csv", help="Baseline CSV") @@ -72,24 +86,29 @@ def main(): candidate_df, on=key_cols, suffixes=("_baseline", "_candidate"), - how="inner", + how="outer", + indicator=True, ) if merged.empty: - print("WARNING: No matching rows between baseline and candidate CSVs.") + print("WARNING: No rows found in baseline or candidate CSVs.") return 0 + matched = merged[merged["_merge"] == "both"] + baseline_only = merged[merged["_merge"] == "left_only"] + candidate_only = merged[merged["_merge"] == "right_only"] + all_speedups = [] per_row_data = [] - for idx in merged.index: - row_keys = {k: merged.loc[idx, k] for k in key_cols} + for idx in matched.index: + row_keys = {k: matched.loc[idx, k] for k in key_cols} row_metrics = {} for metric in metric_cols: baseline_col = f"{metric}_baseline" candidate_col = f"{metric}_candidate" - baseline_value = merged.loc[idx, baseline_col] - candidate_value = merged.loc[idx, candidate_col] + baseline_value = matched.loc[idx, baseline_col] + candidate_value = matched.loc[idx, candidate_col] if pd.isna(baseline_value) or pd.isna(candidate_value): continue @@ -114,53 +133,75 @@ def main(): if row_metrics: per_row_data.append({"keys": row_keys, "metrics": row_metrics}) - if not all_speedups: - print("WARNING: No valid comparisons found.") - return 0 - - speedups = np.array(all_speedups) - median_sp = float(np.median(speedups)) - min_sp = float(np.min(speedups)) - max_sp = float(np.max(speedups)) + summary_row = None + + if all_speedups: + speedups = np.array(all_speedups) + median_sp = float(np.median(speedups)) + min_sp = float(np.min(speedups)) + max_sp = float(np.max(speedups)) + summary_row = ( + f"| {args.bench_name} | {median_sp:.3f}x | {min_sp:.3f}x | {max_sp:.3f}x |\n" + ) + summary = ( + f"{args.bench_name} " + f"(median {median_sp:.3f}x, min {min_sp:.3f}x, max {max_sp:.3f}x)" + ) + elif not matched.empty: + summary = ( + f"{args.bench_name} " + f"(no valid speedups after filtering)" + ) + else: + summary = f"{args.bench_name} (no overlapping rows)" # Details block print("
") - print(f"{args.bench_name} " - f"(median {median_sp:.3f}x, min {min_sp:.3f}x, max {max_sp:.3f}x)") + print(summary) print() - header_cols = list(key_cols) - for m in metric_cols: - short = m.replace(" TFLOPS", "") - header_cols.extend([ - f"{short} Baseline", - f"{short} Candidate", - f"{short} Speedup", - ]) - - print("| " + " | ".join(header_cols) + " |") - print("|" + "|".join(["---"] * len(header_cols)) + "|") - - for row in per_row_data: - cells = [str(row["keys"].get(k, "")) for k in key_cols] - for metric in metric_cols: - if metric in row["metrics"]: - v = row["metrics"][metric] - cells.append(f"{v['baseline']:.2f}") - cells.append(f"{v['candidate']:.2f}") - cells.append(f"{v['speedup']:.3f}x") - else: - cells.extend(["", "", ""]) - print("| " + " | ".join(cells) + " |") + if per_row_data: + header_cols = list(key_cols) + for m in metric_cols: + short = m.replace(" TFLOPS", "") + header_cols.extend([ + f"{short} Baseline", + f"{short} Candidate", + f"{short} Speedup", + ]) + + print("| " + " | ".join(header_cols) + " |") + print("|" + "|".join(["---"] * len(header_cols)) + "|") + + for row in per_row_data: + cells = [str(row["keys"].get(k, "")) for k in key_cols] + for metric in metric_cols: + if metric in row["metrics"]: + v = row["metrics"][metric] + cells.append(f"{v['baseline']:.2f}") + cells.append(f"{v['candidate']:.2f}") + cells.append(f"{v['speedup']:.3f}x") + else: + cells.extend(["", "", ""]) + print("| " + " | ".join(cells) + " |") + print() + elif not matched.empty: + print("No overlapping metric rows produced a valid speedup after filtering.") + print() + + print_key_table("Rows only in candidate", candidate_only, key_cols) + print_key_table("Rows only in baseline", baseline_only, key_cols) - print() print("
") print() # Summary row if args.summary_file: with open(args.summary_file, "a") as f: - f.write(f"| {args.bench_name} | {median_sp:.3f}x | {min_sp:.3f}x | {max_sp:.3f}x |\n") + if summary_row is not None: + f.write(summary_row) + else: + f.write(f"| {args.bench_name} | n/a | n/a | n/a |\n") return 0 diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 99e4ebcbb..cad25c413 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -15,6 +15,12 @@ # --------------------------------------------------------------------------- M_SIZE_LIST = [1024, 2048, 4096, 8192] +# Shared dtype sweep for TE activation benchmarks. Extend this list to add +# additional precisions such as torch.float16. +DTYPE_LIST = [torch.bfloat16] + +DEFAULT_MIN_RUN_TIME_SECONDS = 0.2 + # --------------------------------------------------------------------------- # Model configurations # --------------------------------------------------------------------------- @@ -61,11 +67,28 @@ def gemm_shapes(configs=None): return shapes +def generate_gemm_test_cases(configs=None, m_sizes=None, dtypes=None): + """Generate dense GEMM benchmark cases shared by BF16 and FP8 GEMM.""" + test_cases = [] + active_shapes = gemm_shapes(configs) + for m_value in (m_sizes or M_SIZE_LIST): + for case_name, (n_value, k_value) in active_shapes.items(): + for dtype in (dtypes or DTYPE_LIST): + test_cases.append({ + "Case": case_name, + "M": m_value, + "N": n_value, + "K": k_value, + "dtype": dtype, + }) + return test_cases + + # --------------------------------------------------------------------------- # Timing helpers # --------------------------------------------------------------------------- -def time_func(fn, method="adaptive"): +def time_func(fn, method="adaptive", min_run_time=DEFAULT_MIN_RUN_TIME_SECONDS): """Time *fn* and return elapsed milliseconds. method: "adaptive" uses adaptive_autorange (good for compute-bound), @@ -73,8 +96,8 @@ def time_func(fn, method="adaptive"): """ timer = benchmark.Timer(stmt="fn()", globals={"fn": fn}) if method == "blocked": - return timer.blocked_autorange().mean * 1e3 - return timer.adaptive_autorange().mean * 1e3 + return timer.blocked_autorange(min_run_time=min_run_time).mean * 1e3 + return timer.adaptive_autorange(min_run_time=min_run_time).mean * 1e3 # --------------------------------------------------------------------------- @@ -91,44 +114,49 @@ def compute_gbps(nbytes, ms): return nbytes / (ms * 1e-3) / 1e9 -def make_metric_record(label, ms, unit, value, derived=False, - ms_precision=3, value_precision=2): - """Create a structured metric record for stdout and CSV generation.""" +def make_metric_record(label, ms, unit, throughput, derived=False, + ms_precision=3, throughput_precision=2): + """Create a structured metric record for stdout and CSV generation. + + Each record describes one benchmark line item such as "GEMM Forward". + ``run_benchmarks`` formats these records for stdout and expands them into + ``