-
Notifications
You must be signed in to change notification settings - Fork 28
Microbenchmark suite #487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Micky774
wants to merge
21
commits into
dev
Choose a base branch
from
zain/asv-demo
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Microbenchmark suite #487
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
d7c643c
Initial benchmark porting to ASV
Micky774 b829122
Update casting benchmark
Micky774 21678b4
Added helper script and documentation
Micky774 6cb91a5
Corrected local benchmarking
Micky774 1a98989
Added direct-run option to bypass subprocess overhead
Micky774 498f16d
Refactor to prefer direct runs, and moved asv conf
Micky774 1e41715
Allowed for direct run of bench files
Micky774 c1e489d
Remove CI component
Micky774 9772f2d
Rename direct_run to driver
Micky774 770a3f0
Refactored driver, streamlined README.md
Micky774 aa2a4a1
Updated to CUDA event based timing
Micky774 a2e5999
Added throughput/bandwidth calc, improved driver
Micky774 89ebfa5
Streamline and clean code
Micky774 91b6b2c
Updated readme, simplified helper script
Micky774 1b5d042
Updated docstrings to include config sources
Micky774 c6df4d7
Added missing var
Micky774 37b8f0d
Merge remote-tracking branch 'origin/dev' into zain/asv-demo
matthiasdiener 29465a4
Added cold-cache support as well as inner runs for launch amortization
Micky774 8c23f72
Trimmed implementation to only use ASV for dashboard
Micky774 6fea829
Merge branch 'dev' into zain/asv-demo
Micky774 f0c7096
Refactor to remove ASV format for now, defaulting to CSV
Micky774 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| # Microbenchmarks for TransformerEngine | ||
|
|
||
| GPU microbenchmarks driven by `driver.py`. Each `bench_*.py` file defines one | ||
| or more bench classes following an ASV-style API (`params`, `param_names`, | ||
| `time_*` methods, optional `work_*` companions). Timing uses | ||
| `torch.utils.benchmark.Timer` under the hood. The driver runs each suite | ||
| in-process and writes results as long-format CSV — one row per Timer block — | ||
| intended to be consumed by a separate analysis tool (statistical tests, | ||
| cross-run comparison). | ||
|
|
||
| ## Prerequisites | ||
|
|
||
| - TransformerEngine must already be built and installed in the current Python environment. | ||
| - A ROCm or CUDA GPU must be available. | ||
|
|
||
| ## Running benchmarks | ||
|
|
||
| Each `bench_*.py` file is directly executable, or you can drive them through | ||
| `driver.py`. Results are written by default to | ||
| `benchmarks/.bench-results/<machine>/<commit-short>.csv`. | ||
|
|
||
| ```bash | ||
| cd benchmarks/microbench | ||
| python driver.py --all # run every suite | ||
| python driver.py bench_gemm # run one suite via driver | ||
| python bench_gemm.py # run one suite directly | ||
| python bench_gemm.py time_forward # filter to method names containing this string | ||
| python bench_casting.py --no-csv # stdout only, don't write CSV | ||
| python bench_casting.py --csv out.csv # custom output path | ||
| python bench_casting.py --append # append to existing CSV | ||
| ``` | ||
|
|
||
| ### Helper script | ||
|
|
||
| `run_benchmarks.sh` wraps common tasks and can be run from anywhere. | ||
|
|
||
| ```bash | ||
| bash benchmarks/microbench/run_benchmarks.sh <command> [options] | ||
| ``` | ||
|
|
||
| | Command | Description | | ||
| |---|---| | ||
| | `run [suite] [method]` | Run benchmarks in-process and write CSV | | ||
| | `list` | List available benchmark suites | | ||
|
|
||
| ## Output format | ||
|
|
||
| Long-format CSV — one row per `torch.utils.benchmark` block. Default location | ||
| is `benchmarks/.bench-results/<machine>/<commit-short>.csv`; the | ||
| `.bench-results` tree is in `.gitignore`. Schema: | ||
|
|
||
| | Column | Type | Description | | ||
| |---|---|---| | ||
| | `suite` | str | Module name (e.g. `bench_gemm`) | | ||
| | `class` | str | Bench class name (e.g. `BenchGemm`) | | ||
| | `method` | str | Timed method (e.g. `time_forward`) | | ||
| | `params` | str | `k1=v1;k2=v2` canonical form for joining across runs | | ||
| | `sample_idx` | int | Block index within this Measurement | | ||
| | `time_s` | float | Per-call elapsed seconds (Timer normalizes by `number_per_run`) | | ||
| | `number_per_run` | int | Kernel invocations averaged into this row's `time_s` | | ||
| | `tflops` | float | Per-call throughput, empty if no `work_*` flops | | ||
| | `gbps` | float | Per-call bandwidth, empty if no `work_*` bytes | | ||
| | `commit` | str | Short git HEAD hash | | ||
| | `machine` | str | `platform.node()` | | ||
| | `started_at_ms` | int | Unix-ms timestamp when this method's run began | | ||
|
|
||
| Per-PR comparison and statistical tests are handled by a separate analysis | ||
| tool (TBD) that reads two or more of these CSVs and joins on | ||
| `(suite, class, method, params)`. Note that `time_s` is a *block mean* — | ||
| the analysis tool should weight by `number_per_run` (or use blocks as | ||
| independent samples) when computing significance. | ||
|
|
||
| ## Writing new benchmarks | ||
|
|
||
| Create a new file in `benchmarks/microbench/` following the naming convention `bench_<name>.py`. | ||
|
|
||
| ```python | ||
| #!/usr/bin/env python3 | ||
| import torch | ||
| import transformer_engine.pytorch as te | ||
|
|
||
| from driver import time_func | ||
|
|
||
|
|
||
| class BenchSomething: | ||
| params = [[1024, 4096], ["config_a", "config_b"]] | ||
| param_names = ["M", "config"] | ||
| timeout = 300 # seconds, per parameter combination | ||
|
|
||
| def setup(self, M, config): | ||
| # Allocate tensors, create modules. | ||
| # Runs once per (combo, method); same instance is reused for warmup | ||
| # and timed Timer blocks. | ||
| self.module = ... | ||
| self.x = ... | ||
|
|
||
| def time_forward(self, M, config): | ||
| return time_func(lambda: self.module(self.x)) | ||
|
|
||
| def time_forward_backward(self, M, config): | ||
| def fn(): | ||
| out = self.module(self.x) | ||
| out.backward(self.grad_out) | ||
| return time_func(fn) | ||
|
|
||
| # Optional: define work_<name> to get throughput columns (TFLOPS / GB/s). | ||
| def work_forward(self, M, config): | ||
| return {"flops": 2 * M * self.N * self.K} # compute-bound | ||
| # return {"bytes": M * self.hidden * 4} # memory-bound | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from driver import run_as_main | ||
| run_as_main(__file__) | ||
| ``` | ||
|
|
||
| Key rules: | ||
| - Method names starting with `time_` are automatically timed. | ||
| - `time_*` methods must return `time_func(fn)` — a `torch.utils.benchmark.Measurement`. | ||
| - Inside `fn`, do whatever per-call work you want measured. For backward, | ||
| let gradients accumulate in-place across iterations — Timer's repeated | ||
| invocations don't OOM (grads accumulate into the same tensor) and the | ||
| numerical correctness of accumulated grad doesn't affect timing. | ||
| - Optionally define `work_<name>` companions to get TFLOPS or GB/s columns. | ||
| Return per-call work; the driver derives per-sample throughput. | ||
| - The `params` list defines a cross-product; keep the matrix size reasonable. | ||
|
|
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| #!/usr/bin/env python3 | ||
| ############################################################################### | ||
| # Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
| ############################################################################### | ||
| """ | ||
| Attention micro-benchmark using te.DotProductAttention. | ||
|
|
||
| Benchmarks fused multi-head attention (with flash attention backend) for | ||
| model configurations with grouped-query attention (GQA). | ||
|
|
||
| Models: | ||
| - Llama 3 8B (TP=1, TP=8), 70B (TP=8), 405B (TP=8) | ||
| - Qwen 2.5 7B (TP=1), 72B (TP=8) | ||
|
|
||
| Forward FLOPs = 4 * batch * num_q_heads * seq_len^2 * head_dim | ||
| (two matmuls: Q@K^T and attn@V, each contributing 2*b*h*s^2*d) | ||
| Backward FLOPs = 2 * Forward FLOPs (approximately) | ||
|
|
||
| Sources for model configs: | ||
| https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json | ||
| https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json | ||
| https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json | ||
| https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json | ||
| https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json | ||
| """ | ||
|
|
||
| import torch | ||
| import transformer_engine.pytorch as te | ||
|
|
||
| from driver import time_func | ||
|
|
||
| BATCH = 2 | ||
|
|
||
| # (num_q_heads, num_kv_heads, head_dim, tp) | ||
| MODELS = { | ||
| "Llama3-8B_TP1": (32, 8, 128, 1), | ||
| "Llama3-8B_TP8": (32, 8, 128, 8), | ||
| "Llama3-70B_TP8": (64, 8, 128, 8), | ||
| "Llama3-405B_TP8": (128, 8, 128, 8), | ||
| "Qwen2.5-7B_TP1": (28, 4, 128, 1), | ||
| "Qwen2.5-72B_TP8": (64, 8, 128, 8), | ||
| } | ||
|
|
||
|
|
||
| class BenchAttention: | ||
| params = [[1024, 2048, 4096, 8192], list(MODELS)] | ||
| param_names = ["seq_len", "model"] | ||
| timeout = 300 | ||
|
|
||
| def setup(self, seq_len, model): | ||
| n_q, n_kv, hd, tp = MODELS[model] | ||
| qh, kvh = n_q // tp, n_kv // tp | ||
| dtype = torch.bfloat16 | ||
|
|
||
| self.attn = te.DotProductAttention( | ||
| num_attention_heads=qh, kv_channels=hd, | ||
| num_gqa_groups=kvh, attn_mask_type="causal", | ||
| ).to(device="cuda", dtype=dtype) | ||
|
|
||
| self.q = torch.randn(seq_len, BATCH, qh, hd, dtype=dtype, device="cuda", requires_grad=True) | ||
| self.k = torch.randn(seq_len, BATCH, kvh, hd, dtype=dtype, device="cuda", requires_grad=True) | ||
| self.v = torch.randn(seq_len, BATCH, kvh, hd, dtype=dtype, device="cuda", requires_grad=True) | ||
| self.grad_out = torch.randn_like(self.attn(self.q, self.k, self.v)) | ||
|
|
||
| def work_forward(self, seq_len, model): | ||
| n_q, n_kv, hd, tp = MODELS[model] | ||
| qh = n_q // tp | ||
| return {"flops": 4 * BATCH * qh * seq_len * seq_len * hd} | ||
|
|
||
| def work_forward_backward(self, seq_len, model): | ||
| n_q, n_kv, hd, tp = MODELS[model] | ||
| qh = n_q // tp | ||
| return {"flops": 3 * 4 * BATCH * qh * seq_len * seq_len * hd} | ||
|
|
||
| def time_forward(self, seq_len, model): | ||
| return time_func(lambda: self.attn(self.q, self.k, self.v)) | ||
|
|
||
| def time_forward_backward(self, seq_len, model): | ||
| def fn(): | ||
| out = self.attn(self.q, self.k, self.v) | ||
| out.backward(self.grad_out) | ||
| return time_func(fn) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from driver import run_as_main | ||
| run_as_main(__file__) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| #!/usr/bin/env python3 | ||
| ############################################################################### | ||
| # Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
| ############################################################################### | ||
| """ | ||
| Benchmarks quantization (BF16 -> FP8) and dequantization (FP8 -> BF16) for | ||
| both E4M3 (activations/weights) and E5M2 (gradients) formats. | ||
|
|
||
| Shapes are (M, hidden_size) matching the activation tensors from models: | ||
| - Llama 3.1 8B, 70B, 405B | ||
| - Qwen 2.5 7B, 72B | ||
|
|
||
| These casts are memory-bound; we report GB/s (input + output bytes). | ||
|
|
||
| Sources for model configs: | ||
| https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json | ||
| https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json | ||
| https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json | ||
| https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json | ||
| https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json | ||
| """ | ||
|
|
||
| import torch | ||
| from transformer_engine.pytorch import Float8CurrentScalingQuantizer | ||
| from transformer_engine_torch import DType as TE_DType | ||
|
|
||
| from driver import time_func | ||
|
|
||
| HIDDEN_SIZES = { | ||
| "Llama3-8B": 4096, | ||
| "Llama3-70B": 8192, | ||
| "Llama3-405B": 16384, | ||
| "Qwen2.5-7B": 3584, | ||
| "Qwen2.5-72B": 8192, | ||
| } | ||
|
|
||
| CAST_CONFIGS = { | ||
| "BF16_to_E4M3": ("quantize", TE_DType.kFloat8E4M3), | ||
| "E4M3_to_BF16": ("dequantize", TE_DType.kFloat8E4M3), | ||
| "BF16_to_E5M2": ("quantize", TE_DType.kFloat8E5M2), | ||
| "E5M2_to_BF16": ("dequantize", TE_DType.kFloat8E5M2), | ||
| } | ||
|
|
||
|
|
||
| class BenchCasting: | ||
| params = [[1024, 2048, 4096, 8192], list(HIDDEN_SIZES), list(CAST_CONFIGS)] | ||
| param_names = ["M", "model", "cast"] | ||
| timeout = 120 | ||
|
|
||
| def setup(self, M, model, cast): | ||
| hidden = HIDDEN_SIZES[model] | ||
| direction, fp8_dtype = CAST_CONFIGS[cast] | ||
| self.direction = direction | ||
| quantizer = Float8CurrentScalingQuantizer( | ||
| fp8_dtype=fp8_dtype, | ||
| device=torch.device("cuda"), | ||
| rowwise=True, | ||
| columnwise=False, | ||
| ) | ||
| if direction == "dequantize": | ||
| bf16_tensor = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") | ||
| self.x = quantizer.quantize(bf16_tensor) | ||
| else: | ||
| self.x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") | ||
| self.quantizer = quantizer | ||
|
|
||
| def work_cast(self, M, model, cast): | ||
| hidden = HIDDEN_SIZES[model] | ||
| # Read input (1B FP8 or 2B BF16) + write output + scale (~hidden bytes total) | ||
| # Approximated as 3 bytes per element either direction. | ||
| return {"bytes": M * hidden * 3} | ||
|
|
||
| def time_cast(self, M, model, cast): | ||
| if self.direction == "quantize": | ||
| return time_func(lambda: self.quantizer.quantize(self.x)) | ||
| return time_func(lambda: self.x.dequantize(dtype=torch.bfloat16)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from driver import run_as_main | ||
| run_as_main(__file__) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| #!/usr/bin/env python3 | ||
| ############################################################################### | ||
| # Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
| ############################################################################### | ||
| """BF16 GEMM benchmarks via te.Linear. | ||
|
|
||
| GEMM shapes derived from transformer layer projections: | ||
| QKV, AttnOut, GateUp (SwiGLU), Down. | ||
|
|
||
| Model configuration sources: | ||
| - Llama 3 8B (hidden=4096, intermediate=14336, heads=32, kv_heads=8, head_dim=128) | ||
| https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json | ||
|
|
||
| - Llama 3 70B (hidden=8192, intermediate=28672, heads=64, kv_heads=8, head_dim=128) | ||
| https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json | ||
|
|
||
| - Llama 3 405B (hidden=16384, intermediate=53248, heads=128, kv_heads=8, head_dim=128) | ||
| https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json | ||
|
|
||
| - Qwen 2.5 7B (hidden=3584, intermediate=18944, heads=28, kv_heads=4, head_dim=128) | ||
| https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json | ||
|
|
||
| - Qwen 2.5 72B (hidden=8192, intermediate=29568, heads=64, kv_heads=8, head_dim=128) | ||
| https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json | ||
| """ | ||
|
|
||
| import torch | ||
| import transformer_engine.pytorch as te | ||
|
|
||
| from driver import time_func | ||
|
|
||
| # (hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) | ||
| MODELS = { | ||
| "Llama3-8B_TP1": (4096, 14336, 32, 8, 128, 1), | ||
| "Llama3-8B_TP8": (4096, 14336, 32, 8, 128, 8), | ||
| "Llama3-70B_TP8": (8192, 28672, 64, 8, 128, 8), | ||
| "Llama3-405B_TP8": (16384, 53248, 128, 8, 128, 8), | ||
| "Qwen2.5-7B_TP1": (3584, 18944, 28, 4, 128, 1), | ||
| "Qwen2.5-72B_TP8": (8192, 29568, 64, 8, 128, 8), | ||
| } | ||
|
|
||
| # Pre-compute (N, K) for each GEMM shape | ||
| SHAPES = {} | ||
| for _name, (h, inter, nq, nkv, hd, tp) in MODELS.items(): | ||
| SHAPES[f"{_name}-QKV"] = ((nq * hd + 2 * nkv * hd) // tp, h) | ||
| SHAPES[f"{_name}-AttnOut"] = (h, (nq * hd) // tp) | ||
| SHAPES[f"{_name}-GateUp"] = ((2 * inter) // tp, h) | ||
| SHAPES[f"{_name}-Down"] = (h, inter // tp) | ||
|
|
||
|
|
||
| class BenchGemm: | ||
| params = [[1024, 2048, 4096, 8192], list(SHAPES)] | ||
| param_names = ["M", "shape"] | ||
| timeout = 300 | ||
|
|
||
| def setup(self, M, shape): | ||
| N, K = SHAPES[shape] | ||
| dtype = torch.bfloat16 | ||
| self.linear = te.Linear(K, N, bias=False).to(device="cuda", dtype=dtype) | ||
| self.x = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) | ||
| self.grad_out = torch.randn_like(self.linear(self.x)) | ||
|
|
||
| def work_forward(self, M, shape): | ||
| N, K = SHAPES[shape] | ||
| return {"flops": 2 * M * N * K} | ||
|
|
||
| def work_forward_backward(self, M, shape): | ||
| N, K = SHAPES[shape] | ||
| return {"flops": 3 * 2 * M * N * K} | ||
|
|
||
| def time_forward(self, M, shape): | ||
| return time_func(lambda: self.linear(self.x)) | ||
|
|
||
| def time_forward_backward(self, M, shape): | ||
| def fn(): | ||
| out = self.linear(self.x) | ||
| out.backward(self.grad_out) | ||
| return time_func(fn) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from driver import run_as_main | ||
| run_as_main(__file__) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of creating a new attention microbenchmark, should we use the attention microbenchmark(s) already part of TE (in https://github.com/ROCm/TransformerEngine/tree/dev/benchmarks/attention)?