-
Notifications
You must be signed in to change notification settings - Fork 28
Microbenchmarking, Torch+CSV-based #478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
matthiasdiener
wants to merge
37
commits into
dev
Choose a base branch
from
mdiener/ci-microbench
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
8a0ea47
initial impl
matthiasdiener 4270296
Merge remote-tracking branch 'origin/dev' into mdiener/ci-microbench
matthiasdiener 6ddb77d
put into benchmarks subfolder
matthiasdiener fb2b3f3
restructure comment
matthiasdiener d4e9b1e
misc updates
matthiasdiener 95358f4
python fix
matthiasdiener a675d17
Merge remote-tracking branch 'origin/dev' into mdiener/ci-microbench
matthiasdiener d0a320d
another embedded python fix
matthiasdiener 6f45853
replace py code
matthiasdiener e5eaf10
Merge branch 'dev' into mdiener/ci-microbench
matthiasdiener 55e7eb5
restore disabled parts of CI
matthiasdiener 7072e82
Merge remote-tracking branch 'origin/dev' into mdiener/ci-microbench
matthiasdiener 9c771b4
add attention, casting, normalization
matthiasdiener 64e8da8
add timestamp and commit ID
matthiasdiener c986c97
add FP8 GEMM
matthiasdiener 4f6dc86
fix name
matthiasdiener 811e329
Merge remote-tracking branch 'origin/dev' into mdiener/ci-microbench
matthiasdiener bd6c3e7
Merge remote-tracking branch 'origin/dev' into mdiener/ci-microbench
matthiasdiener c9d6d4d
updates casting
matthiasdiener 4bc11df
Merge branch 'dev' into mdiener/ci-microbench
matthiasdiener de21a77
remove attention
matthiasdiener 1d6f869
fix grouped gemm
matthiasdiener 12b4218
remove CI part
matthiasdiener 75c8291
use adaptive_autorange, cleanups
matthiasdiener 2e6da68
add csv to asv converter
matthiasdiener 6353411
Merge remote-tracking branch 'origin/dev' into mdiener/ci-microbench
matthiasdiener a1c6453
Merge remote-tracking branch 'origin/dev' into mdiener/ci-microbench
matthiasdiener 33a3137
Merge remote-tracking branch 'upstream/dev' into mdiener/ci-microbench
matthiasdiener aa8997c
refactor
matthiasdiener fefaf13
remove asv converter
matthiasdiener 7f2669d
cleanups, misc fixes
matthiasdiener 117c2d7
Merge remote-tracking branch 'origin/dev' into mdiener/ci-microbench
matthiasdiener bc824d7
Merge remote-tracking branch 'upstream/dev' into mdiener/ci-microbench
matthiasdiener d4e116a
simplifications, address review comments
matthiasdiener 372e6df
Llama 3.1
matthiasdiener 284adda
address reviewer comments
matthiasdiener ca1f442
add readme
matthiasdiener File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| # Transformer Engine Microbenchmarks | ||
|
|
||
| This directory contains lightweight Python microbenchmarks for selected | ||
| Transformer Engine kernels and helper scripts for comparing benchmark CSVs. | ||
|
|
||
| ## Benchmarks | ||
|
|
||
| - `benchmark_gemm.py`: dense BF16 GEMM benchmark | ||
| - `benchmark_gemm_fp8.py`: dense FP8 GEMM benchmark using `fp8_autocast` | ||
| - `benchmark_grouped_gemm.py`: grouped GEMM benchmark for MoE-style shapes | ||
| - `benchmark_casting.py`: BF16 `<->` FP8 casting benchmark | ||
| - `benchmark_normalization.py`: LayerNorm and RMSNorm benchmark | ||
|
|
||
| Run a benchmark directly from this directory. Pass `--csv` to write results. | ||
| When no filename is provided, `run_benchmarks` derives the CSV name from the | ||
| benchmark script file name. | ||
|
|
||
| ```bash | ||
| python benchmark_gemm.py --csv | ||
| python benchmark_grouped_gemm.py --csv grouped_results.csv | ||
| ``` | ||
|
|
||
| ## Shared configuration | ||
|
|
||
| Common benchmark settings live in `utils.py`. | ||
|
|
||
| - `M_SIZE_LIST`: default token-count sweep for dense and elementwise kernels | ||
| - `DTYPE_LIST`: shared dtype sweep for TE activation benchmarks | ||
| - `MODEL_CONFIGS`: dense GEMM model shapes | ||
| - `MODEL_HIDDEN_SIZES`: hidden sizes for elementwise kernels | ||
|
|
||
| Grouped GEMM keeps its own smaller M sweep because its working set scales with | ||
| expert count `B` in addition to `M`. | ||
|
|
||
| ## Adding a benchmark | ||
|
|
||
| Use `run_benchmarks(test_cases, bench_fn, param_columns)`. | ||
|
|
||
| - `test_cases` is a list of dictionaries containing benchmark inputs. | ||
| - `param_columns` lists the case fields that should appear in stdout headers | ||
| and CSV output. | ||
| - `bench_fn(**case)` must return a list of metric records created by | ||
| `make_metric_record(...)` or `make_forward_backward_metric_records(...)`. | ||
|
|
||
| Each metric record represents one benchmark line such as `GEMM Forward`. The | ||
| runner prints that line to stdout and expands it into two CSV columns: | ||
|
|
||
| - `<label> Time (ms)` | ||
| - `<label> <unit>` | ||
|
|
||
| For example, a `GEMM Forward` metric with unit `TFLOPS` becomes: | ||
|
|
||
| - `GEMM Forward Time (ms)` | ||
| - `GEMM Forward TFLOPS` | ||
|
|
||
| ## Comparing results | ||
|
|
||
| Use `compare_results.py` to compare two CSV files from the same benchmark | ||
| family: | ||
|
|
||
| ```bash | ||
| python compare_results.py baseline.csv candidate.csv --bench-name GEMM | ||
| ``` | ||
|
|
||
| The script auto-detects metric columns, computes speedups for overlapping rows, | ||
| and reports rows that exist only in the baseline or only in the candidate. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| #!/usr/bin/env python | ||
| ############################################################################### | ||
| # Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
| ############################################################################### | ||
| """ | ||
| FP8 casting micro-benchmark. | ||
|
|
||
| Benchmarks quantization (BF16 -> FP8) and dequantization (FP8 -> BF16) for | ||
| both E4M3 (activations/weights) and E5M2 (gradients) formats. | ||
|
|
||
| These casts are memory-bound; we report GB/s (input + output bytes). | ||
| Output: benchmark_casting.csv (written to cwd) | ||
| """ | ||
|
|
||
| import torch | ||
| import transformer_engine | ||
| import transformer_engine_torch as tex | ||
| from transformer_engine.pytorch import Float8Quantizer | ||
| from utils import ( | ||
| MODEL_HIDDEN_SIZES, M_SIZE_LIST, | ||
| time_func, compute_gbps, make_metric_record, run_benchmarks, | ||
| ) | ||
|
|
||
| TE_FP8_E4M3 = tex.DType.kFloat8E4M3 | ||
| TE_FP8_E5M2 = tex.DType.kFloat8E5M2 | ||
|
|
||
| CAST_LABEL = "Cast" | ||
|
|
||
| CAST_CONFIGS = [ | ||
| # (name, direction, fp8_dtype) | ||
| ("BF16-to-FP8-E4M3", "quantize", TE_FP8_E4M3), | ||
| ("FP8-E4M3-to-BF16", "dequantize", TE_FP8_E4M3), | ||
| ("BF16-to-FP8-E5M2", "quantize", TE_FP8_E5M2), | ||
| ("FP8-E5M2-to-BF16", "dequantize", TE_FP8_E5M2), | ||
| ] | ||
|
|
||
|
|
||
| def _generate_cast_test_cases(): | ||
| test_cases = [] | ||
| for model_name, hidden in MODEL_HIDDEN_SIZES: | ||
| for cast_name, direction, fp8_dtype in CAST_CONFIGS: | ||
| for M in M_SIZE_LIST: | ||
| test_cases.append({ | ||
| "Case": f"{model_name}/{cast_name}", | ||
| "M": M, | ||
| "hidden_size": hidden, | ||
| "direction": direction, | ||
| "fp8_dtype": fp8_dtype, | ||
| "dtype_str": cast_name, | ||
| }) | ||
| return test_cases | ||
|
|
||
|
|
||
| def bench_cast(Case, M, hidden_size, direction, fp8_dtype, dtype_str): | ||
| device = "cuda" | ||
|
|
||
| numel = M * hidden_size | ||
| scale = torch.ones(1, dtype=torch.float32, device=device) | ||
| amax = torch.zeros(1, dtype=torch.float32, device=device) | ||
| quantizer = Float8Quantizer(scale, amax, fp8_dtype) | ||
|
|
||
| if direction == "quantize": | ||
| x = torch.randn(M, hidden_size, dtype=torch.bfloat16, device=device) | ||
| out = quantizer(x) | ||
| cast_func = lambda: quantizer.quantize(x, out=out) | ||
| total_bytes = numel * (2 + 1) # BF16 read + FP8 write | ||
| else: | ||
| x = torch.randn(M, hidden_size, dtype=torch.bfloat16, device=device) | ||
| fp8_tensor = quantizer(x) | ||
| cast_func = lambda: fp8_tensor.dequantize() | ||
| total_bytes = numel * (1 + 2) # FP8 read + BF16 write | ||
|
|
||
| ms = time_func(cast_func, method="blocked") | ||
| gbps = compute_gbps(total_bytes, ms) | ||
|
|
||
| return [make_metric_record(CAST_LABEL, ms, "GB/s", gbps)] | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_benchmarks( | ||
| test_cases=_generate_cast_test_cases(), | ||
| bench_fn=bench_cast, | ||
| param_columns=["Case", "M", "hidden_size", "dtype_str"], | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| #!/usr/bin/env python | ||
| ############################################################################### | ||
| # Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
| ############################################################################### | ||
|
|
||
|
|
||
| import torch | ||
| import transformer_engine.pytorch as te | ||
| from utils import ( | ||
| generate_gemm_test_cases, | ||
| time_func, compute_tflops, make_forward_backward_metric_records, run_benchmarks, | ||
| ) | ||
|
|
||
| BENCHMARK_LABEL = "GEMM" | ||
|
|
||
|
|
||
| def bench_gemm(Case, M, N, K, dtype): | ||
| device = "cuda" | ||
|
|
||
| linear = te.Linear(K, N, bias=False).to(device=device, dtype=dtype) | ||
| x = torch.randn(M, K, dtype=dtype, device=device, requires_grad=True) | ||
|
|
||
| fwd_func = lambda: linear(x) | ||
| out = fwd_func() | ||
| grad_out = torch.randn_like(out) | ||
|
|
||
| def fwd_bwd_func(): | ||
| out = linear(x) | ||
| out.backward(grad_out) | ||
| x.grad = None | ||
| linear.weight.grad = None | ||
|
|
||
| fwd_bwd_func() | ||
|
|
||
| fwd_flops = 2 * M * N * K | ||
| bwd_flops = 2 * fwd_flops # dX + dW | ||
|
|
||
| fwd_ms = time_func(fwd_func) | ||
| fwd_bwd_ms = time_func(fwd_bwd_func) | ||
| bwd_ms = fwd_bwd_ms - fwd_ms | ||
|
|
||
| fwd_tflops = compute_tflops(fwd_flops, fwd_ms) | ||
| bwd_tflops = compute_tflops(bwd_flops, bwd_ms) | ||
|
|
||
| return make_forward_backward_metric_records( | ||
| BENCHMARK_LABEL, | ||
| "TFLOPS", | ||
| fwd_ms, | ||
| fwd_tflops, | ||
| bwd_ms, | ||
| bwd_tflops, | ||
| backward_derived=True, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_benchmarks( | ||
| test_cases=generate_gemm_test_cases(), | ||
| bench_fn=bench_gemm, | ||
| param_columns=["Case", "M", "N", "K", "dtype"], | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| #!/usr/bin/env python | ||
| ############################################################################### | ||
| # Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
| ############################################################################### | ||
| """ | ||
| FP8 GEMM micro-benchmark using te.Linear under fp8_autocast. | ||
|
|
||
| Same model shapes as benchmark_gemm.py. | ||
| Output: benchmark_gemm_fp8.csv (written to cwd) | ||
| """ | ||
|
|
||
| import torch | ||
| import transformer_engine.pytorch as te | ||
| from transformer_engine.common.recipe import DelayedScaling, Format | ||
| from utils import ( | ||
| generate_gemm_test_cases, | ||
| time_func, compute_tflops, make_forward_backward_metric_records, run_benchmarks, | ||
| ) | ||
|
|
||
| RECIPES = { | ||
| "hybrid": DelayedScaling( | ||
| fp8_format=Format.HYBRID, | ||
| amax_history_len=16, | ||
| amax_compute_algo="max", | ||
| ), | ||
| } | ||
|
|
||
| FP8_RECIPE = RECIPES["hybrid"] | ||
|
|
||
| BENCHMARK_LABEL = "FP8 GEMM" | ||
|
|
||
|
|
||
| def bench_fp8_gemm(Case, M, N, K, dtype): | ||
| device = "cuda" | ||
|
|
||
| linear = te.Linear(K, N, bias=False).to(device=device, dtype=dtype) | ||
| x = torch.randn(M, K, dtype=dtype, device=device, requires_grad=True) | ||
| grad_out = torch.randn(M, N, dtype=dtype, device=device) | ||
|
|
||
| def fwd_func(): | ||
| with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): | ||
| return linear(x) | ||
|
|
||
| def fwd_bwd_func(): | ||
| with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): | ||
| out = linear(x) | ||
| out.backward(grad_out) | ||
| x.grad = None | ||
| linear.weight.grad = None | ||
|
|
||
| fwd_flops = 2 * M * N * K | ||
| bwd_flops = 2 * fwd_flops | ||
|
|
||
| fwd_ms = time_func(fwd_func) | ||
| fwd_bwd_ms = time_func(fwd_bwd_func) | ||
| bwd_ms = fwd_bwd_ms - fwd_ms | ||
|
|
||
| fwd_tflops = compute_tflops(fwd_flops, fwd_ms) | ||
| bwd_tflops = compute_tflops(bwd_flops, bwd_ms) | ||
|
|
||
| return make_forward_backward_metric_records( | ||
| BENCHMARK_LABEL, | ||
| "TFLOPS", | ||
| fwd_ms, | ||
| fwd_tflops, | ||
| bwd_ms, | ||
| bwd_tflops, | ||
| backward_derived=True, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_benchmarks( | ||
| test_cases=generate_gemm_test_cases(), | ||
| bench_fn=bench_fp8_gemm, | ||
| param_columns=["Case", "M", "N", "K", "dtype"], | ||
| ) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this to be a separate benchmark entirely, or can we combine with the
bencmhark_gemm.pyand include as e.g. a parameterization option?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kept it separate for now because it has FP8-specific recipe/autocast setup and a separate output result.