diff --git a/.gitignore b/.gitignore index a1da56aa9..726bf6f9e 100644 --- a/.gitignore +++ b/.gitignore @@ -58,4 +58,5 @@ artifacts/ **/times.csv transformer_engine/build_info.txt transformer_engine/common/util/hip_nvml.* +.bench-results/ *.DS_Store diff --git a/benchmarks/microbench/README.md b/benchmarks/microbench/README.md new file mode 100644 index 000000000..6ffe98ae0 --- /dev/null +++ b/benchmarks/microbench/README.md @@ -0,0 +1,114 @@ +# Microbenchmarks for TransformerEngine + +GPU microbenchmarks driven by `driver.py`. Each `bench_*.py` file defines one +or more bench classes following an ASV-style API (`params`, `param_names`, +`time_*` methods, optional `work_*` companions). Timing uses +`torch.utils.benchmark.Timer` under the hood. The driver runs each suite +in-process and writes results as long-format CSV — one row per Timer block — +intended to be consumed by a separate analysis tool (statistical tests, +cross-run comparison). + +## Prerequisites + +- TransformerEngine must already be built and installed in the current Python environment. +- A ROCm or CUDA GPU must be available. + +## Running benchmarks + +Each `bench_*.py` file is directly executable, or you can drive them through +`driver.py`. Results are written by default to +`benchmarks/.bench-results//.csv`. + +```bash +cd benchmarks/microbench +python driver.py --all # run every suite +python driver.py bench_gemm # run one suite via driver +python bench_gemm.py # run one suite directly +python bench_gemm.py time_forward # filter to method names containing this string +python bench_casting.py --no-csv # stdout only, don't write CSV +python bench_casting.py --csv out.csv # custom output path +python bench_casting.py --append # append to existing CSV +``` + +## Output format + +Long-format CSV — one row per `torch.utils.benchmark` block. Default location +is `benchmarks/.bench-results//.csv`; the +`.bench-results` tree is in `.gitignore`. Schema: + +| Column | Type | Description | +|---|---|---| +| `suite` | str | Module name (e.g. `bench_gemm`) | +| `class` | str | Bench class name (e.g. `BenchGemm`) | +| `method` | str | Timed method (e.g. `time_forward`) | +| `params` | str | `k1=v1;k2=v2` canonical form for joining across runs | +| `sample_idx` | int | Block index within this Measurement | +| `time_s` | float | Per-call elapsed seconds (Timer normalizes by `number_per_run`) | +| `number_per_run` | int | Kernel invocations averaged into this row's `time_s` | +| `tflops` | float | Per-call throughput, empty if no `work_*` flops | +| `gbps` | float | Per-call bandwidth, empty if no `work_*` bytes | +| `commit` | str | Short git HEAD hash | +| `machine` | str | `platform.node()` | +| `started_at_ms` | int | Unix-ms timestamp when this method's run began | + +Per-PR comparison and statistical tests are handled by a separate analysis +tool (TBD) that reads two or more of these CSVs and joins on +`(suite, class, method, params)`. Note that `time_s` is a *block mean* — +the analysis tool should weight by `number_per_run` (or use blocks as +independent samples) when computing significance. + +## Writing new benchmarks + +Create a new file in `benchmarks/microbench/` following the naming convention `bench_.py`. + +```python +#!/usr/bin/env python3 +import torch +import transformer_engine.pytorch as te + +from driver import time_func + + +class BenchSomething: + params = [[1024, 4096], ["config_a", "config_b"]] + param_names = ["M", "config"] + timeout = 300 # seconds, per parameter combination + + def setup(self, M, config): + # Allocate tensors, create modules. + # Runs once per (combo, method); same instance is reused for warmup + # and timed Timer blocks. + self.module = ... + self.x = ... + + def time_forward(self, M, config): + return time_func(lambda: self.module(self.x)) + + def time_forward_backward(self, M, config): + def fn(): + out = self.module(self.x) + out.backward(self.grad_out) + return time_func(fn) + + # Optional: define work_ to get throughput columns (TFLOPS / GB/s). + def work_forward(self, M, config): + return {"flops": 2 * M * self.N * self.K} # compute-bound + # return {"bytes": M * self.hidden * 4} # memory-bound + + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) +``` + +Key rules: +- Method names starting with `time_` are automatically timed. +- `time_*` methods must return `time_func(fn)` — a `torch.utils.benchmark.Measurement`. +- Inside `fn`, do whatever per-call work you want measured. For backward, + let gradients accumulate in-place across iterations — Timer's repeated + invocations don't OOM (grads accumulate into the same tensor) and the + numerical correctness of accumulated grad doesn't affect timing. +- Optionally define `work_` companions to get TFLOPS or GB/s columns. + Return per-call work; the driver derives per-sample throughput. +- The `params` list defines a cross-product; keep the matrix size reasonable. + diff --git a/benchmarks/microbench/__init__.py b/benchmarks/microbench/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/benchmarks/microbench/bench_attention.py b/benchmarks/microbench/bench_attention.py new file mode 100644 index 000000000..5ece0a0d0 --- /dev/null +++ b/benchmarks/microbench/bench_attention.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +Attention micro-benchmark using te.DotProductAttention. + +Benchmarks fused multi-head attention (with flash attention backend) for +model configurations with grouped-query attention (GQA). + +Models: + - Llama 3 8B (TP=1, TP=8), 70B (TP=8), 405B (TP=8) + - Qwen 2.5 7B (TP=1), 72B (TP=8) + +Forward FLOPs = 4 * batch * num_q_heads * seq_len^2 * head_dim + (two matmuls: Q@K^T and attn@V, each contributing 2*b*h*s^2*d) +Backward FLOPs = 2 * Forward FLOPs (approximately) + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json +""" + +import torch +import transformer_engine.pytorch as te + +from driver import time_func + +BATCH = 2 + +# (num_q_heads, num_kv_heads, head_dim, tp) +MODELS = { + "Llama3-8B_TP1": (32, 8, 128, 1), + "Llama3-8B_TP8": (32, 8, 128, 8), + "Llama3-70B_TP8": (64, 8, 128, 8), + "Llama3-405B_TP8": (128, 8, 128, 8), + "Qwen2.5-7B_TP1": (28, 4, 128, 1), + "Qwen2.5-72B_TP8": (64, 8, 128, 8), +} + + +class BenchAttention: + params = [[1024, 2048, 4096, 8192], list(MODELS)] + param_names = ["seq_len", "model"] + timeout = 300 + + def setup(self, seq_len, model): + n_q, n_kv, hd, tp = MODELS[model] + qh, kvh = n_q // tp, n_kv // tp + dtype = torch.bfloat16 + + self.attn = te.DotProductAttention( + num_attention_heads=qh, kv_channels=hd, + num_gqa_groups=kvh, attn_mask_type="causal", + ).to(device="cuda", dtype=dtype) + + self.q = torch.randn(seq_len, BATCH, qh, hd, dtype=dtype, device="cuda", requires_grad=True) + self.k = torch.randn(seq_len, BATCH, kvh, hd, dtype=dtype, device="cuda", requires_grad=True) + self.v = torch.randn(seq_len, BATCH, kvh, hd, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn_like(self.attn(self.q, self.k, self.v)) + + def work_forward(self, seq_len, model): + n_q, n_kv, hd, tp = MODELS[model] + qh = n_q // tp + return {"flops": 4 * BATCH * qh * seq_len * seq_len * hd} + + def work_forward_backward(self, seq_len, model): + n_q, n_kv, hd, tp = MODELS[model] + qh = n_q // tp + return {"flops": 3 * 4 * BATCH * qh * seq_len * seq_len * hd} + + def time_forward(self, seq_len, model): + return time_func(lambda: self.attn(self.q, self.k, self.v)) + + def time_forward_backward(self, seq_len, model): + def fn(): + out = self.attn(self.q, self.k, self.v) + out.backward(self.grad_out) + return time_func(fn) + + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/microbench/bench_casting.py b/benchmarks/microbench/bench_casting.py new file mode 100644 index 000000000..b406b84fc --- /dev/null +++ b/benchmarks/microbench/bench_casting.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +Benchmarks quantization (BF16 -> FP8) and dequantization (FP8 -> BF16) for +both E4M3 (activations/weights) and E5M2 (gradients) formats. + +Shapes are (M, hidden_size) matching the activation tensors from models: + - Llama 3.1 8B, 70B, 405B + - Qwen 2.5 7B, 72B + +These casts are memory-bound; we report GB/s (input + output bytes). + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json +""" + +import torch +from transformer_engine.pytorch import Float8CurrentScalingQuantizer +from transformer_engine_torch import DType as TE_DType + +from driver import time_func + +HIDDEN_SIZES = { + "Llama3-8B": 4096, + "Llama3-70B": 8192, + "Llama3-405B": 16384, + "Qwen2.5-7B": 3584, + "Qwen2.5-72B": 8192, +} + +CAST_CONFIGS = { + "BF16_to_E4M3": ("quantize", TE_DType.kFloat8E4M3), + "E4M3_to_BF16": ("dequantize", TE_DType.kFloat8E4M3), + "BF16_to_E5M2": ("quantize", TE_DType.kFloat8E5M2), + "E5M2_to_BF16": ("dequantize", TE_DType.kFloat8E5M2), +} + + +class BenchCasting: + params = [[1024, 2048, 4096, 8192], list(HIDDEN_SIZES), list(CAST_CONFIGS)] + param_names = ["M", "model", "cast"] + timeout = 120 + + def setup(self, M, model, cast): + hidden = HIDDEN_SIZES[model] + direction, fp8_dtype = CAST_CONFIGS[cast] + self.direction = direction + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=fp8_dtype, + device=torch.device("cuda"), + rowwise=True, + columnwise=False, + ) + if direction == "dequantize": + bf16_tensor = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") + self.x = quantizer.quantize(bf16_tensor) + else: + self.x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") + self.quantizer = quantizer + + def work_cast(self, M, model, cast): + hidden = HIDDEN_SIZES[model] + # Read input (1B FP8 or 2B BF16) + write output + scale (~hidden bytes total) + # Approximated as 3 bytes per element either direction. + return {"bytes": M * hidden * 3} + + def time_cast(self, M, model, cast): + if self.direction == "quantize": + return time_func(lambda: self.quantizer.quantize(self.x)) + return time_func(lambda: self.x.dequantize(dtype=torch.bfloat16)) + + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/microbench/bench_gemm.py b/benchmarks/microbench/bench_gemm.py new file mode 100644 index 000000000..2c2adc9f6 --- /dev/null +++ b/benchmarks/microbench/bench_gemm.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""BF16 GEMM benchmarks via te.Linear. + +GEMM shapes derived from transformer layer projections: + QKV, AttnOut, GateUp (SwiGLU), Down. + +Model configuration sources: +- Llama 3 8B (hidden=4096, intermediate=14336, heads=32, kv_heads=8, head_dim=128) + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + +- Llama 3 70B (hidden=8192, intermediate=28672, heads=64, kv_heads=8, head_dim=128) + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + +- Llama 3 405B (hidden=16384, intermediate=53248, heads=128, kv_heads=8, head_dim=128) + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + +- Qwen 2.5 7B (hidden=3584, intermediate=18944, heads=28, kv_heads=4, head_dim=128) + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + +- Qwen 2.5 72B (hidden=8192, intermediate=29568, heads=64, kv_heads=8, head_dim=128) + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json +""" + +import torch +import transformer_engine.pytorch as te + +from driver import time_func + +# (hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) +MODELS = { + "Llama3-8B_TP1": (4096, 14336, 32, 8, 128, 1), + "Llama3-8B_TP8": (4096, 14336, 32, 8, 128, 8), + "Llama3-70B_TP8": (8192, 28672, 64, 8, 128, 8), + "Llama3-405B_TP8": (16384, 53248, 128, 8, 128, 8), + "Qwen2.5-7B_TP1": (3584, 18944, 28, 4, 128, 1), + "Qwen2.5-72B_TP8": (8192, 29568, 64, 8, 128, 8), +} + +# Pre-compute (N, K) for each GEMM shape +SHAPES = {} +for _name, (h, inter, nq, nkv, hd, tp) in MODELS.items(): + SHAPES[f"{_name}-QKV"] = ((nq * hd + 2 * nkv * hd) // tp, h) + SHAPES[f"{_name}-AttnOut"] = (h, (nq * hd) // tp) + SHAPES[f"{_name}-GateUp"] = ((2 * inter) // tp, h) + SHAPES[f"{_name}-Down"] = (h, inter // tp) + + +class BenchGemm: + params = [[1024, 2048, 4096, 8192], list(SHAPES)] + param_names = ["M", "shape"] + timeout = 300 + + def setup(self, M, shape): + N, K = SHAPES[shape] + dtype = torch.bfloat16 + self.linear = te.Linear(K, N, bias=False).to(device="cuda", dtype=dtype) + self.x = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn_like(self.linear(self.x)) + + def work_forward(self, M, shape): + N, K = SHAPES[shape] + return {"flops": 2 * M * N * K} + + def work_forward_backward(self, M, shape): + N, K = SHAPES[shape] + return {"flops": 3 * 2 * M * N * K} + + def time_forward(self, M, shape): + return time_func(lambda: self.linear(self.x)) + + def time_forward_backward(self, M, shape): + def fn(): + out = self.linear(self.x) + out.backward(self.grad_out) + return time_func(fn) + + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/microbench/bench_gemm_fp8.py b/benchmarks/microbench/bench_gemm_fp8.py new file mode 100644 index 000000000..9669ca2d2 --- /dev/null +++ b/benchmarks/microbench/bench_gemm_fp8.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +FP8 GEMM benchmarks via te.Linear under fp8_autocast. + +Same shapes as bench_gemm.py but with FP8 quantized compute: + - Llama 3 8B (TP=1, TP=8), 70B (TP=8), 405B (TP=8) + - Qwen 2.5 7B (TP=1), 72B (TP=8) + +Each model contributes four GEMM shapes: + QKV projection (column-parallel) N = (Qheads + 2*KVheads)*head_dim / TP, K = hidden + Attention output (row-parallel) N = hidden, K = Qheads*head_dim / TP + MLP Gate+Up (column-parallel) N = 2*intermediate / TP, K = hidden (SwiGLU) + MLP Down (row-parallel) N = hidden, K = intermediate / TP + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json +""" + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling, Format + +from driver import time_func + +# (hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) +MODELS = { + "Llama3-8B_TP1": (4096, 14336, 32, 8, 128, 1), + "Llama3-8B_TP8": (4096, 14336, 32, 8, 128, 8), + "Llama3-70B_TP8": (8192, 28672, 64, 8, 128, 8), + "Llama3-405B_TP8": (16384, 53248, 128, 8, 128, 8), + "Qwen2.5-7B_TP1": (3584, 18944, 28, 4, 128, 1), + "Qwen2.5-72B_TP8": (8192, 29568, 64, 8, 128, 8), +} + +SHAPES = {} +for _name, (h, inter, nq, nkv, hd, tp) in MODELS.items(): + SHAPES[f"{_name}-QKV"] = ((nq * hd + 2 * nkv * hd) // tp, h) + SHAPES[f"{_name}-AttnOut"] = (h, (nq * hd) // tp) + SHAPES[f"{_name}-GateUp"] = ((2 * inter) // tp, h) + SHAPES[f"{_name}-Down"] = (h, inter // tp) + +FP8_RECIPE = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max", +) + + +class BenchGemmFP8: + params = [[1024, 2048, 4096, 8192], list(SHAPES)] + param_names = ["M", "shape"] + timeout = 300 + + def setup(self, M, shape): + N, K = SHAPES[shape] + dtype = torch.bfloat16 + self.linear = te.Linear(K, N, bias=False).to(device="cuda", dtype=dtype) + self.x = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn(M, N, dtype=dtype, device="cuda") + + def work_forward(self, M, shape): + N, K = SHAPES[shape] + return {"flops": 2 * M * N * K} + + def work_forward_backward(self, M, shape): + N, K = SHAPES[shape] + return {"flops": 3 * 2 * M * N * K} + + def time_forward(self, M, shape): + def fn(): + with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + self.linear(self.x) + return time_func(fn) + + def time_forward_backward(self, M, shape): + def fn(): + with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + out = self.linear(self.x) + out.backward(self.grad_out) + return time_func(fn) + + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/microbench/bench_grouped_gemm.py b/benchmarks/microbench/bench_grouped_gemm.py new file mode 100644 index 000000000..31243e2a1 --- /dev/null +++ b/benchmarks/microbench/bench_grouped_gemm.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""Grouped GEMM benchmarks via te.GroupedLinear. + +MoE model configurations with GateUp and Down projections. +Configurations are based on: +https://github.com/AMD-AGI/Primus-Turbo/blob/main/benchmark/ops/config.py +""" + +import torch +import transformer_engine.pytorch as te + +from driver import time_func + +# (n_routed_experts, moe_intermediate_size, hidden_size) +MOE_MODELS = { + "DSV2-Lite": (64, 1408, 2048), + "DSV2": (160, 1536, 5120), + "DSV3": (256, 2048, 7168), + "Grok-V2": (8, 16384, 8192), +} + +# Build (config_key -> (num_gemms, N, K)) mapping +CONFIGS = {} +for model, (n_experts, inter, hidden) in MOE_MODELS.items(): + for ep in [32, 16, 8]: + if n_experts % ep != 0: + continue + B = n_experts // ep + CONFIGS[f"{model}_EP{ep}-GateUp"] = (B, 2 * inter, hidden) + CONFIGS[f"{model}_EP{ep}-Down"] = (B, hidden, inter) + + +class BenchGroupedGemm: + params = [[512, 1024, 2048, 4096], list(CONFIGS)] + param_names = ["M", "config"] + timeout = 300 + + def setup(self, M, config): + B, N, K = CONFIGS[config] + dtype = torch.bfloat16 + + self.module = te.GroupedLinear( + num_gemms=B, in_features=K, out_features=N, bias=False, + ).to(device="cuda", dtype=dtype) + + self.xs = [ + torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) + for _ in range(B) + ] + outs = self.module(self.xs) + self.grad_outs = [torch.randn_like(o) for o in outs] + + def work_forward(self, M, config): + B, N, K = CONFIGS[config] + return {"flops": B * 2 * M * N * K} + + def work_forward_backward(self, M, config): + B, N, K = CONFIGS[config] + return {"flops": B * 3 * 2 * M * N * K} + + def time_forward(self, M, config): + return time_func(lambda: self.module(self.xs)) + + def time_forward_backward(self, M, config): + def fn(): + outs = self.module(self.xs) + torch.autograd.backward(outs, self.grad_outs) + return time_func(fn) + + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/microbench/bench_normalization.py b/benchmarks/microbench/bench_normalization.py new file mode 100644 index 000000000..16e49394b --- /dev/null +++ b/benchmarks/microbench/bench_normalization.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +RMSNorm and LayerNorm benchmarks on activation-sized tensors. + +Shapes are derived from training workloads: + - Llama 3 8B, 70B, 405B (all use RMSNorm) + - Qwen 2.5 7B, 72B (all use RMSNorm) + +Modern models predominantly use RMSNorm, but we benchmark both +LayerNorm and RMSNorm since TE supports both and they share the +same kernel infrastructure. + +The M dimension (batch * seq_len) is swept across typical training sizes. + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json +""" + +import torch +import transformer_engine.pytorch as te + +from driver import time_func + +NORMS = {"RMSNorm": te.RMSNorm, "LayerNorm": te.LayerNorm} +HIDDEN_SIZES = [3584, 4096, 8192, 16384] + + +class BenchNormalization: + params = [[1024, 2048, 4096, 8192], HIDDEN_SIZES, list(NORMS)] + param_names = ["M", "hidden", "norm_type"] + timeout = 120 + + def setup(self, M, hidden, norm_type): + dtype = torch.bfloat16 + self.norm = NORMS[norm_type](hidden).to(device="cuda", dtype=dtype) + self.x = torch.randn(M, hidden, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn_like(self.norm(self.x)) + + def work_forward(self, M, hidden, norm_type): + # Read input (2B) + write output (2B) = 4 bytes per element + return {"bytes": M * hidden * 4} + + def work_forward_backward(self, M, hidden, norm_type): + # Fwd: read+write (4B), Bwd: read input+grad_out+write grad_in (6B) = 10B + return {"bytes": M * hidden * 10} + + def time_forward(self, M, hidden, norm_type): + return time_func(lambda: self.norm(self.x)) + + def time_forward_backward(self, M, hidden, norm_type): + def fn(): + out = self.norm(self.x) + out.backward(self.grad_out) + return time_func(fn) + + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/microbench/driver.py b/benchmarks/microbench/driver.py new file mode 100644 index 000000000..e7354859c --- /dev/null +++ b/benchmarks/microbench/driver.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""Microbenchmark driver — runs bench classes via torch.utils.benchmark.Timer +and writes long-format CSV. + +Usage: + python driver.py [method_filter] [--csv FILE | --no-csv] + python driver.py --all [--csv FILE | --no-csv] + python bench_gemm.py [method_filter] [--csv FILE | --no-csv] + +CSV schema (one row per Timer block): + suite, class, method, params, sample_idx, time_s, number_per_run, + tflops, gbps, commit, machine, started_at_ms + +Each row's `time_s` is one block's per-call mean (block_total / number_per_run). +The downstream analysis tool can group by (suite, class, method, params) to +recover the distribution of block-mean per-call times. +""" + +import argparse +import csv +import glob +import importlib +import itertools +import os +import platform +import subprocess +import sys +import time + +import torch.utils.benchmark as benchmark + + +# --------------------------------------------------------------------------- +# Environment metadata +# --------------------------------------------------------------------------- + +def _get_machine_name(): + return platform.node() or "unknown" + + +def _get_commit_hash(short=False): + try: + sha = subprocess.check_output( + ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL + ).decode().strip() + return sha[:8] if short else sha + except Exception: + return "unknown" + + +# --------------------------------------------------------------------------- +# Timing helper used by bench files +# --------------------------------------------------------------------------- + +def time_func(fn, min_run_time=1.0, method="blocked"): + """Time *fn* with torch.utils.benchmark.Timer and return the Measurement. + + The Measurement object exposes per-block elapsed times (`.times`) and + `.number_per_run` (kernel invocations averaged per block). The driver + flattens these into long-format CSV rows. + + method: + "blocked" — fixed-block sampling, more samples (recommended for stats). + "adaptive" — stops when noise threshold is met; fewer, variable samples. + """ + timer = benchmark.Timer(stmt="fn()", globals={"fn": fn}) + if method == "adaptive": + return timer.adaptive_autorange(min_run_time=min_run_time) + return timer.blocked_autorange(min_run_time=min_run_time) + + +# --------------------------------------------------------------------------- +# CSV output +# --------------------------------------------------------------------------- + +CSV_COLUMNS = [ + "suite", "class", "method", "params", "sample_idx", "time_s", + "number_per_run", "tflops", "gbps", "commit", "machine", "started_at_ms", +] + + +def _default_csv_path(script_dir): + """benchmarks/.bench-results//.csv, anchored at the repo root.""" + repo_root = os.path.abspath(os.path.join(script_dir, "..", "..")) + return os.path.join( + repo_root, "benchmarks", ".bench-results", + _get_machine_name(), f"{_get_commit_hash(short=True)}.csv", + ) + + +def save_csv_results(rows, csv_path, append=False): + """Write sample rows to *csv_path* (long format, one row per Timer block).""" + os.makedirs(os.path.dirname(os.path.abspath(csv_path)) or ".", exist_ok=True) + write_header = not (append and os.path.exists(csv_path)) + with open(csv_path, "a" if append else "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=CSV_COLUMNS, extrasaction="ignore") + if write_header: + writer.writeheader() + writer.writerows(rows) + print(f"\nResults {'appended to' if append else 'saved to'} {csv_path} " + f"({len(rows)} rows)") + + +# --------------------------------------------------------------------------- +# Benchmark runner +# --------------------------------------------------------------------------- + +def _format_params(param_names, combo): + """Canonical 'k1=v1;k2=v2' string for joining across runs.""" + return ";".join(f"{n}={v}" for n, v in zip(param_names, combo)) + + +def _measurement_to_rows(measurement, *, suite, class_name, method_name, + params_str, work, commit, machine, started_at_ms): + """Flatten a Timer Measurement into one CSV row per block. + + Measurement.times is already per-call (Timer divides by number_per_run + internally). number_per_run is recorded as metadata so the analysis tool + knows how many invocations were averaged into each row's time_s. + """ + n = measurement.number_per_run + rows = [] + for i, per_call_s in enumerate(measurement.times): + rows.append({ + "suite": suite, + "class": class_name, + "method": method_name, + "params": params_str, + "sample_idx": i, + "time_s": per_call_s, + "number_per_run": n, + "tflops": (work["flops"] / per_call_s / 1e12) + if "flops" in work and per_call_s > 0 else "", + "gbps": (work["bytes"] / per_call_s / 1e9) + if "bytes" in work and per_call_s > 0 else "", + "commit": commit, + "machine": machine, + "started_at_ms": started_at_ms, + }) + return rows + + +def run_class(suite_name, cls, class_name, method_filter=None, + commit=None, machine=None): + """Run all benchmarks in a class. Returns a list of CSV row dicts.""" + methods = sorted(m for m in dir(cls) if m.startswith("time_")) + if method_filter: + methods = [m for m in methods if method_filter in m] + if not methods: + return [] + + params = getattr(cls, "params", [[]]) + param_names = getattr(cls, "param_names", []) + combos = list(itertools.product(*params)) + + # Discover throughput columns from work_* companions + probe_keys = set() + for m in methods: + wfn = getattr(cls, "work_" + m[5:], None) + if wfn: + try: + probe_keys.update(wfn(cls(), *combos[0])) + except Exception: + pass + has_tflops = "flops" in probe_keys + has_gbps = "bytes" in probe_keys + + print(f"\n{class_name} ({len(combos)} combos x {len(methods)} methods, " + "Timer-driven)") + extra_hdr = "" + if has_tflops: + extra_hdr += f" {'TFLOPS':>10}" + if has_gbps: + extra_hdr += f" {'GB/s':>10}" + HDR = (f" {'median':>10} {'mean':>10} {'iqr':>10} {'n_blocks':>9}" + f" {'per_run':>8}" + extra_hdr + f" {'method':<30} params") + print("-" * len(HDR)) + print(HDR) + print("-" * len(HDR)) + + rows = [] + for method_name in methods: + started_at_ms = int(time.time() * 1000) + for combo in combos: + label = ", ".join(f"{n}={v}" for n, v in zip(param_names, combo)) + params_str = _format_params(param_names, combo) + instance = cls() + try: + instance.setup(*combo) + except Exception as e: + print(f" SKIP {label} setup failed: {e}") + continue + + method = getattr(instance, method_name) + try: + measurement = method(*combo) + except Exception as e: + print(f" SKIP {label} {method_name} failed: {e}") + continue + + wfn = getattr(instance, "work_" + method_name[5:], None) + work = {} + if wfn: + try: + work = wfn(*combo) + except Exception: + pass + + rows.extend(_measurement_to_rows( + measurement, suite=suite_name, class_name=class_name, + method_name=method_name, params_str=params_str, work=work, + commit=commit, machine=machine, started_at_ms=started_at_ms, + )) + + median_s = measurement.median + mean_s = measurement.mean + iqr_s = measurement.iqr + extra_cols = "" + if has_tflops: + extra_cols += (f" {work['flops'] / median_s / 1e12:>10.1f}" + if "flops" in work and median_s > 0 else f" {'':>10}") + if has_gbps: + extra_cols += (f" {work['bytes'] / median_s / 1e9:>10.1f}" + if "bytes" in work and median_s > 0 else f" {'':>10}") + print(f" {median_s*1000:>8.3f}ms {mean_s*1000:>8.3f}ms " + f"{iqr_s*1000:>8.3f}ms {len(measurement.times):>9} " + f"{measurement.number_per_run:>8}" + f"{extra_cols} {method_name:<30} {label}") + + return rows + + +def run_as_main(caller_file=None): + """Run benchmarks from a bench file or from the command line. + + When called with a file path (from a bench file's ``__main__`` block), + the suite is derived from the filename. When called without arguments + (i.e. ``python driver.py bench_gemm``), the suite is taken from argv. + + Usage from a bench file:: + + if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) + """ + parser = argparse.ArgumentParser( + description="Run microbenchmarks via torch.utils.benchmark and emit CSV.") + if caller_file is None: + parser.add_argument("suite", nargs="?", default=None, + help="Benchmark module name (e.g. bench_casting)") + parser.add_argument("--all", action="store_true", + help="Run all bench_*.py suites in the directory") + parser.add_argument("method_filter", nargs="?", default=None, + help="Only run time_* methods containing this string") + parser.add_argument("--csv", default=None, metavar="FILE", + help="Output CSV path. Default: " + "benchmarks/.bench-results//.csv") + parser.add_argument("--no-csv", action="store_true", + help="Don't write CSV (stdout summary only).") + parser.add_argument("--append", action="store_true", + help="Append to the CSV instead of overwriting.") + args = parser.parse_args() + + if caller_file is not None: + script_dir = os.path.dirname(os.path.abspath(caller_file)) + suite_names = [os.path.splitext(os.path.basename(caller_file))[0]] + else: + script_dir = os.path.dirname(os.path.abspath(__file__)) + run_all = getattr(args, "all", False) + if run_all: + suite_names = sorted( + os.path.splitext(os.path.basename(f))[0] + for f in glob.glob(os.path.join(script_dir, "bench_*.py")) + ) + elif args.suite: + suite_names = [args.suite] + else: + parser.error("provide a suite name or use --all") + + os.chdir(script_dir) + if script_dir not in sys.path: + sys.path.insert(0, script_dir) + + commit = _get_commit_hash(short=True) + machine = _get_machine_name() + + all_rows = [] + for suite_name in suite_names: + mod = importlib.import_module(suite_name) + for name in sorted(dir(mod)): + obj = getattr(mod, name) + if isinstance(obj, type) and name.startswith("Bench"): + all_rows.extend(run_class( + suite_name, obj, name, args.method_filter, + commit=commit, machine=machine, + )) + + if all_rows and not args.no_csv: + csv_path = args.csv or _default_csv_path(script_dir) + save_csv_results(all_rows, csv_path, append=args.append) + + +if __name__ == "__main__": + run_as_main()