From 639efc022e4df154feff68cce706e995c93a8e5a Mon Sep 17 00:00:00 2001 From: PhlimosJW <220233704@seu.edu.cn> Date: Fri, 13 Mar 2026 06:37:20 -0500 Subject: [PATCH 1/5] Add ROCm Triton kernels skill for MI355X/R9700 - RMSNorm, RoPE 3D, GEGLU, AdaLN kernel patterns - Benchmark scripts (micro + e2e for LTX-Video) - HuggingFace Kernels integration example - Reference docs: optimization guides, templates, troubleshooting --- skills/custom-cuda-kernels-agent-skills.md | 298 ++++++++++ skills/rocm-kernels/CHANGELOG.md | 32 ++ skills/rocm-kernels/SKILL.md | 468 +++++++++++++++ skills/rocm-kernels/manifest.txt | 18 + .../references/diffusers-integration.md | 252 ++++++++ .../huggingface-kernels-integration.md | 351 ++++++++++++ .../references/kernel-agent-knowledge-base.md | 140 +++++ .../references/kernel-templates.md | 512 +++++++++++++++++ .../references/kernelbench-classification.md | 162 ++++++ .../references/mi355x-optimization-guide.md | 233 ++++++++ .../references/r9700-optimization-guide.md | 172 ++++++ .../skill-evaluation-methodology.md | 251 ++++++++ .../references/transformers-integration.md | 340 +++++++++++ .../references/troubleshooting.md | 292 ++++++++++ skills/rocm-kernels/scripts/benchmark_e2e.py | 296 ++++++++++ .../rocm-kernels/scripts/benchmark_kernels.py | 536 ++++++++++++++++++ .../scripts/huggingface_kernels_example.py | 344 +++++++++++ .../scripts/transformers_injection_example.py | 183 ++++++ 18 files changed, 4880 insertions(+) create mode 100644 skills/custom-cuda-kernels-agent-skills.md create mode 100644 skills/rocm-kernels/CHANGELOG.md create mode 100644 skills/rocm-kernels/SKILL.md create mode 100644 skills/rocm-kernels/manifest.txt create mode 100644 skills/rocm-kernels/references/diffusers-integration.md create mode 100644 skills/rocm-kernels/references/huggingface-kernels-integration.md create mode 100644 skills/rocm-kernels/references/kernel-agent-knowledge-base.md create mode 100644 skills/rocm-kernels/references/kernel-templates.md create mode 100644 skills/rocm-kernels/references/kernelbench-classification.md create mode 100644 skills/rocm-kernels/references/mi355x-optimization-guide.md create mode 100644 skills/rocm-kernels/references/r9700-optimization-guide.md create mode 100644 skills/rocm-kernels/references/skill-evaluation-methodology.md create mode 100644 skills/rocm-kernels/references/transformers-integration.md create mode 100644 skills/rocm-kernels/references/troubleshooting.md create mode 100644 skills/rocm-kernels/scripts/benchmark_e2e.py create mode 100644 skills/rocm-kernels/scripts/benchmark_kernels.py create mode 100644 skills/rocm-kernels/scripts/huggingface_kernels_example.py create mode 100644 skills/rocm-kernels/scripts/transformers_injection_example.py diff --git a/skills/custom-cuda-kernels-agent-skills.md b/skills/custom-cuda-kernels-agent-skills.md new file mode 100644 index 00000000..e9405b6f --- /dev/null +++ b/skills/custom-cuda-kernels-agent-skills.md @@ -0,0 +1,298 @@ +--- +title: "Custom Kernels for All from Codex and Claude" +thumbnail: /blog/assets/custom-cuda-kernels/meme.png +authors: +- user: burtenshaw +- user: sayakpaul +- user: ariG23498 +- user: evalstate +--- + + + +# Custom Kernels for All from Codex and Claude + +![oprah custom cuda kernels](assets/custom-cuda-kernels/meme.png) + +tl;dr: We built an agent skill that teaches coding agents how to write production CUDA kernels. Then we pointed Claude and Codex at two real targets: a **diffusers** pipeline and a **transformers** model. The agents produced working kernels for both, with correct PyTorch bindings and benchmarks, end to end. + +Writing CUDA kernels is hard. Writing CUDA kernels that correctly integrate with `transformers` and `diffusers` is harder. There are architecture-specific memory access patterns, vectorization strategies, warp shuffle reductions, and a dozen integration pitfalls that trip up even experienced developers. It is exactly the kind of specialized, high-stakes problem where agent skills shine. + +We gave coding agents the domain knowledge they need, like which GPU architecture to target, how to structure a kernel-builder project, when to use shared memory versus registers, and how to write PyTorch bindings. The agents did the rest. If you have used the [LLM training skill](https://huggingface.co/blog/hf-skills-training) or read [We Got Claude to Teach Open Models](https://huggingface.co/blog/upskill), the pattern will feel familiar: package domain expertise into a skill, point the agent at a problem, and let it work. + +## Why a skill for kernels? + +The [Kernel Hub](https://huggingface.co/blog/hello-hf-kernels) solved the distribution of custom hardware kernels. You can load pre-compiled kernels from the Hub with a single `get_kernel` call. No builds, no flags. However, someone still needs to **write the kernels**. That is the gap this skill fills. + +CUDA kernel development has a brutal surface area: + +- Hardware-specific optimization guides for each generation of GPU. H100, A100, and T4 each have different compute capabilities, shared memory sizes, and bandwidth profiles +- In Libraries, `diffusers` and `transformers` have different module hierarchies, normalization conventions, and integration patterns. Custom kernels need to be registered in PyTorch for `torch.compile` to recognize. +- For distribution, kernels can depend on CUDA, Pytorch, and Python versions creating massive environment matrices. + +This is domain knowledge that gets lost in documentation tabs and Stack Overflow answers. An agent skill packages it into context that loads on demand. + +First, let's show how to use the skill right away, then we'll dive into the details of how we benchmarked the kernels. + +## Installing the skill + +The skill ships with the `kernels` library. Install it into your coding agent with a single command: + +```shell +# we need to install kernels from main for this +pip install git+https://github.com/huggingface/kernels.git#subdirectory=kernels +kernels skills add cuda-kernels --claude +``` + +This drops the skill into `.claude/skills/cuda-kernels/` where Claude Code and Cursor pick it up automatically. For other agents: + +```shell +# Codex +kernels skills add cuda-kernels --codex + +# OpenCode +kernels skills add cuda-kernels --opencode + +# Custom destination +kernels skills add cuda-kernels --dest ./my-agent/skills/ + +# Install globally (available across all projects) +kernels skills add cuda-kernels --global + +# Overwrite an existing installation +kernels skills add cuda-kernels --claude --force +``` + +Once installed, prompt your agent: + +``` +Build a vectorized RMSNorm kernel for H100 targeting the Qwen3-8B model in transformers. +``` + +Or, you can go for something more open-ended: + +``` +Build an optimized attention kernel for H100 targeting the Qwen3-8B model in transformers. Benchmark it against the PyTorch baseline and validate improvements in end-to-end performance. +``` + +The agent can read the skill, select the right architecture parameters, generate the CUDA source, write the PyTorch bindings, set up `build.toml`, and create a benchmark script. + +If you're working on more complex kernels, or architecture-specific optimizations, that aren't covered in the skill, then the skill supplies the fundamental building blocks and patterns to get you started. We are also open to contributions on the [skill itself](https://github.com/huggingface/kernels/tree/main/.docs/skills). + +## What is in the skill + +The skill is roughly **550 tokens** of structured guidance plus reference scripts, GPU optimization guides, troubleshooting docs, and complete working examples. Agentic coding tools like Codex and Claude can read this and produce a working kernel project. + +It covers: + +- NVIDIA GPU Architecture-aware optimization for H100, A100, and T4 (compute capabilities, memory bandwidth, shared memory sizes, block sizing) +- Integration patterns for both `diffusers` and `transformers`, including the pitfalls specific to each library +- Kernel templates with vectorized memory access patterns for BF16, FP16, and FP32 +- Benchmarking workflows for both isolated kernel micro-benchmarks and end-to-end pipeline comparisons +- HuggingFace Kernel Hub integration via `get_kernel` for loading community kernels + +``` +.claude/skills/cuda-kernels/ +├── SKILL.md # Main instructions (~550 tokens) +├── scripts/ +│ ├── benchmark_example.py # End-to-end benchmark template +│ ├── benchmark_rmsnorm.py # Isolated kernel micro-benchmark +│ ├── ltx_kernel_injection_example.py # Diffusers integration pattern +│ ├── transformers_injection_example.py # Transformers integration pattern +│ └── huggingface_kernels_example.py # Kernel Hub integration +└── references/ + ├── diffusers-integration.md # Diffusers guide with pitfalls + ├── transformers-integration.md # Transformers guide + ├── huggingface-kernels-integration.md + ├── h100-optimization-guide.md + ├── a100-optimization-guide.md + ├── t4-optimization-guide.md + ├── kernel-templates.md + └── troubleshooting.md +``` + +When an agent loads this, it gets everything it needs to go from "write me an RMSNorm kernel" to a buildable, benchmarkable project. It will grep and glob the skill to find the relevant files and directories. So it's important to structure the skill in a way that is easy to find. + +The agent is instructed to generate kernels that conform to the templates in `references/kernel-templates.md` and produce a complete kernel project: + +``` +examples/your_model/ +├── kernel_src/ +│ └── rmsnorm.cu # Vectorized CUDA kernel +├── torch-ext/ +│ ├── your_kernels/__init__.py +│ └── torch_binding.cpp # PyTorch C++ bindings +├── benchmark_rmsnorm.py # Micro-benchmark script +├── build.toml # kernel-builder config +├── setup.py # pip install -e . +└── pyproject.toml +``` + +We tested this on two real targets. + +## Benchmarking the kernels: Diffusers (LTX-Video on H100) + +The agent built RMSNorm, RoPE 3D, GEGLU, and AdaLN kernels for [LTX-Video](https://huggingface.co/Lightricks/LTX-Video), a video generation pipeline from `diffusers`. The full example is at `examples/ltx_video/`. We optimized the RMSNorm kernel for H100. Both benchmarks were run on H100 80GB HBM3 at precision BFloat16. + +If you want to check out the generated kernel, got to [this example](https://github.com/burtenshaw/kernel-skill/tree/main/examples/ltx_video) + +### Isolated RMSNorm benchmark + +First, we compare the isolated RMSNorm kernel performance against the PyTorch baseline. This is the main speedup in the optimized pipeline. + +![isolated rmsnorm benchmark ltx-video](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kernels-skill-benchmark/rmsnorm_ltx_video.png) + +
+Table + +| Shape | Custom (ms) | PyTorch (ms) | Speedup | +| :---- | :---: | :---: | :---: | +| [1x1024x2048] | 0.039 | 0.064 | **1.64x** | +| [2x1024x2048] | 0.040 | 0.073 | **1.82x** | +| [4x1024x2048] | 0.052 | 0.093 | **1.78x** | +| [1x4096x2048] | 0.052 | 0.093 | **1.79x** | +| [2x4096x3072] | 0.102 | 0.209 | **2.04x** | +| [1x8192x2048] | 0.083 | 0.150 | **1.81x** | +| [4x4096x3072] | 0.173 | 0.393 | **2.26x** | + +**Average speedup: 1.88x** and a bandwidth efficiency: 34.7% of H100 theoretical (3,350 GB/s) + +
+ +### End-to-end video generation (49 frames, 30 steps, H100 80GB) + +Next, we compare the end-to-end video generation performance of the optimized kernels against the baseline (no compile) and the `torch.compile` baseline. + +![e2e benchmark ltx-video](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kernels-skill-benchmark/e2e_ltx_video.png) + +
+Table + +| Configuration | Time (s) | it/s | Speedup | +| :---- | :---: | :---: | :---: | +| Baseline (no compile) | 2.87 | 12.58 | 1.00x | +| **Generated Optimized Kernels** | 2.70 | 13.52 | **1.06x** | +| Baseline + torch.compile | 2.14 | 19.05 | 1.34x | +| Optimized + torch.compile | 2.01 | 18.45 | 1.43x | + +
+ +RMSNorm accounts for ~5% of total compute in LTX-Video. The remaining time is spent in attention, linear projections, and VAE decode. The 6% end-to-end speedup from a single kernel type is consistent with that profile. + +## Benchmarking the kernels: Transformers (Qwen3-8B on H100) + +The agent built an RMSNorm kernel for [Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B), a large language model from `transformers` with 65 RMSNorm modules across 32 layers. The full example is at `examples/qwen3_8b/`. We optimized the RMSNorm kernel for H100. Both benchmarks were run on H100 80GB HBM3 at precision BFloat16. + +If you want to explore the kernel, check it out [here.](https://github.com/burtenshaw/kernel-skill/tree/main/examples/qwen3_8b) + +### Isolated RMSNorm benchmark + +Once again, we compare the isolated RMSNorm kernel performance against the PyTorch baseline. + +![isolated rmsnorm benchmark qwen3-8b](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kernels-skill-benchmark/rmsnorm_qwen3.png) + +**Average speedup: 1.94x** and a bandwidth efficiency: 22.3% of H100 theoretical (3,350 GB/s) + +
+Table + +| Shape | Custom (ms) | PyTorch (ms) | Speedup | +| :---- | :---: | :---: | :---: | +| [1x128x4096] | 0.040 | 0.062 | **1.58x** | +| [1x512x4096] | 0.038 | 0.064 | **1.69x** | +| [1x1024x4096] | 0.037 | 0.071 | **1.90x** | +| [1x2048x4096] | 0.045 | 0.091 | **2.03x** | +| [1x4096x4096] | 0.071 | 0.150 | **2.12x** | +| [4x512x4096] | 0.056 | 0.093 | **1.67x** | +| [8x256x4096] | 0.045 | 0.092 | **2.06x** | +| [1x8192x4096] | 0.109 | 0.269 | **2.47x** | + +
+ +Speedup scales with sequence length: 1.58x at 128 tokens, 2.47x at 8192 tokens. For long-context inference, the custom kernel roughly halves RMSNorm latency. + +## Publishing your kernel to the Hub + +The agent gives you a working kernel. The [Kernel Hub](https://huggingface.co/kernels-community) lets you share it so anyone can load it without compilation. Here is the full path from agent output to published kernel. + +### 1. Verify the project structure + +The agent produces a project that already follows the [kernel-builder](https://huggingface.co/docs/kernels/en/builder/writing-kernels) layout: + +``` +your_kernel/ +├── build.toml # Build configuration +├── kernel_src/ +│ └── rmsnorm.cu # CUDA kernel source +└── torch-ext/ + ├── torch_binding.cpp # Registers Torch ops + └── your_kernels/ + └── __init__.py # Python API wrapping _ops +``` + +The `build.toml` tells `kernel-builder` what to build. The agent generates this for you, including the correct `cuda-capabilities` for your target GPU: + +``` +[general] +name = "your_kernels" +backends = ["cuda"] + +[torch] +src = ["torch-ext/torch_binding.cpp"] + +[kernel.rmsnorm] +backend = "cuda" +src = ["kernel_src/rmsnorm.cu"] +depends = ["torch"] +cuda-capabilities = ["9.0"] # H100 +``` + +### 2. Build all variants with Nix + +Kernel Hub kernels must support all recent PyTorch and CUDA configurations. The kernel-builder Nix flake handles this automatically. Copy the [example `flake.nix`](https://github.com/huggingface/kernels/blob/main/builder/examples/relu/flake.nix) into your project and run: + +```shell +nix flake update +nix run .#build-and-copy -L +``` + +This builds the kernel for every required PyTorch/CUDA variant and places the results in `build/`. For faster builds, enable the HuggingFace Nix cache: + +```shell +nix run nixpkgs#cachix -- use huggingface +``` + +### 3. Create a Hub repo and push + +Create a model repo on the Hub and upload the built kernel: + +```shell +huggingface-cli repo create your-org/your-kernel --type model +huggingface-cli upload your-org/your-kernel ./build +``` + +### 4. Others load it in one line + +Once published, anyone can use your kernel with zero compilation: + +```py +from kernels import get_kernel + +rmsnorm = get_kernel("your-org/your-kernel") +``` + +`get_kernel` detects the user's Python, PyTorch, and CUDA versions and downloads the matching pre-compiled binary. No builds, no flags, typically ready in seconds. + +The skill and the Hub are complementary. The skill handles development. The Hub handles distribution. Build a kernel with the skill, validate it with the benchmark scripts, publish it to the Hub, and it becomes a one-liner for everyone else. + +## Conclusion + +We built an agent skill that teaches coding agents how to write production CUDA kernels. Then we pointed Claude and Codex at two real targets: a **diffusers** pipeline and a **transformers** model. The agents produced working kernels for both, with correct PyTorch bindings and benchmarks, end to end. We benchmarked the kernels and found that the optimized kernels can provide a speedup in both isolated and end-to-end performance. + +## Resources + +- [CUDA Kernels Skill in `kernels`](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels) +- [HuggingFace Kernel Hub Blog](https://huggingface.co/blog/hello-hf-kernels) +- [We Got Claude to Fine-Tune an Open Source LLM](https://huggingface.co/blog/hf-skills-training) +- [We Got Claude to Teach Open Models](https://huggingface.co/blog/upskill) +- [HuggingFace Kernels Community](https://huggingface.co/kernels-community) diff --git a/skills/rocm-kernels/CHANGELOG.md b/skills/rocm-kernels/CHANGELOG.md new file mode 100644 index 00000000..1d52b9a5 --- /dev/null +++ b/skills/rocm-kernels/CHANGELOG.md @@ -0,0 +1,32 @@ +# Changelog + +## v0.2 (2026-03-12) + +### Added +- **Transformers integration**: `references/transformers-integration.md` — LLaMA/Mistral/Qwen RMSNorm patching, Flash Attention 2, epsilon handling differences +- **Transformers injection script**: `scripts/transformers_injection_example.py` — minimal runnable example (~150 lines) +- **HuggingFace Kernels Hub integration**: `references/huggingface-kernels-integration.md` — `get_kernel`, `has_kernel`, publishing, ROCm compatibility notes +- **HuggingFace Kernels example script**: `scripts/huggingface_kernels_example.py` — Hub loading, benchmarking, model integration with fallback +- **GEMM template with XCD swizzle**: Template 5 in `kernel-templates.md` — full GEMM kernel with XCD swizzle for MI355X, L2 cache grouping, autotune configs, Python API, and benchmark +- **CHANGELOG.md**: Version tracking for skill iterations + +### Fixed +- Broken cross-references: "Template 2" for GEMM → corrected to "Template 5" in `troubleshooting.md`, `kernelbench-classification.md`, and `skill-evaluation-methodology.md` +- R9700 Memory Bandwidth: filled in ~608 GB/s (was TBD) in SKILL.md + +### Updated +- `SKILL.md` See Also section: added new integration guides, scripts, and Hub links +- `SKILL.md` argument-hint: added gemm, transformers, huggingface-kernels, get_kernel +- `manifest.txt`: added all new files + +## v0.1 (2026-03-10) + +### Added +- Initial skill with SKILL.md, 4 kernel templates (RMSNorm, RoPE 3D, GEGLU, AdaLN) +- MI355X and R9700 GPU optimization guides +- Diffusers integration guide (LTX-Video) +- Troubleshooting guide (14 ROCm-specific issues) +- Benchmark scripts: micro-benchmark (`benchmark_kernels.py`) and E2E (`benchmark_e2e.py`) +- LTX-Video injection example (`ltx_kernel_injection_example.py`) +- KernelBench classification and evaluation methodology docs +- Kernel-agent knowledge base diff --git a/skills/rocm-kernels/SKILL.md b/skills/rocm-kernels/SKILL.md new file mode 100644 index 00000000..f1d9ad58 --- /dev/null +++ b/skills/rocm-kernels/SKILL.md @@ -0,0 +1,468 @@ +--- +name: rocm-kernels +description: "Provides guidance for writing and benchmarking optimized Triton kernels for AMD GPUs (MI355X, R9700) on ROCm, targeting HuggingFace diffusers (LTX-Video, SD3, FLUX) and transformers. Core kernels: RMSNorm, RoPE 3D, GEGLU, AdaLN. Includes XCD swizzle, autotune, diffusers integration patterns, and LTX-Video pipeline injection." +disable-model-invocation: false +user-invocable: true +allowed-tools: "Read, Grep, Glob, Bash" +argument-hint: "kernel type: rmsnorm, rope, rope-3d, geglu, adaln, gemm, benchmark, diffusers, transformers, ltx-video, huggingface-kernels, get_kernel, autotune, xcd-swizzle" +--- + +# ROCm Triton Kernels for Diffusers & Transformers + +This skill provides patterns and guidance for developing optimized Triton kernels targeting AMD GPUs (MI355X, R9700) on ROCm, for use with HuggingFace **diffusers** (LTX-Video, SD3, FLUX) and **transformers** libraries. + +## Quick Start + +### Diffusers (LTX-Video) + +**Inject optimized kernels into LTX-Video pipeline:** +```python +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' + +from diffusers import LTXPipeline +pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) +pipe.to("cuda") # ROCm uses same API via HIP +inject_optimized_kernels(pipe) # BEFORE CPU offloading +pipe.enable_model_cpu_offload() +``` + +**For a minimal integration example (~150 lines):** +```bash +python scripts/ltx_kernel_injection_example.py +``` + +### Isolated Kernel Micro-benchmarks +```bash +# All 4 kernels: correctness + performance + bandwidth +python scripts/benchmark_kernels.py + +# Single kernel +python scripts/benchmark_kernels.py --kernel rmsnorm +python scripts/benchmark_kernels.py --kernel rope +python scripts/benchmark_kernels.py --kernel geglu +python scripts/benchmark_kernels.py --kernel adaln +``` + +### End-to-End Pipeline Benchmark +```bash +# Compare baseline vs Triton vs torch.compile +python scripts/benchmark_e2e.py --mode all + +# Quick test +python scripts/benchmark_e2e.py --mode triton --num-frames 9 --steps 5 + +# Save results for comparison +python scripts/benchmark_e2e.py --mode all --output-json results.json +``` + +## Target Model: LTX-Video + +### Architecture Overview + +| Component | Class | Has Weight | Count | Kernel | +|-----------|-------|------------|-------|--------| +| `transformer_blocks.*.norm1` | RMSNorm | **No** (elementwise_affine=False) | 56 | RMSNorm | +| `transformer_blocks.*.norm2` | RMSNorm | **No** | 56 | RMSNorm | +| `transformer_blocks.*.attn1.norm_q` | torch.nn.RMSNorm | Yes | 28 | RMSNorm | +| `transformer_blocks.*.attn1.norm_k` | torch.nn.RMSNorm | Yes | 28 | RMSNorm | +| `transformer_blocks.*.ff` | FeedForward | - | 28 | **GELU** (not GEGLU!) | +| Rotary position encoding | LTXVideoRotaryPosEmbed | - | 1 | RoPE 3D | + +**Total RMSNorm modules: 168** (56 with weights, 112 without) + +### Target Kernels + +| Kernel | Use Case | Input Layout | Key Challenge | +|--------|----------|-------------|---------------| +| **RMSNorm** | Normalization | `[..., hidden_size]` | Weight may be None; 168 instances | +| **RoPE 3D** | Video position encoding | `[batch, t*h*w, heads, head_dim]` | 3D → temporal + spatial decomposition | +| **GEGLU** | Gated activation (SD3/FLUX) | `[batch, seq, 2*hidden]` → `[batch, seq, hidden]` | Gate/value split | +| **AdaLN** | Conditioned normalization (DiT) | `norm(x) * weight * (1+scale) + shift` | Fused norm + condition | + +## Supported Hardware + +| GPU | Architecture | Wave Size | LDS/CU | Mem BW | Key Feature | +|-----|-------------|-----------|--------|--------|-------------| +| **MI355X** | CDNA3+ (gfx950) | Wave64 | **160 KB** | 8 TB/s | 32 XCDs, XCD Swizzle for GEMM | +| **R9700** | RDNA4 (gfx1201) | **Wave32** | 64 KB | ~608 GB/s | 256B cacheline, inference-focused | + +> See [MI355X guide](references/mi355x-optimization-guide.md) | [R9700 guide](references/r9700-optimization-guide.md) + +## When This Skill Applies + +Use this skill when: +- Writing Triton kernels for **RMSNorm, RoPE, GEGLU, AdaLN** on AMD GPUs +- Integrating custom kernels with **diffusers** pipelines (LTX-Video, SD3, FLUX) +- Benchmarking kernel performance against PyTorch baseline on ROCm +- Optimizing existing kernels for MI355X or R9700 architecture +- Debugging ROCm/HIP-specific kernel issues + +## Critical ROCm Constraints + +### Things That DON'T Work on AMD + +```python +# FORBIDDEN - CUDA only, NOT available on ROCm +tl.libdevice.tanh(x) # Use manual formula below +tl.libdevice.log1p(x) # Use: tl.log(1.0 + x) +tl.math.tanh(x) # Also NOT available on ROCm Triton + +# Manual tanh (ONLY reliable method on ROCm): +e2x = tl.exp(2.0 * x) +tanh_x = (e2x - 1.0) / (e2x + 1.0) + +# FORBIDDEN - Triton limitations on ROCm +break / continue # Use: tl.where() +min(a, b) / max(a, b) # Use: tl.minimum(a, b) / tl.maximum(a, b) +``` + +### Mandatory Environment Variables + +```python +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' +``` + +## Core Kernel Implementations + +### 1. RMSNorm (Core Optimization Target) + +Row-wise reduction pattern. **168 instances** in LTX-Video, ~5% of total compute. + +**CRITICAL: Do NOT autotune BLOCK_D.** Autotune may pick `BLOCK_D < D`, causing partial row processing and wrong results. Always compute `BLOCK_D = triton.next_power_of_2(D)` in the Python wrapper. + +```python +@triton.jit +def rmsnorm_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x, D, + eps: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + + if HAS_WEIGHT: + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + out = x * rms_inv * w + else: + out = x * rms_inv + + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) + + +def triton_rmsnorm(x, weight=None, eps=1e-6): + x_2d = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_2d) + M, D = x_2d.shape + has_weight = weight is not None + if not has_weight: + weight = torch.empty(0, device=x.device) + + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + rmsnorm_kernel[(M,)]( + x_2d, weight, out, x_2d.stride(0), D, eps, has_weight, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) +``` + +**LTX-Video pitfall: Weight may be None!** +```python +has_weight = hasattr(module, 'weight') and module.weight is not None +``` + +### 2. RoPE 3D (Video Position Encoding) + +Element-wise pattern. LTX-Video splits `head_dim` into temporal + spatial components. + +**CRITICAL: cos/sin have shape `[seq_len, head_dim]`.** When grid flattens batch dimension (`batch * seq_len`), use `pid_s % seq_len` to index cos/sin, otherwise batch > 1 causes OOB GPU crash. + +```python +@triton.jit +def rope_3d_kernel( + qk_ptr, cos_ptr, sin_ptr, out_ptr, + seq_len, num_heads, head_dim, + stride_s, stride_h, stride_d, + BLOCK_HD: tl.constexpr, +): + pid_s = tl.program_id(0) # batch * seq_len + pid_h = tl.program_id(1) # head index + half_dim = head_dim // 2 + offs = tl.arange(0, BLOCK_HD) + mask = offs < half_dim + + base = pid_s * stride_s + pid_h * stride_h + x0 = tl.load(qk_ptr + base + offs, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load(qk_ptr + base + half_dim + offs, mask=mask, other=0.0).to(tl.float32) + + seq_idx = pid_s % seq_len # wrap for batch > 1 + cos_val = tl.load(cos_ptr + seq_idx * head_dim + offs, mask=mask, other=1.0).to(tl.float32) + sin_val = tl.load(sin_ptr + seq_idx * head_dim + offs, mask=mask, other=0.0).to(tl.float32) + + out0 = x0 * cos_val - x1 * sin_val + out1 = x0 * sin_val + x1 * cos_val + + tl.store(out_ptr + base + offs, out0.to(x0.dtype), mask=mask) + tl.store(out_ptr + base + half_dim + offs, out1.to(x0.dtype), mask=mask) + + +def triton_rope_3d(qk, cos, sin): + qk = qk.contiguous() + out = torch.empty_like(qk) + batch, seq_len, num_heads, head_dim = qk.shape + qk_flat = qk.view(batch * seq_len, num_heads, head_dim) + out_flat = out.view(batch * seq_len, num_heads, head_dim) + BLOCK_HD = triton.next_power_of_2(head_dim // 2) + num_warps = 4 if BLOCK_HD <= 64 else 8 + rope_3d_kernel[(batch * seq_len, num_heads)]( + qk_flat, cos, sin, out_flat, + seq_len, num_heads, head_dim, + qk_flat.stride(0), qk_flat.stride(1), qk_flat.stride(2), + BLOCK_HD=BLOCK_HD, num_warps=num_warps, num_stages=2, + ) + return out +``` + +### 3. GEGLU (For SD3/FLUX, NOT LTX-Video) + +Element-wise gated activation. Input `[batch, seq, 2*hidden]` → Output `[batch, seq, hidden]`. + +**Same BLOCK_SIZE rule: compute dynamically, do NOT autotune.** + +```python +@triton.jit +def geglu_kernel( + input_ptr, output_ptr, + stride_in, stride_out, hidden_size, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_H) + mask = offs < hidden_size + + gate = tl.load(input_ptr + row * stride_in + offs, mask=mask, other=0.0).to(tl.float32) + value = tl.load(input_ptr + row * stride_in + hidden_size + offs, mask=mask, other=0.0).to(tl.float32) + + # GELU approx — manual tanh (tl.math.tanh NOT available on ROCm) + k = 0.7978845608028654 # sqrt(2/pi) + tanh_arg = k * (gate + 0.044715 * gate * gate * gate) + e2x = tl.exp(2.0 * tanh_arg) + tanh_val = (e2x - 1.0) / (e2x + 1.0) + gate_gelu = 0.5 * gate * (1.0 + tanh_val) + result = gate_gelu * value + + tl.store(output_ptr + row * stride_out + offs, result.to(gate.dtype), mask=mask) + + +def triton_geglu(x): + x = x.contiguous() + *batch_dims, double_h = x.shape + hidden_size = double_h // 2 + x_2d = x.view(-1, double_h) + M = x_2d.shape[0] + out = torch.empty(M, hidden_size, device=x.device, dtype=x.dtype) + BLOCK_H = triton.next_power_of_2(hidden_size) + num_warps = 4 if BLOCK_H <= 1024 else (8 if BLOCK_H <= 4096 else 16) + geglu_kernel[(M,)]( + x_2d, out, x_2d.stride(0), out.stride(0), hidden_size, + BLOCK_H=BLOCK_H, num_warps=num_warps, num_stages=2, + ) + return out.view(*batch_dims, hidden_size) +``` + +**Warning: LTX-Video uses GELU, NOT GEGLU.** GEGLU is for SD3/FLUX. + +### 4. AdaLN (Adaptive Layer Normalization for DiT) + +Fused normalization + conditioning: `norm(x) * weight * (1 + scale) + shift` + +**Same BLOCK_D rule: compute dynamically.** + +```python +@triton.jit +def adaln_kernel( + x_ptr, weight_ptr, scale_ptr, shift_ptr, out_ptr, + stride_x, stride_cond, D, + eps: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + x_norm = x * rms_inv + + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + scale = tl.load(scale_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) + shift = tl.load(shift_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) + + out = x_norm * w * (1.0 + scale) + shift + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) + + +def triton_adaln(x, weight, scale, shift, eps=1e-6): + x_flat = x.contiguous().view(-1, x.shape[-1]) + scale_flat = scale.contiguous().view(-1, x.shape[-1]) + shift_flat = shift.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_flat) + M, D = x_flat.shape + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + adaln_kernel[(M,)]( + x_flat, weight, scale_flat, shift_flat, out, + x_flat.stride(0), scale_flat.stride(0), D, eps, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) +``` + +## Diffusers Integration + +> **See [diffusers-integration.md](references/diffusers-integration.md) for the complete guide.** + +### Minimal Integration Pattern + +```python +def patch_rmsnorm_modules(model): + """Patch all RMSNorm modules to use custom Triton kernel.""" + for name, module in model.named_modules(): + if type(module).__name__ == 'RMSNorm': + eps = getattr(module, 'eps', 1e-6) + has_weight = hasattr(module, 'weight') and module.weight is not None + if has_weight: + def make_forward(mod, epsilon): + def forward(x): + return triton_rmsnorm(x, mod.weight, eps=epsilon) + return forward + module.forward = make_forward(module, eps) + else: + def make_forward(epsilon): + def forward(x): + w = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype) + return triton_rmsnorm(x, w, eps=epsilon) + return forward + module.forward = make_forward(eps) + +pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) +pipe.to("cuda") +patch_rmsnorm_modules(pipe.transformer) +pipe.enable_model_cpu_offload() +``` + +### Diffusers Critical Pitfalls + +1. **RMSNorm weight may be None** — LTX-Video uses `elementwise_affine=False` +2. **Diffusers RMSNorm != torch.nn.RMSNorm** — Use `type(module).__name__` not `isinstance()` +3. **LTX-Video uses GELU, not GEGLU** — Don't patch GEGLU for LTX-Video +4. **Inject BEFORE CPU offloading** — `inject_kernels()` then `enable_model_cpu_offload()` + +## Performance Expectations + +### Micro-benchmark Results (MI355X, BF16) + +| Kernel | Avg Speedup | Best Config Speedup | Status | +|--------|:-----------:|:-------------------:|:------:| +| **RMSNorm** | **1.71x** | 2.44x ([4×4096×3072]) | PASS | +| **RoPE 3D** | **1.21x** | 1.52x ([2×4096×16×128]) | PASS | +| **GEGLU** | **1.43x** | 2.13x ([4×4096×8192]) | PASS | +| **AdaLN** | **2.22x** | 2.77x ([4×4096×3072]) | PASS | + +RMSNorm bandwidth utilization: 3554 GB/s (MI355X theoretical: 8 TB/s, ~44%). + +### End-to-End LTX-Video (MI355X, 25 frames, 30 steps) + +| Mode | Time (s) | Per Step (s) | Peak Mem (GB) | Speedup | +|------|:--------:|:------------:|:-------------:|:-------:| +| baseline | 1.20 | 0.040 | 18.58 | 1.00x | +| **triton** | **0.98** | **0.033** | **18.58** | **1.22x** | +| torch.compile | 0.78 | 0.026 | 18.58 | 1.54x | + +**Key finding**: MI355X Triton E2E speedup (22%) is significantly higher than H100 CUDA reference (6%), because MI355X's default PyTorch RMSNorm path has more room for optimization. + +### CUDA Reference (H100, for comparison) + +| Shape | Custom (ms) | PyTorch (ms) | Speedup | +|:---|:---:|:---:|:---:| +| [1×1024×2048] | 0.019 | 0.065 | **3.37x** | +| [2×4096×3072] | 0.087 | 0.208 | **2.41x** | + +H100 E2E: ~6% (RMSNorm is ~5% of total compute). + +### Optimization Targets + +| Kernel | Current | Target | Priority | +|--------|:-------:|:------:|:--------:| +| RMSNorm | 1.71x | >2x | P0 — increase bandwidth util (44%→60%+) | +| AdaLN | 2.22x | >2.5x | P1 — already strong | +| GEGLU | 1.43x | >1.5x | P1 — tanh overhead | +| RoPE 3D | 1.21x | >1.5x | P2 — small head_dim launch overhead | + +## Common Issues on ROCm + +| Issue | Symptom | Fix | +|-------|---------|-----| +| **Autotune BLOCK_D** | Wrong results (max_abs 4-8+) | **Never autotune BLOCK_D.** Use `triton.next_power_of_2(D)` | +| **RoPE batch OOB** | GPU crash (`Memory access fault`) | Use `pid_s % seq_len` for cos/sin indexing | +| `tl.libdevice` | Not found on AMD | Use manual math formulas | +| `tl.tanh` / `tl.math.tanh` | Not on ROCm | Manual: `e2x=exp(2x); (e2x-1)/(e2x+1)` | +| Python min/max | Runtime error | `tl.minimum()`/`tl.maximum()` | +| LDS overflow | HIP OOM | Reduce num_stages to 2 | +| Weight is None | AttributeError | Check `elementwise_affine` | +| isinstance() miss | RMSNorm not patched | Use `type(module).__name__` | + +> See [troubleshooting.md](references/troubleshooting.md) for all common issues. + +## Performance Profiling + +```bash +rocprof --stats python your_kernel.py +rocprofv3 -i metrics.txt python your_kernel.py +rocm-bandwidth-test +rocminfo | grep -E "Name|Compute Unit|Wavefront" +``` + +## See Also + +### Benchmark & Test Scripts +- [benchmark_kernels.py](scripts/benchmark_kernels.py) - Micro-benchmark all 4 kernels (correctness + perf + bandwidth) +- [benchmark_e2e.py](scripts/benchmark_e2e.py) - End-to-end LTX-Video pipeline benchmark (baseline vs Triton vs compile) +- [ltx_kernel_injection_example.py](scripts/ltx_kernel_injection_example.py) - Minimal diffusers injection example +- [transformers_injection_example.py](scripts/transformers_injection_example.py) - Minimal transformers injection example +- [huggingface_kernels_example.py](scripts/huggingface_kernels_example.py) - HuggingFace Kernels Hub integration example + +### Integration Guides +- [diffusers-integration.md](references/diffusers-integration.md) - LTX-Video pipeline integration +- [transformers-integration.md](references/transformers-integration.md) - LLaMA/Mistral/Qwen integration +- [huggingface-kernels-integration.md](references/huggingface-kernels-integration.md) - HuggingFace Kernels Hub (`get_kernel`) +- [kernel-templates.md](references/kernel-templates.md) - Complete Triton kernel templates (incl. GEMM with XCD Swizzle) + +### GPU Optimization Guides +- [mi355x-optimization-guide.md](references/mi355x-optimization-guide.md) - MI355X (gfx950) deep dive +- [r9700-optimization-guide.md](references/r9700-optimization-guide.md) - R9700 (RDNA4) deep dive + +### Reference +- [troubleshooting.md](references/troubleshooting.md) - Common issues and solutions +- [kernelbench-classification.md](references/kernelbench-classification.md) - KernelBench operator taxonomy +- [skill-evaluation-methodology.md](references/skill-evaluation-methodology.md) - How to evaluate and improve skills +- [kernel-agent-knowledge-base.md](references/kernel-agent-knowledge-base.md) - Knowledge from kernel-agent project + +### External Resources +- [Triton Documentation](https://triton-lang.org/) +- [ROCm Documentation](https://rocm.docs.amd.com/) +- [HuggingFace Kernels Hub](https://huggingface.co/kernels-community) +- [LTX-Video on HuggingFace](https://huggingface.co/Lightricks/LTX-Video) +- [HuggingFace Diffusers](https://huggingface.co/docs/diffusers/en/index) diff --git a/skills/rocm-kernels/manifest.txt b/skills/rocm-kernels/manifest.txt new file mode 100644 index 00000000..3e88dac8 --- /dev/null +++ b/skills/rocm-kernels/manifest.txt @@ -0,0 +1,18 @@ +# Files for rocm-kernels skill +SKILL.md +CHANGELOG.md +references/mi355x-optimization-guide.md +references/r9700-optimization-guide.md +references/kernel-templates.md +references/diffusers-integration.md +references/transformers-integration.md +references/huggingface-kernels-integration.md +references/troubleshooting.md +references/skill-evaluation-methodology.md +references/kernelbench-classification.md +references/kernel-agent-knowledge-base.md +scripts/benchmark_kernels.py +scripts/benchmark_e2e.py +scripts/ltx_kernel_injection_example.py +scripts/transformers_injection_example.py +scripts/huggingface_kernels_example.py diff --git a/skills/rocm-kernels/references/diffusers-integration.md b/skills/rocm-kernels/references/diffusers-integration.md new file mode 100644 index 00000000..1809e0d1 --- /dev/null +++ b/skills/rocm-kernels/references/diffusers-integration.md @@ -0,0 +1,252 @@ +# Diffusers Pipeline Integration Guide (ROCm) + +Integrating custom Triton kernels into HuggingFace diffusers pipelines on AMD GPUs. + +## Overview + +This guide covers injecting optimized Triton kernels (RMSNorm, RoPE 3D, GEGLU, AdaLN) into diffusers pipelines running on ROCm. The patterns are analogous to the CUDA kernel integration but use Triton instead of CUDA C. + +## LTX-Video Architecture + +### Module Inventory + +```python +from diffusers import LTXPipeline +pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) + +# Analyze RMSNorm modules +for name, module in pipe.transformer.named_modules(): + if 'Norm' in type(module).__name__: + has_weight = hasattr(module, 'weight') and module.weight is not None + print(f"{name}: {type(module).__name__} (has_weight={has_weight})") +``` + +### Kernel Applicability in LTX-Video + +| Kernel | Used? | Count | Notes | +|--------|-------|-------|-------| +| **RMSNorm** | Yes | **168** | 56 with weights, 112 without | +| **RoPE 3D** | Indirect | 1 | Diffusers computes via LTXVideoRotaryPosEmbed | +| **GEGLU** | **No** | 0 | LTX uses `activation_fn="gelu-approximate"` | +| **AdaLN** | Partial | ~28 | Scale/shift pattern in transformer blocks | + +## Integration Pattern + +### Step 1: Triton RMSNorm Wrapper + +```python +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _rmsnorm_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x_row, D, + eps: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + + x = tl.load(x_ptr + row * stride_x_row + offs, mask=mask, other=0.0).to(tl.float32) + sq_sum = tl.sum(x * x, axis=0) + rms_inv = tl.rsqrt(sq_sum / D + eps) + + if HAS_WEIGHT: + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + out = x * rms_inv * w + else: + out = x * rms_inv + + tl.store(out_ptr + row * stride_x_row + offs, out.to(x.dtype), mask=mask) + + +def triton_rmsnorm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6): + """Drop-in replacement for RMSNorm forward pass.""" + x_flat = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_flat) + M, D = x_flat.shape + has_weight = weight is not None + + if not has_weight: + weight = torch.ones(D, device=x.device, dtype=x.dtype) + + # CRITICAL: BLOCK_D must be >= D. Never autotune BLOCK_D. + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + _rmsnorm_kernel[(M,)]( + x_flat, weight, out, + x_flat.stride(0), D, + eps, has_weight, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) +``` + +### Step 2: Module Patcher + +```python +def patch_rmsnorm_modules(model) -> int: + """ + Patch all RMSNorm modules to use Triton kernel on ROCm. + + Handles both: + - Modules WITH weight (elementwise_affine=True) — attention norm_q/norm_k + - Modules WITHOUT weight (elementwise_affine=False) — transformer block norms + """ + patched = 0 + for name, module in model.named_modules(): + # IMPORTANT: Use class name, NOT isinstance + if type(module).__name__ == 'RMSNorm': + eps = getattr(module, 'eps', 1e-6) + has_weight = hasattr(module, 'weight') and module.weight is not None + + if has_weight: + def make_forward(mod, epsilon): + def forward(x): + return triton_rmsnorm(x, mod.weight, eps=epsilon) + return forward + module.forward = make_forward(module, eps) + else: + def make_forward_no_weight(epsilon): + def forward(x): + return triton_rmsnorm(x, None, eps=epsilon) + return forward + module.forward = make_forward_no_weight(eps) + + patched += 1 + return patched +``` + +### Step 3: Pipeline Injection + +```python +def inject_optimized_kernels(pipe) -> dict: + """ + Inject Triton kernels into LTX-Video pipeline. + + Call AFTER pipe.to("cuda"), BEFORE pipe.enable_model_cpu_offload(). + """ + stats = {'rmsnorm_modules': 0} + + if not hasattr(pipe, 'transformer'): + print("WARNING: Pipeline has no 'transformer' attribute!") + return stats + + stats['rmsnorm_modules'] = patch_rmsnorm_modules(pipe.transformer) + return stats +``` + +### Step 4: Usage + +```python +import torch +from diffusers import LTXPipeline +from diffusers.utils import export_to_video + +pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) +pipe.to("cuda") # ROCm via HIP + +stats = inject_optimized_kernels(pipe) +print(f"RMSNorm modules patched: {stats['rmsnorm_modules']}") +# Expected: 168 + +pipe.enable_model_cpu_offload() # AFTER injection + +output = pipe( + prompt="A cat sleeping in the sun", + num_frames=25, height=480, width=704, + num_inference_steps=30, +) +export_to_video(output.frames[0], "output.mp4", fps=24) +``` + +## Model-Specific Notes + +### LTX-Video +- Uses **GELU** (`activation_fn="gelu-approximate"`), NOT GEGLU +- RMSNorm in blocks: `elementwise_affine=False` (no weight) +- RMSNorm in attention: `elementwise_affine=True` (has weight) +- RoPE: Computed by diffusers via `LTXVideoRotaryPosEmbed` + +### SD3 / FLUX +- Uses **GEGLU** in FeedForward blocks +- Different attention patterns +- May have different normalization conventions +- Verify architecture before applying LTX-Video patterns + +## ROCm-Specific Considerations + +### BF16 vs FP16 + +```python +# MI355X supports BF16 — use it for diffusers +pipe = LTXPipeline.from_pretrained(..., torch_dtype=torch.bfloat16) + +# R9700 (RDNA4) — check BF16 support, may need FP16 +# torch_dtype=torch.float16 +``` + +### ROCm Memory Management + +```python +# ROCm uses same API as CUDA via HIP +pipe.to("cuda") # Works on ROCm +pipe.enable_model_cpu_offload() # Works on ROCm +torch.cuda.empty_cache() # Works on ROCm +``` + +### Triton on ROCm vs CUDA C Kernels + +| Aspect | CUDA C (original skill) | Triton (this skill) | +|--------|------------------------|---------------------| +| Build system | setup.py + nvcc | No build needed | +| Portability | NVIDIA only | AMD + NVIDIA | +| Performance | Maximum | 80-95% of CUDA C | +| Complexity | High (C++/CUDA) | Lower (Python) | +| Autotune | Manual | `@triton.autotune` | +| torch.compile | Needs custom op | Automatic compatibility | + +## Verification + +```python +# Check injection worked +for name, module in pipe.transformer.named_modules(): + if type(module).__name__ == 'RMSNorm': + x = torch.randn(1, 10, 2048, device='cuda', dtype=torch.bfloat16) + out = module(x) + print(f"RMSNorm forward: {x.shape} -> {out.shape}") + break + +# Compare with PyTorch reference +def pytorch_rmsnorm(x, weight, eps=1e-6): + rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + if weight is not None: + return x * rms * weight + return x * rms + +# Verify correctness +torch.testing.assert_close( + triton_rmsnorm(x, weight, eps=1e-6), + pytorch_rmsnorm(x, weight, eps=1e-6), + rtol=1e-2, atol=1e-3 +) +``` + +## Troubleshooting + +| Issue | Fix | +|-------|-----| +| `NoneType has no attribute contiguous` | RMSNorm weight is None, pass `None` to kernel | +| `isinstance()` not matching | Use `type(module).__name__ == 'RMSNorm'` | +| GEGLU not called | LTX-Video uses GELU, not GEGLU | +| Patching doesn't persist | Inject BEFORE `enable_model_cpu_offload()` | +| HIP error during inference | Check ROCm version compatibility with PyTorch | diff --git a/skills/rocm-kernels/references/huggingface-kernels-integration.md b/skills/rocm-kernels/references/huggingface-kernels-integration.md new file mode 100644 index 00000000..e8287ddb --- /dev/null +++ b/skills/rocm-kernels/references/huggingface-kernels-integration.md @@ -0,0 +1,351 @@ +# HuggingFace Kernels Integration Guide (ROCm) + +Complete guide for using and publishing kernels with the HuggingFace Kernels library (`get_kernel`) on ROCm. + +> **Quick Start:** See [huggingface_kernels_example.py](../scripts/huggingface_kernels_example.py) for a minimal working example. + +## Overview + +The [HuggingFace Kernels](https://huggingface.co/docs/kernels/en/index) library enables dynamic loading of pre-compiled kernels from the Hugging Face Hub. This eliminates the need for local compilation and ensures compatibility across different Python, PyTorch, and CUDA/ROCm versions. + +**Key Benefits:** +- **No local compilation** — download pre-built binaries +- **Version management** — load specific kernel versions +- **Multi-version support** — multiple versions coexist in one Python process +- **Automatic compatibility** — matches your PyTorch/ROCm configuration + +**ROCm Note:** Not all Hub kernels have ROCm builds. Triton-based kernels (e.g., `triton-layer-norm`) are more likely to work on ROCm than CUDA C kernels. Always check with `has_kernel()` first. + +## Installation + +```bash +pip install kernels torch numpy +``` + +Requirements: +- PyTorch >= 2.5 (ROCm build) +- ROCm-capable AMD GPU +- Python 3.8+ + +## Core API + +### get_kernel + +Download and load a kernel from the Hub: + +```python +from kernels import get_kernel + +kernel = get_kernel("kernels-community/triton-layer-norm") + +# With specific version +kernel = get_kernel("kernels-community/triton-layer-norm", version=1) + +# With specific revision +kernel = get_kernel("kernels-community/flash-attn", revision="v2.0.0") +``` + +**Parameters:** + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `repo_id` | str | required | Hub repository (e.g., "kernels-community/activation") | +| `revision` | str | "main" | Branch, tag, or commit hash | +| `version` | int/str | None | Kernel version number (mutually exclusive with `revision`) | + +**Returns:** `ModuleType` — the imported kernel module + +### has_kernel + +Check if a kernel build exists for your environment: + +```python +from kernels import has_kernel + +if has_kernel("kernels-community/triton-layer-norm"): + kernel = get_kernel("kernels-community/triton-layer-norm") +else: + print("No compatible build for this ROCm/PyTorch version") +``` + +### get_local_kernel + +Load a kernel from a local path (useful during development): + +```python +from kernels import get_local_kernel + +kernel = get_local_kernel("/path/to/my-kernel") +``` + +### load_kernel & get_locked_kernel + +For reproducible, offline-capable deployments using lockfiles: + +```python +from kernels import load_kernel, get_locked_kernel + +kernel = load_kernel("lockfile.json") +kernel = get_locked_kernel("kernels-community/activation", lockfile="kernel.lock") +``` + +## Usage Examples + +### 1. RMSNorm Kernel from Hub + +**Note:** The actual function name may vary by kernel version. Use `dir(kernel)` to inspect, and check for `rms_norm_fn`, `rms_norm`, or `rmsnorm`. + +```python +import torch +from kernels import get_kernel, has_kernel + +repo_id = "kernels-community/triton-layer-norm" + +if has_kernel(repo_id): + layer_norm = get_kernel(repo_id) + + # Inspect available functions + print([f for f in dir(layer_norm) if not f.startswith('_')]) + # e.g. ['layer_norm', 'layer_norm_fn', 'rms_norm_fn', ...] + + x = torch.randn(2, 1024, 2048, dtype=torch.bfloat16, device="cuda") + weight = torch.ones(2048, dtype=torch.bfloat16, device="cuda") + + # Use the actual function name (rms_norm_fn in current version) + out = layer_norm.rms_norm_fn(x, weight, eps=1e-6) + print(f"Output shape: {out.shape}") +else: + print("No ROCm-compatible build available") +``` + +### 2. Integration with Transformers Models + +```python +import torch +from kernels import get_kernel, has_kernel + +repo_id = "kernels-community/triton-layer-norm" + +if has_kernel(repo_id): + rmsnorm_kernel = get_kernel(repo_id) + + def patch_rmsnorm_with_hub_kernel(model): + """Patch model's RMSNorm to use Hub kernel.""" + patched = 0 + for name, module in model.named_modules(): + if 'RMSNorm' in type(module).__name__: + eps = getattr(module, 'variance_epsilon', None) or getattr(module, 'eps', 1e-6) + + def make_forward(mod, epsilon): + def forward(hidden_states): + return rmsnorm_kernel.rms_norm(hidden_states, mod.weight, eps=epsilon) + return forward + + module.forward = make_forward(module, eps) + patched += 1 + return patched +``` + +### 3. Integration with Diffusers Pipelines + +```python +import torch +from diffusers import LTXPipeline +from kernels import get_kernel, has_kernel + +if has_kernel("kernels-community/triton-layer-norm"): + rmsnorm_kernel = get_kernel("kernels-community/triton-layer-norm") + + def patch_rmsnorm(model): + for name, module in model.named_modules(): + if type(module).__name__ == 'RMSNorm': + eps = getattr(module, 'eps', 1e-6) + has_weight = hasattr(module, 'weight') and module.weight is not None + + if has_weight: + def make_forward(mod, epsilon): + def forward(x): + return rmsnorm_kernel.rms_norm(x, mod.weight, eps=epsilon) + return forward + module.forward = make_forward(module, eps) + + pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) + pipe.to("cuda") + patch_rmsnorm(pipe.transformer) +``` + +### 4. Benchmark Hub Kernel vs PyTorch + +```python +import time +import torch +from kernels import get_kernel + +kernel = get_kernel("kernels-community/triton-layer-norm") + +sizes = [(2, 1024, 2048), (4, 4096, 4096)] +for shape in sizes: + x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") + w = torch.ones(shape[-1], dtype=torch.bfloat16, device="cuda") + + for _ in range(10): + kernel.rms_norm(x, w, eps=1e-6) + variance = x.pow(2).mean(-1, keepdim=True) + _ = x * torch.rsqrt(variance + 1e-6) * w + torch.cuda.synchronize() + + iters = 100 + start = time.perf_counter() + for _ in range(iters): + kernel.rms_norm(x, w, eps=1e-6) + torch.cuda.synchronize() + hub_ms = (time.perf_counter() - start) / iters * 1000 + + start = time.perf_counter() + for _ in range(iters): + variance = x.pow(2).mean(-1, keepdim=True) + _ = x * torch.rsqrt(variance + 1e-6) * w + torch.cuda.synchronize() + pt_ms = (time.perf_counter() - start) / iters * 1000 + + print(f"Shape {shape}: Hub={hub_ms:.3f}ms, PyTorch={pt_ms:.3f}ms, Speedup={pt_ms/hub_ms:.2f}x") +``` + +## ROCm-Specific Notes + +### Kernel Compatibility + +Not all Hub kernels have ROCm builds: + +| Kernel Type | ROCm Support | Notes | +|-------------|:------------:|-------| +| Triton-based (e.g., `triton-layer-norm`) | Likely | Triton compiles to HIP | +| CUDA C-based (e.g., `flash-attn`) | Check | Needs explicit ROCm build | +| Custom CUDA ops | Unlikely | CUDA-only unless HIP-ported | + +**Always check availability first:** +```python +from kernels import has_kernel + +if has_kernel("kernels-community/triton-layer-norm"): + print("ROCm build available") +else: + print("No ROCm build — use local Triton kernel instead") +``` + +### Fallback Strategy + +When a Hub kernel is not available for ROCm, fall back to the local Triton implementation: + +```python +from kernels import has_kernel, get_kernel + +def get_rmsnorm_function(): + """Get best available RMSNorm implementation.""" + if has_kernel("kernels-community/triton-layer-norm"): + kernel = get_kernel("kernels-community/triton-layer-norm") + return lambda x, w, eps: kernel.rms_norm(x, w, eps=eps) + else: + from your_local_kernels import triton_rmsnorm + return triton_rmsnorm +``` + +### Environment Check + +```python +import torch +print(f"PyTorch: {torch.__version__}") +print(f"HIP version: {getattr(torch.version, 'hip', 'N/A')}") +print(f"GPU: {torch.cuda.get_device_name()}") +print(f"GPU arch: {torch.cuda.get_device_capability()}") +``` + +## Publishing Kernels to Hub + +### Triton Kernel Project Structure + +For Triton-based kernels (best ROCm compatibility): + +``` +my-triton-kernel/ +├── build.toml +├── kernel_src/ +│ └── rmsnorm.py # Triton kernel source +└── torch-ext/ + ├── torch_binding.cpp + └── my_kernels/ + └── __init__.py +``` + +### build.toml for Triton Kernels + +```toml +[general] +name = "my_triton_kernels" +backends = ["cuda", "rocm"] # Include ROCm backend + +[torch] +src = ["torch-ext/torch_binding.cpp"] + +[kernel.rmsnorm] +backend = "triton" +src = ["kernel_src/rmsnorm.py"] +depends = ["torch"] +``` + +### Build and Publish + +```bash +pip install kernel-builder +kernel-builder build + +huggingface-cli repo create your-org/your-kernel --type model +huggingface-cli upload your-org/your-kernel ./dist +``` + +### Others Load It + +```python +from kernels import get_kernel + +rmsnorm = get_kernel("your-org/your-kernel") +``` + +## Available Community Kernels + +Popular kernels from `kernels-community`: + +| Kernel | Description | ROCm? | +|--------|-------------|:-----:| +| `triton-layer-norm` | LayerNorm, RMSNorm | Likely | +| `activation` | GELU, SiLU, etc. | Check | +| `flash-attn` | Flash Attention 2 | Check | +| `quantization` | INT8/INT4 ops | Check | + +Browse all kernels: https://huggingface.co/kernels-community + +## Caching and Offline Usage + +```python +import os +os.environ["HF_HUB_OFFLINE"] = "1" + +# Will only use cached kernels +kernel = get_kernel("kernels-community/triton-layer-norm") +``` + +## Best Practices + +1. **Always check availability** — `has_kernel()` before `get_kernel()` +2. **Pin versions** — `get_kernel(repo, version=1)` for reproducibility +3. **Have a fallback** — local Triton kernel when Hub build is unavailable +4. **Use lockfiles in production** — `load_kernel("kernel.lock")` +5. **Test on your GPU** — verify correctness after loading + +## See Also + +- [HuggingFace Kernels Documentation](https://huggingface.co/docs/kernels/en/index) +- [HuggingFace Kernels GitHub](https://github.com/huggingface/kernels) +- [Kernel Builder Documentation](https://github.com/huggingface/kernel-builder) +- [Community Kernels](https://huggingface.co/kernels-community) +- [Blog: Learn the Kernel Hub in 5 Minutes](https://huggingface.co/blog/hello-hf-kernels) diff --git a/skills/rocm-kernels/references/kernel-agent-knowledge-base.md b/skills/rocm-kernels/references/kernel-agent-knowledge-base.md new file mode 100644 index 00000000..25ea2cb2 --- /dev/null +++ b/skills/rocm-kernels/references/kernel-agent-knowledge-base.md @@ -0,0 +1,140 @@ +# Kernel-Agent 项目知识提取 + +本文档记录从 `/home/jixiong/kernel-agent` 项目中提取的核心知识,作为 ROCm kernel skills 的基础。 + +## 1. 项目概况 + +kernel-agent 是一个 **LLM 驱动的 Triton/Helion kernel 生成与评测工作流**,专门面向 AMD ROCm 平台。 + +| 组件 | 说明 | +|------|------| +| **后端** | Triton (主要), Helion (实验性) | +| **目标平台** | AMD GPU (ROCm) | +| **评测基准** | KernelBench (Level 1-7) | +| **工作流** | 生成 → 执行 → 正确性检查 → 性能优化 → 迭代 | +| **LLM 提供商** | OpenAI, Anthropic, Google, AMD on-prem | + +## 2. AMD GPU 硬件参数 (来自 amd_gpu_specs.py) + +### MI355X (gfx950) - CDNA3+ + +| 参数 | 值 | 优化影响 | +|------|-----|---------| +| GPU 架构 | CDNA3+ (gfx950) | 编译目标 | +| GPU 显存 | 288GB HBM3e | 大模型无压力 | +| 内存带宽 | 8 TB/s | 内存受限 kernel 的上限 | +| XCD 数量 | **32** | XCD Swizzle 必须用 NUM_XCDS=32 | +| CU 总数 | 256 | Grid 大小的倍数 | +| 每 XCD 的 CU | 8 | XCD 间负载均衡 | +| LDS/CU | **160 KB** | 比 MI300X 大 2.5 倍 | +| L2 Cache | 256 MB | 大型共享缓存 | +| Wavefront | 64 | CDNA 固定 Wave64 | +| MFMA 指令 | 16x16 (最优), 32x32 | matrix_instr_nonkdim=16 | +| FP8 格式 | float8_e4m3fn (OCP) | 与 MI300X 不同! | +| 最优 num_warps | 4-16 | autotune 范围 | +| 最优 num_stages | 2-3 | 避免 LDS 溢出 | +| 最优 BLOCK_SIZE (1D) | 1024-8192 | 比 MI300X 更大 | +| 最优 BLOCK_M/N (2D) | 128-256 | GEMM tile 大小 | + +### R9700 (RDNA4, gfx1201) + +| 参数 | 值 | 优化影响 | +|------|-----|---------| +| GPU 架构 | RDNA4 (gfx1201) | Wave32 模式 | +| Wavefront | **32** | 归约代码需要不同偏移 | +| CU 总数 | 64 | Grid 大小的倍数 | +| LDS/CU | 64 KB | 标准大小 | +| L1 Cache | 32 KB | 每 CU 私有 | +| L2 Cache | 8 MB | 全 CU 共享 | +| L3 Cache | 64 MB | 末级缓存 | +| Cacheline | **256 B** | 比 RDNA3 更大,需更严格对齐 | +| Max Threads/Block | 1024 | 32 waves × 32 threads | +| Max Threads/CU | 2048 | 64 waves × 32 threads | +| FP16 矩阵 TFLOPS | 191 | 矩阵指令 | +| FP8 矩阵 TFLOPS | 383 | 推理加速 | +| 矩阵核心 | 有限 (无 FP8 MFMA) | 不支持高级矩阵指令 | + +## 3. 关键优化知识 (来自 prompt_constructor.py) + +### MI355X 必须的优化 + +1. **XCD Swizzle (GEMM 必须)**: NUM_XCDS=32,将 block ID 映射到 32 个 XCD +2. **L2 Cache Grouping**: GROUP_M=8 或 16,提高 L2 缓存命中率 +3. **MFMA 16x16**: matrix_instr_nonkdim=16 +4. **环境变量**: `TRITON_HIP_USE_BLOCK_PINGPONG=1`, `TRITON_HIP_USE_ASYNC_COPY=1` +5. **num_stages=2-3**: 避免 LDS 溢出 + +### Triton on ROCm 禁忌 + +- **禁止** `tl.libdevice.*` (CUDA 专属) +- **禁止** `tl.tanh` (不支持,用 `(exp(2x)-1)/(exp(2x)+1)`) +- **禁止** `break/continue` (用 `tl.where` 替代) +- **禁止** Python `min()/max()` (用 `tl.minimum()/tl.maximum()`) +- **必须** 用 `tl.float32` 做累加器 +- **必须** 对 exp/log/sqrt/rsqrt/除法 转换为 FP32 + +### Autotune 配置 + +#### 逐元素 (1D) + +**MI355X**: BLOCK_SIZE = [1024, 2048, 4096, 4096, 8192, 16384] +**R9700**: BLOCK_SIZE = [256, 512, 1024] (更小) + +#### GEMM (2D) + +**MI355X**: BLOCK_M/N = [128-256], BLOCK_K = [32-64], GROUP_M = 8 + +## 4. 问题分类体系 (来自 classify_problem) + +| 类别 | 匹配模式 | 典型算子 | +|------|---------|---------| +| elementwise | relu, gelu, swish, silu, sigmoid, tanh, elu... | 激活函数 | +| softmax | softmax, logsoftmax | Softmax 变体 | +| norm | layernorm, batchnorm, rmsnorm, groupnorm... | 归一化 | +| pooling | pool | 池化操作 | +| reduction | sum_reduction, mean_reduction, max_reduction... | 归约操作 | +| attention | attention, multihead | 注意力机制 | +| matvec | matrix_vector, matvec | 矩阵-向量乘 | +| batched_gemm | batch, bmm | 批量矩阵乘 | +| gemm_2d | matmul, gemm, mm_ | 2D 矩阵乘 | + +## 5. KernelBench 测试结果关键发现 + +### 表现优秀的类别 (在 kernel-agent 上) + +| 类别 | 最佳 Speedup | 代表算子 | +|------|-------------|---------| +| Reduction | 5.00x | Sum reduction | +| Pooling | 5.16x | Average Pooling 3D | +| 激活函数 | 2.94x | Softsign, Softplus, Swish | +| 归一化 | 1.73x | LayerNorm | +| 特殊 GEMM | 1.98x | 对角矩阵乘 | + +### 需要重点优化的类别 + +| 类别 | 当前 Speedup | 根本原因 | +|------|-------------|---------| +| 大 K GEMM | 0.04x | 寄存器压力、内存访问不优 | +| BatchNorm | 0.04x | HIP 运行时错误、同步问题 | +| 对称/三角矩阵乘 | 0.08-0.20x | 线程利用率低 | +| Argmax/Argmin | FAILED | Triton API 限制 | +| 融合算子 | 0.32x (平均) | 多操作组合复杂度 | + +### 常见错误类型 + +1. **HIP Runtime Error**: GPU 内存访问冲突 +2. **精度问题**: FP16 累积误差 +3. **program_id 限制**: 3D Grid 映射 +4. **tl.store() kwarg 错误**: Triton API 差异 +5. **max_contiguous 错误**: 内存访问模式 + +## 6. 性能分析工具链 + +| 工具 | 用途 | +|------|------| +| `rocprof` / `rocprofv3` | GPU kernel profiling | +| `rocm-bandwidth-test` | 内存带宽测试 | +| `rocminfo` | GPU 设备信息 | +| `rocm-smi` | GPU 状态监控 | +| `omniperf` | 全面性能分析 | +| `omnitrace` | 系统级追踪 | diff --git a/skills/rocm-kernels/references/kernel-templates.md b/skills/rocm-kernels/references/kernel-templates.md new file mode 100644 index 00000000..702debbf --- /dev/null +++ b/skills/rocm-kernels/references/kernel-templates.md @@ -0,0 +1,512 @@ +# Triton Kernel Templates for ROCm (LTX-Video Operators) + +Copy-paste ready Triton kernel templates for RMSNorm, RoPE 3D, GEGLU, and AdaLN on AMD GPUs. + +## Required Header + +**Every kernel file MUST start with:** + +```python +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' + +import torch +import torch.nn as nn +import triton +import triton.language as tl +``` + +## Template 1: RMSNorm (Core Target) + +Row-wise reduction. **168 instances** in LTX-Video. Handles both with-weight and no-weight variants. + +**CRITICAL: Do NOT autotune BLOCK_D.** Autotune may select `BLOCK_D < D`, causing partial row processing and completely wrong results. Always compute `BLOCK_D = triton.next_power_of_2(D)` dynamically. + +### Triton Kernel + +```python +@triton.jit +def rmsnorm_fwd_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x, D, + eps: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_D) + mask = col_offsets < D + + x = tl.load(x_ptr + row * stride_x + col_offsets, mask=mask, other=0.0).to(tl.float32) + + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + + if HAS_WEIGHT: + w = tl.load(weight_ptr + col_offsets, mask=mask, other=1.0).to(tl.float32) + result = x * rms_inv * w + else: + result = x * rms_inv + + tl.store(out_ptr + row * stride_x + col_offsets, result.to(x.dtype), mask=mask) +``` + +### Python API + +```python +def triton_rmsnorm( + x: torch.Tensor, + weight: torch.Tensor = None, + eps: float = 1e-6, +) -> torch.Tensor: + orig_shape = x.shape + x_2d = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_2d) + M, D = x_2d.shape + + has_weight = weight is not None + if not has_weight: + weight = torch.empty(0, device=x.device) + + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + rmsnorm_fwd_kernel[(M,)]( + x_2d, weight, out, + x_2d.stride(0), D, eps, has_weight, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view(orig_shape) +``` + +### Benchmark + +```python +def benchmark_rmsnorm(): + configs = [ + (1, 1024, 2048), + (2, 1024, 2048), + (2, 4096, 3072), + ] + for batch, seq, hidden in configs: + x = torch.randn(batch, seq, hidden, device='cuda', dtype=torch.float16) + w = torch.ones(hidden, device='cuda', dtype=torch.float16) + + # Reference + ref = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6) * w + + # Custom + out = triton_rmsnorm(x, w, eps=1e-6) + + # Verify + torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-3) + print(f"[{batch}x{seq}x{hidden}] ✓ Correct") +``` + +## Template 2: RoPE 3D (Video Position Encoding) + +Element-wise rotation. Splits head_dim into temporal + spatial (height + width) components. + +**CRITICAL: cos/sin have shape `[seq_len, head_dim]`, NOT `[batch*seq_len, ...]`.** When the grid flattens the batch dimension, use `pid_s % seq_len` to index cos/sin, otherwise batch > 1 causes out-of-bounds GPU crash. + +### Triton Kernel + +```python +@triton.jit +def rope_3d_fwd_kernel( + qk_ptr, cos_ptr, sin_ptr, out_ptr, + seq_len, num_heads, head_dim, + stride_s, stride_h, stride_d, + BLOCK_HD: tl.constexpr, +): + pid_s = tl.program_id(0) # ranges [0, batch * seq_len) + pid_h = tl.program_id(1) + + half_dim = head_dim // 2 + offs = tl.arange(0, BLOCK_HD) + mask = offs < half_dim + + base = pid_s * stride_s + pid_h * stride_h + x0 = tl.load(qk_ptr + base + offs, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load(qk_ptr + base + half_dim + offs, mask=mask, other=0.0).to(tl.float32) + + seq_idx = pid_s % seq_len # wrap for batch > 1 + cos_val = tl.load(cos_ptr + seq_idx * head_dim + offs, mask=mask, other=1.0).to(tl.float32) + sin_val = tl.load(sin_ptr + seq_idx * head_dim + offs, mask=mask, other=0.0).to(tl.float32) + + out0 = x0 * cos_val - x1 * sin_val + out1 = x0 * sin_val + x1 * cos_val + + tl.store(out_ptr + base + offs, out0.to(x0.dtype), mask=mask) + tl.store(out_ptr + base + half_dim + offs, out1.to(x0.dtype), mask=mask) +``` + +### Python API + +```python +def triton_rope_3d( + qk: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: + """ + Apply 3D RoPE to Q or K tensor. + + Args: + qk: [batch, seq_len, num_heads, head_dim] + cos: [seq_len, head_dim] — NOT batch-expanded! + sin: [seq_len, head_dim] + """ + qk = qk.contiguous() + out = torch.empty_like(qk) + batch, seq_len, num_heads, head_dim = qk.shape + + qk_flat = qk.view(batch * seq_len, num_heads, head_dim) + out_flat = out.view(batch * seq_len, num_heads, head_dim) + + BLOCK_HD = triton.next_power_of_2(head_dim // 2) + num_warps = 4 if BLOCK_HD <= 64 else 8 + + rope_3d_fwd_kernel[(batch * seq_len, num_heads)]( + qk_flat, cos, sin, out_flat, + seq_len, num_heads, head_dim, + qk_flat.stride(0), qk_flat.stride(1), qk_flat.stride(2), + BLOCK_HD=BLOCK_HD, num_warps=num_warps, num_stages=2, + ) + return out +``` + +## Template 3: GEGLU (For SD3/FLUX) + +Gated activation: `GELU(gate) * value`. Input splits in half along last dim. + +**Note: LTX-Video uses GELU, NOT GEGLU. This template is for SD3/FLUX.** + +### Triton Kernel + +```python +@triton.jit +def geglu_fwd_kernel( + input_ptr, output_ptr, + stride_in, stride_out, hidden_size, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_H) + mask = offs < hidden_size + + gate = tl.load(input_ptr + row * stride_in + offs, + mask=mask, other=0.0).to(tl.float32) + value = tl.load(input_ptr + row * stride_in + hidden_size + offs, + mask=mask, other=0.0).to(tl.float32) + + # Manual tanh — tl.math.tanh / tl.libdevice.tanh NOT available on ROCm + SQRT_2_OVER_PI = 0.7978845608028654 + tanh_arg = SQRT_2_OVER_PI * (gate + 0.044715 * gate * gate * gate) + e2x = tl.exp(2.0 * tanh_arg) + tanh_val = (e2x - 1.0) / (e2x + 1.0) + cdf = 0.5 * (1.0 + tanh_val) + gelu_gate = gate * cdf + + result = gelu_gate * value + tl.store(output_ptr + row * stride_out + offs, result.to(gate.dtype), mask=mask) +``` + +### Python API + +```python +def triton_geglu(x: torch.Tensor) -> torch.Tensor: + """ + GEGLU activation: GELU(x[..., :H]) * x[..., H:] + + Input: [..., 2*hidden_size] → Output: [..., hidden_size] + """ + x = x.contiguous() + *batch_dims, double_h = x.shape + hidden_size = double_h // 2 + + x_2d = x.view(-1, double_h) + M = x_2d.shape[0] + out = torch.empty(M, hidden_size, device=x.device, dtype=x.dtype) + + BLOCK_H = triton.next_power_of_2(hidden_size) + num_warps = 4 if BLOCK_H <= 1024 else (8 if BLOCK_H <= 4096 else 16) + geglu_fwd_kernel[(M,)]( + x_2d, out, + x_2d.stride(0), out.stride(0), hidden_size, + BLOCK_H=BLOCK_H, num_warps=num_warps, num_stages=2, + ) + return out.view(*batch_dims, hidden_size) +``` + +## Template 4: AdaLN (Adaptive Layer Normalization) + +Fused RMSNorm + adaptive conditioning for DiT blocks. +Formula: `norm(x) * weight * (1 + scale) + shift` + +### Triton Kernel + +```python +@triton.jit +def adaln_fwd_kernel( + x_ptr, weight_ptr, scale_ptr, shift_ptr, out_ptr, + stride_x, stride_cond, D, + eps: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + x_norm = x * rms_inv + + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + scale = tl.load(scale_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) + shift = tl.load(shift_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) + + out = x_norm * w * (1.0 + scale) + shift + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) +``` + +### Python API + +```python +def triton_adaln( + x: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + shift: torch.Tensor, + eps: float = 1e-6, +) -> torch.Tensor: + """ + Adaptive Layer Normalization for DiT blocks. + + Args: + x: [batch, seq, hidden] + weight: [hidden] + scale: [batch, seq, hidden] or [batch, 1, hidden] + shift: [batch, seq, hidden] or [batch, 1, hidden] + """ + x_flat = x.contiguous().view(-1, x.shape[-1]) + scale_flat = scale.contiguous().view(-1, x.shape[-1]) + shift_flat = shift.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_flat) + M, D = x_flat.shape + + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + adaln_fwd_kernel[(M,)]( + x_flat, weight, scale_flat, shift_flat, out, + x_flat.stride(0), scale_flat.stride(0), D, eps, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) +``` + +## Common Math Replacements for ROCm + +| Standard | ROCm Triton Replacement | +|----------|------------------------| +| `tl.tanh(x)` | Manual: `e2x = tl.exp(2.0*x); (e2x-1)/(e2x+1)` | +| `tl.math.tanh(x)` | **Also NOT available on ROCm** — use manual formula above | +| `tl.libdevice.*` | Remove entirely, use manual implementations | +| `min(a, b)` | `tl.minimum(a, b)` | +| `max(a, b)` | `tl.maximum(a, b)` | +| GELU exact | `0.5 * x * (1 + erf(x / sqrt(2)))` | +| GELU approx | `0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))` | + +## Kernel-Specific Guidelines + +### RMSNorm +- Input: `[..., hidden_size]` — flatten to 2D `[M, D]` +- Epsilon default: 1e-6 +- **Weight may be None** if `elementwise_affine=False` +- Always accumulate `x*x` sum in FP32 +- **BLOCK_D = `triton.next_power_of_2(D)`** — compute in wrapper, NEVER autotune +- Autotuning BLOCK_D is dangerous: if BLOCK_D < D, only partial row is processed → wrong results + +### RoPE 3D +- 1D: `[batch, seq, heads, head_dim]` for text +- 3D: `[batch, t*h*w, heads, head_dim]` for video +- LTX-Video computes RoPE via `LTXVideoRotaryPosEmbed` — kernel replaces the apply step +- head_dim typically 64 or 128 +- **cos/sin shape is `[seq_len, head_dim]`** — use `pid_s % seq_len` for batch > 1 + +### GEGLU vs GELU +- **GEGLU**: Input `[B, S, 2*H]` → Output `[B, S, H]` — gate/value split +- **GELU**: Standard activation, no split +- **LTX-Video uses GELU, NOT GEGLU** +- GEGLU is for SD3/FLUX + +### AdaLN +- Formula: `norm(x) * weight * (1 + scale) + shift` +- Scale/shift come from timestep embedding MLP +- DiT computes 6 values per block: `(scale1, shift1, gate1, scale2, shift2, gate2)` +- Fusing norm + conditioning saves one memory round-trip + +## Template 5: GEMM with XCD Swizzle (MI355X) + +Tiled matrix multiplication with XCD swizzle for MI355X (32 XCDs). **Mandatory** for any GEMM-like operation on MI355X — without it, work clusters on a few chiplets, wasting 90%+ of the GPU. + +> See [mi355x-optimization-guide.md](mi355x-optimization-guide.md) for architecture details. + +**When to use XCD swizzle:** GEMM, batched GEMM, attention (Q@K, score@V). NOT needed for elementwise, reduction, or normalization kernels. + +### Triton Kernel + +```python +NUM_XCDS = 32 # MI355X has 32 XCDs + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=2, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=2, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=2, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, + num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=3, num_warps=8), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def gemm_xcd_swizzle_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pids = num_pid_m * num_pid_n + + # --- XCD Swizzle: distribute blocks across 32 chiplets --- + pids_per_xcd = (num_pids + NUM_XCDS - 1) // NUM_XCDS + xcd_id = pid % NUM_XCDS + local_pid = pid // NUM_XCDS + if local_pid < pids_per_xcd: + remapped_pid = xcd_id * pids_per_xcd + local_pid + if remapped_pid < num_pids: + pid = remapped_pid + + # --- L2 Cache Grouping (after XCD swizzle) --- + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # --- Compute GEMM tile --- + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + offs_k = tl.arange(0, BLOCK_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) + acc += tl.dot(a, b) + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + # --- Store result --- + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, acc.to(tl.float16), mask=mask) +``` + +### Python API + +```python +def triton_gemm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Matrix multiplication C = A @ B with XCD swizzle for MI355X. + + Args: + a: [M, K] input matrix + b: [K, N] input matrix + Returns: + c: [M, N] output matrix + """ + assert a.shape[1] == b.shape[0], "Inner dimensions must match" + assert a.is_contiguous() and b.is_contiguous() + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) + gemm_xcd_swizzle_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + ) + return c +``` + +### Benchmark + +```python +def benchmark_gemm(): + configs = [(4096, 4096, 4096), (8192, 8192, 4096), (2048, 8192, 2048)] + for M, N, K in configs: + a = torch.randn(M, K, device='cuda', dtype=torch.float16) + b = torch.randn(K, N, device='cuda', dtype=torch.float16) + + ref = torch.mm(a, b) + out = triton_gemm(a, b) + torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-1) + + # Benchmark + for _ in range(10): + triton_gemm(a, b) + torch.mm(a, b) + torch.cuda.synchronize() + + import time + iters = 50 + start = time.perf_counter() + for _ in range(iters): + triton_gemm(a, b) + torch.cuda.synchronize() + custom_ms = (time.perf_counter() - start) / iters * 1000 + + start = time.perf_counter() + for _ in range(iters): + torch.mm(a, b) + torch.cuda.synchronize() + torch_ms = (time.perf_counter() - start) / iters * 1000 + + print(f"[{M}x{N}x{K}] Custom: {custom_ms:.2f}ms, Torch: {torch_ms:.2f}ms, " + f"Speedup: {torch_ms/custom_ms:.2f}x") +``` + +### GEMM-Specific Guidelines + +- **XCD Swizzle is MANDATORY** on MI355X for any GEMM — without it, expect 0.3-0.5x +- **L2 Cache Grouping** (`GROUP_M=8-16`): Improves L2 hit rate after XCD swizzle +- **MFMA**: Use `matrix_instr_nonkdim=16` for MI355X matrix cores +- **FP32 accumulation**: Always accumulate in FP32, cast at store +- **LDS budget**: Check `BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N` * dtype * num_stages < 160 KB +- **Autotune**: GEMM benefits heavily from autotuning — always include 4+ configs +- **R9700**: Does NOT have XCDs — remove the XCD swizzle section for RDNA4 diff --git a/skills/rocm-kernels/references/kernelbench-classification.md b/skills/rocm-kernels/references/kernelbench-classification.md new file mode 100644 index 00000000..67ba8a76 --- /dev/null +++ b/skills/rocm-kernels/references/kernelbench-classification.md @@ -0,0 +1,162 @@ +# KernelBench Operator Classification & Skill Mapping + +This document classifies KernelBench operators into categories and maps each to the appropriate kernel skill/pattern. + +## Classification Taxonomy + +### Level 1: Basic Operators (53 operators) + +#### Category A: GEMM / Matrix Multiplication (18 operators) + +| ID | Name | Sub-type | Key Skill | +|----|------|----------|-----------| +| 1 | Square matrix multiplication | Dense GEMM | XCD Swizzle + Autotune | +| 2 | Standard matrix multiplication | Dense GEMM (M!=N) | XCD Swizzle + Autotune | +| 3 | Batched matrix multiplication | BMM | Batch-indexed GEMM | +| 4 | Matrix-vector multiplication | MatVec | 1D reduction pattern | +| 5 | Matrix-scalar multiplication | Elementwise | Scale kernel | +| 6 | Matmul with large K | Large-K GEMM | K-dimension blocking | +| 7 | Matmul with small K | Small-K GEMM | Fewer K-iterations | +| 8 | Matmul with irregular shapes | Non-square GEMM | Mask handling | +| 9 | Tall-skinny matmul | Tall-skinny GEMM | Tile shape tuning | +| 10 | 3D tensor-matrix mul | Batched GEMM | Reshape + GEMM | +| 11 | 4D tensor-matrix mul | Batched GEMM | Einsum decomposition | +| 12 | Diagonal matrix mul | Special GEMM | Elementwise pattern | +| 13 | Symmetric matrices | Dense GEMM | Standard GEMM | +| 14 | Upper triangular mul | Masked GEMM | Triangle mask | +| 15 | Lower triangular mul | Masked GEMM | Triangle mask | +| 16 | Transposed A | Transposed GEMM | Stride adjustment | +| 17 | Transposed B | Transposed GEMM | Stride adjustment | +| 18 | Both transposed | Transposed GEMM | Stride adjustment | + +**Key Pattern**: Template 5 (GEMM with XCD Swizzle) +**Critical Optimization**: XCD swizzle + L2 cache grouping + MFMA 16x16 + +#### Category B: Elementwise / Activation Functions (14 operators) + +| ID | Name | Sub-type | Key Skill | +|----|------|----------|-----------| +| 19 | ReLU | Branching | `tl.where(x > 0, x, 0)` | +| 20 | LeakyReLU | Branching | `tl.where(x > 0, x, alpha*x)` | +| 21 | Sigmoid | Transcendental | `1/(1+exp(-x))` | +| 22 | Tanh | Transcendental | `(exp(2x)-1)/(exp(2x)+1)` | +| 23 | Softmax | Row reduction | Online softmax | +| 24 | LogSoftmax | Row reduction | Online softmax + log | +| 25 | Swish/SiLU | Transcendental | `x * sigmoid(x)` | +| 26 | GELU | Transcendental | `0.5*x*(1+erf(x/sqrt(2)))` | +| 27 | SELU | Branching + exp | `scale * where(x>0, x, alpha*(exp(x)-1))` | +| 28 | HardSigmoid | Clamp | `clamp((x+3)/6, 0, 1)` | +| 29 | Softplus | Transcendental | `log(1+exp(x))` | +| 30 | Softsign | Division | `x/(1+abs(x))` | +| 31 | ELU | Branching + exp | `where(x>0, x, alpha*(exp(x)-1))` | +| 32 | HardTanh | Clamp | `clamp(x, -1, 1)` | + +**Key Pattern**: Template 1 (Elementwise) +**Critical Optimization**: Large BLOCK_SIZE (4096-16384), FP32 compute + +#### Category C: Normalization (8 operators) + +| ID | Name | Sub-type | Key Skill | +|----|------|----------|-----------| +| 33 | BatchNorm | Multi-dim reduction | Welford algorithm | +| 34 | InstanceNorm | Per-instance reduction | Per-sample norm | +| 35 | GroupNorm | Group reduction | Grouped channels | +| 36 | RMSNorm | Row reduction | `x * rsqrt(mean(x^2) + eps)` | +| 37 | FrobeniusNorm | Full reduction | `sqrt(sum(x^2))` | +| 38 | L1 Norm | Full reduction | `sum(abs(x))` | +| 39 | L2 Norm | Full reduction | `sqrt(sum(x^2))` | +| 40 | LayerNorm | Row reduction | `(x-mean)/std * w + b` | + +**Key Pattern**: Template 3 (Row-wise Reduction) +**Critical Optimization**: FP32 accumulation, proper reduction + +#### Category D: Pooling (6 operators) + +| ID | Name | Sub-type | Key Skill | +|----|------|----------|-----------| +| 41 | Max Pooling 1D | Sliding window | Max reduction | +| 42 | Max Pooling 2D | 2D window | 2D index mapping | +| 43 | Max Pooling 3D | 3D window | Program_id flattening | +| 44 | Average Pooling 1D | Sliding window | Sum + divide | +| 45 | Average Pooling 2D | 2D window | 2D index mapping | +| 46 | Average Pooling 3D | 3D window | Program_id flattening | + +**Key Challenge**: 3D grid mapping with Triton's program_id limits + +#### Category E: Reduction (7 operators) + +| ID | Name | Sub-type | Key Skill | +|----|------|----------|-----------| +| 47 | Sum reduction | Sum | `tl.sum()` | +| 48 | Mean reduction | Mean | `tl.sum() / count` | +| 49 | Max reduction | Max | `tl.max()` | +| 50 | Min reduction | Min | `tl.min()` | +| 51 | Argmax | Index + max | Two-pass or manual | +| 52 | Argmin | Index + min | Two-pass or manual | +| 53 | Min (duplicate) | Min | `tl.min()` | + +**Key Pattern**: Template 5 (Dimension Reduction) +**Key Challenge**: Argmax/Argmin require manual implementation + +### Level 2: Fused Operators (20+ operators) + +Combine multiple operations into single kernels. + +| Category | Examples | Strategy | +|----------|---------|----------| +| GEMM + Activation | Gemm_ReLU, Gemm_GELU | Fuse activation into GEMM epilogue | +| GEMM + Norm | Gemm_BatchNorm, Gemm_GroupNorm | Two-phase kernel | +| GEMM + Scale | Gemm_Scale, Gemm_Divide | Fuse into GEMM store | +| Multi-op fusion | Matmul_Sum_Max_AvgPool | Sequential fusion | + +**Key Pattern**: Template 6 (Fused GEMM + Activation) + +### Level 3-4: Network Models / Transformers + +Full models requiring multiple kernel types. Decompose into Level 1 operators. + +### Level 6-7: Advanced / Expert + +| Operator | Type | Strategy | +|----------|------|----------| +| MinGPTNewGelu | Fused activation | GELU approximation kernel | +| ScaledDotProductAttention | Attention | Flash Attention pattern | +| GELU_And_Mul | Fused activation | `gelu(x) * y` | +| MoE_TopK_Softmax | MoE routing | Specialized kernel | +| Gemm_A8W8_Blockwise | Quantized GEMM | INT8 with block scaling | + +## Category → Skill Mapping + +| Category | Skill File | Priority | +|----------|-----------|----------| +| **GEMM** | `gemm-skill.md` (planned) | P0 - Most impactful | +| **Elementwise** | `elementwise-skill.md` (planned) | P0 - Most common | +| **Normalization** | `normalization-skill.md` (planned) | P1 - Frequently used | +| **Reduction** | `reduction-skill.md` (planned) | P1 - Common pattern | +| **Softmax** | `softmax-skill.md` (planned) | P1 - Critical for attention | +| **Pooling** | `pooling-skill.md` (planned) | P2 - Moderate complexity | +| **Attention** | `attention-skill.md` (planned) | P2 - High complexity | +| **Fused** | `fused-skill.md` (planned) | P2 - Combination patterns | + +## Performance Expectations by Category + +Based on kernel-agent test results: + +| Category | Achievable Speedup | Difficulty | Notes | +|----------|-------------------|------------|-------| +| Elementwise | 1.0-3.0x | Low | Large blocks, memory-bound | +| Reduction (sum/mean) | 1.5-5.0x | Medium | Good parallelism | +| Pooling | 1.5-5.0x | Medium | Grid mapping challenge | +| LayerNorm/RMSNorm | 1.5-2.0x | Medium | Row-wise reduction | +| Dense GEMM | 0.8-1.2x | High | XCD swizzle critical | +| Batched GEMM | 0.6-0.9x | High | Memory bandwidth limited | +| BatchNorm | <0.1x | Very High | HIP sync issues | +| Argmax/Argmin | FAIL | Very High | Triton API limitation | +| Fused operators | 0.3-1.0x | Very High | Correctness challenges | + +## Recommended Skill Development Order + +1. **Phase 1 (Quick wins)**: Elementwise activations, Sum/Mean reduction +2. **Phase 2 (Core)**: GEMM with XCD swizzle, LayerNorm/RMSNorm +3. **Phase 3 (Advanced)**: Softmax, Pooling, Attention +4. **Phase 4 (Expert)**: Fused operators, BatchNorm, Quantized GEMM diff --git a/skills/rocm-kernels/references/mi355x-optimization-guide.md b/skills/rocm-kernels/references/mi355x-optimization-guide.md new file mode 100644 index 00000000..966ff7a3 --- /dev/null +++ b/skills/rocm-kernels/references/mi355x-optimization-guide.md @@ -0,0 +1,233 @@ +# MI355X (gfx950) Optimization Guide + +Deep dive into MI355X-specific optimizations for Triton kernels on ROCm. + +## MI355X CDNA3+ Architecture + +### Key Specifications + +| Component | Value | vs MI300X | +|-----------|-------|-----------| +| Compute Capability | gfx950 | gfx942 | +| Architecture | CDNA3+ | CDNA3 | +| **XCDs (Chiplets)** | **32** | 8 | +| CUs Total | 256 | 228 | +| CUs per XCD | 8 | 28 | +| **LDS per CU** | **160 KB** | 64 KB | +| L2 Cache | 256 MB | 256 MB | +| Wavefront Size | 64 | 64 | +| GPU Memory | 288 GB HBM3e | 192 GB HBM3 | +| **Memory Bandwidth** | **8 TB/s** | 5.3 TB/s | +| FP16/BF16 Matrix TFLOPS | ~2500 | 1307 | +| FP8 Matrix TFLOPS | ~5000 | 2615 | +| MFMA Instructions | 16x16, 32x32 | 16x16, 32x32 | +| FP8 Format | float8_e4m3fn (OCP) | float8_e4m3fnuz (AMD) | + +### Critical Architecture Differences from MI300X + +1. **32 XCDs vs 8**: XCD swizzle must use `NUM_XCDS=32` +2. **8 CUs per XCD vs 28**: Finer-grained chiplet distribution +3. **160 KB LDS vs 64 KB**: 2.5x larger local memory per CU +4. **8 TB/s vs 5.3 TB/s**: 50% more memory bandwidth +5. **OCP FP8 vs AMD FP8**: Different FP8 format + +## XCD Swizzle (MANDATORY for GEMM) + +MI355X has 32 XCDs. Without proper swizzle, GEMM blocks cluster on a few XCDs, wasting 90%+ of the GPU. + +### When to Use XCD Swizzle + +| Kernel Type | XCD Swizzle? | Why | +|-------------|-------------|-----| +| GEMM / matmul | **YES, MANDATORY** | Multi-block work distribution | +| Elementwise | No | Single-block independent | +| Reduction | No | Row-independent | +| Normalization | No | Row-independent | +| Attention | **YES** (for Q@K and score@V) | Contains GEMM | + +### XCD Swizzle Implementation + +```python +NUM_XCDS = 32 + +@triton.jit +def gemm_with_xcd_swizzle(...): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pids = num_pid_m * num_pid_n + + # Step 1: XCD Swizzle + pids_per_xcd = (num_pids + NUM_XCDS - 1) // NUM_XCDS + xcd_id = pid % NUM_XCDS + local_pid = pid // NUM_XCDS + if local_pid < pids_per_xcd: + remapped_pid = xcd_id * pids_per_xcd + local_pid + if remapped_pid < num_pids: + pid = remapped_pid + + # Step 2: L2 Cache Grouping (after XCD swizzle) + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m +``` + +### Performance Impact + +| Config | Without XCD Swizzle | With XCD Swizzle | Improvement | +|--------|-------------------|-----------------|-------------| +| Square GEMM 4096x4096 | 0.3-0.5x | 0.8-1.2x | 2-4x | +| Tall-skinny GEMM | 0.4-0.6x | 0.7-1.0x | 1.5-2.5x | + +## MFMA Instructions + +Use 16x16 MFMA for optimal matrix core utilization: + +```python +# Launch kernel with MFMA hint +kernel[grid](..., matrix_instr_nonkdim=16) +``` + +## LDS Optimization + +MI355X has 160 KB LDS per CU—2.5x more than MI300X. + +### LDS Budget Calculation + +``` +LDS usage = BLOCK_M × BLOCK_K × dtype_size + BLOCK_K × BLOCK_N × dtype_size + × num_stages + +Example (BLOCK_M=256, BLOCK_N=256, BLOCK_K=64, FP16, num_stages=2): + = (256×64×2 + 64×256×2) × 2 = 131,072 bytes = 128 KB < 160 KB ✓ + +Same config on MI300X (64 KB LDS): + 128 KB > 64 KB ✗ → Need num_stages=1 or smaller blocks +``` + +### Stage Configuration + +| LDS Budget | MI355X num_stages | MI300X num_stages | +|------------|------------------|------------------| +| < 80 KB | 2-3 | 2 | +| 80-160 KB | 2 | 1 (or reduce blocks) | +| > 160 KB | 1 (or reduce blocks) | Not possible | + +## Autotune Configurations + +### Elementwise Operations + +```python +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 1024}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE': 2048}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_SIZE': 4096}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_SIZE': 4096}, num_warps=16, num_stages=2), + triton.Config({'BLOCK_SIZE': 8192}, num_warps=16, num_stages=2), + triton.Config({'BLOCK_SIZE': 16384}, num_warps=16, num_stages=2), + ], + key=['n_elements'], +) +``` + +### GEMM Operations + +```python +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=2, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=2, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=2, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, + num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=3, num_warps=8), + ], + key=['M', 'N', 'K'], +) +``` + +### Problem-Specific Block Sizes + +| Problem Type | BLOCK_M | BLOCK_N | BLOCK_K | num_stages | num_warps | GROUP_M | +|-------------|---------|---------|---------|------------|-----------|---------| +| Square GEMM (M,N>=4096) | 256 | 256 | 32 | 3 | 8 | 16 | +| Large K (K > max(M,N)) | 128 | 128 | 64 | 2 | 8 | 8 | +| Fused GEMM+Activation | 128 | 128 | 64 | 2 | 8 | 8 | +| Element-wise ops | - | - | - | 2 | 4-16 | - | + +## Precision and Numerical Stability + +### FP32 Accumulation (Required) + +```python +# Always accumulate in FP32 +acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) +for k in range(...): + acc += tl.dot(a, b) +# Cast at store +c = acc.to(tl.float16) +``` + +### Math Operations + +```python +# Cast to FP32 for transcendental functions +x_f32 = x.to(tl.float32) +result = tl.exp(x_f32) # ✓ +result = tl.log(x_f32) # ✓ +result = tl.sqrt(x_f32) # ✓ +result = 1.0 / x_f32 # ✓ (division in FP32) + +# tanh workaround (tl.tanh not supported on AMD) +e2x = tl.exp(2.0 * x_f32) +tanh_x = (e2x - 1.0) / (e2x + 1.0) +``` + +## Performance Profiling + +```bash +# Basic kernel profiling +rocprof --stats python your_kernel.py + +# Detailed metrics +rocprofv3 -i metrics.txt python your_kernel.py + +# Key metrics to watch: +# - L2 cache hit rate (target >70%) +# - VGPR usage (128+ may limit occupancy) +# - LDS usage (max 160 KB on MI355X) +# - Memory bandwidth utilization (target 40-60% of 8 TB/s) +``` + +## Environment Variables + +```python +import os +# Block ping-pong for better latency hiding +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +# Async memory copies +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' +``` + +## Best Practices Summary + +1. **XCD Swizzle**: Always for GEMM, never for elementwise +2. **MFMA**: Use matrix_instr_nonkdim=16 +3. **LDS**: Leverage 160 KB, but check with num_stages +4. **num_stages**: 2-3 (safe), up to 4 if LDS permits +5. **num_warps**: 8 is default, autotune 4-16 +6. **BLOCK_SIZE**: Larger than MI300X (1024-16384 for 1D) +7. **GROUP_M**: 8 or 16 for L2 cache grouping +8. **FP32 acc**: Always accumulate in FP32 +9. **Env vars**: Set BLOCK_PINGPONG and ASYNC_COPY +10. **Profile**: Use rocprof to validate optimizations diff --git a/skills/rocm-kernels/references/r9700-optimization-guide.md b/skills/rocm-kernels/references/r9700-optimization-guide.md new file mode 100644 index 00000000..8eea3489 --- /dev/null +++ b/skills/rocm-kernels/references/r9700-optimization-guide.md @@ -0,0 +1,172 @@ +# R9700 (RDNA4, gfx1201) Optimization Guide + +Deep dive into R9700-specific optimizations for Triton kernels on ROCm. + +## R9700 RDNA4 Architecture + +### Key Specifications + +| Component | R9700 | vs MI355X | +|-----------|-------|-----------| +| Compute Capability | gfx1201 | gfx950 | +| Architecture | RDNA4 | CDNA3+ | +| **Wavefront Size** | **32 (Wave32)** | 64 (Wave64) | +| CUs | 64 | 256 | +| Stream Processors | 4096 | - | +| LDS per CU | 64 KB | 160 KB | +| L1 Cache | 32 KB | - | +| L2 Cache | 8 MB | 256 MB | +| L3 Cache | 64 MB | - | +| **Cacheline Size** | **256 B** | - | +| Max Threads/Block | 1024 | 1024 | +| Max Threads/CU | 2048 | 2048 | +| Max Waves/CU | 32 | - | +| SIMDs per CU | 2 | - | +| FP32 Vector TFLOPS | 47.8 | ~200 | +| FP16 Vector TFLOPS | 95.7 | ~2500 | +| FP16 Matrix TFLOPS | 191 | ~2500 | +| Matrix Cores | Limited (no FP8 MFMA) | Full MFMA | + +### Critical RDNA4 vs CDNA3+ Differences + +1. **Wave32 vs Wave64**: Warp size is 32, same as NVIDIA +2. **No XCD Swizzle**: Single die, no chiplet distribution needed +3. **Limited Matrix Cores**: No FP8 MFMA support +4. **Smaller LDS**: 64 KB vs 160 KB +5. **Smaller L2 Cache**: 8 MB vs 256 MB +6. **256B Cacheline**: Stricter memory alignment requirements +7. **Consumer GPU**: Optimized for inference, not training + +## Wave32 Implications + +### num_warps Mapping + +On RDNA4, `num_warps` still means "number of wavefronts per block": +- 1 warp = 32 threads (Wave32) +- Max 32 waves per CU +- num_warps range: 2-8 (smaller than CDNA) + +```python +# CDNA (MI355X): 1 warp = 64 threads +# num_warps=8 → 512 threads/block + +# RDNA4 (R9700): 1 warp = 32 threads +# num_warps=8 → 256 threads/block +# Use higher num_warps if needed for same thread count +``` + +### Reduction Code + +Warp-level reductions use different offsets: + +```python +# CDNA (Wave64): offsets = 32, 16, 8, 4, 2, 1 +# RDNA4 (Wave32): offsets = 16, 8, 4, 2, 1 + +# In Triton this is handled automatically by tl.sum(), tl.max(), etc. +# No manual shuffle code needed in Triton +``` + +## Memory Hierarchy + +### 256B Cacheline Alignment + +R9700 uses 256-byte cachelines (vs 128B on RDNA3). Misaligned accesses are penalized more. + +```python +# Ensure contiguous memory access +x = x.contiguous() + +# In kernel: sequential access pattern +offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) +x = tl.load(x_ptr + offsets, mask=mask) # Coalesced +``` + +### L2 Cache Strategy + +With only 8 MB L2, cache reuse is limited: + +```python +# For GEMM: use smaller tiles to fit in L2 +# BLOCK_M=64, BLOCK_N=64, BLOCK_K=32 +# Tile = 64×32×2 + 32×64×2 = 8 KB per stage +# With 2 stages: 16 KB fits in L2 per block +``` + +### LDS (64 KB) Budget + +``` +Max LDS per CU = 64 KB + +GEMM example (BLOCK_M=64, BLOCK_N=128, BLOCK_K=32, FP16, num_stages=2): + = (64×32×2 + 32×128×2) × 2 = 24,576 bytes = 24 KB ✓ + +GEMM example (BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, FP16, num_stages=2): + = (128×64×2 + 64×128×2) × 2 = 65,536 bytes = 64 KB → Borderline! +``` + +## Autotune Configurations + +### Elementwise Operations + +```python +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 256}, num_warps=2, num_stages=2), + triton.Config({'BLOCK_SIZE': 512}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE': 1024}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE': 1024}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_SIZE': 2048}, num_warps=8, num_stages=2), + ], + key=['n_elements'], +) +``` + +### GEMM Operations + +```python +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, + num_stages=2, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, + num_stages=2, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, + num_stages=2, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, + num_stages=2, num_warps=8), + ], + key=['M', 'N', 'K'], +) +``` + +## Grid Sizing + +With 64 CUs: + +```python +# Aim for multiples of 64 blocks +grid = (triton.cdiv(N, BLOCK_SIZE),) +# For GEMM: grid = (cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N),) +``` + +## Precision Considerations + +- FP16 Matrix TFLOPS = 191 (2x FP32 vector) +- FP16 Vector TFLOPS = 95.7 (2x FP32 vector) +- **No FP8 MFMA**: Cannot use FP8 matrix operations +- INT8 Matrix TOPS = 383 (quantized inference) +- Use FP16 for compute, FP32 for accumulation + +## Best Practices Summary + +1. **Wave32 awareness**: Use num_warps=2-8 +2. **No XCD Swizzle**: Not needed on single-die +3. **Smaller blocks**: 64-128 for GEMM tiles +4. **256B alignment**: Ensure contiguous memory access +5. **LDS budget**: Max 64 KB, keep num_stages=2 +6. **Grid sizing**: Multiples of 64 CUs +7. **FP16 preferred**: Best throughput, no FP8 MFMA +8. **L3 cache**: 64 MB can help with model weights +9. **Inference focus**: Best suited for inference workloads +10. **Cacheline**: 256B alignment is stricter than MI355X diff --git a/skills/rocm-kernels/references/skill-evaluation-methodology.md b/skills/rocm-kernels/references/skill-evaluation-methodology.md new file mode 100644 index 00000000..7e9548a6 --- /dev/null +++ b/skills/rocm-kernels/references/skill-evaluation-methodology.md @@ -0,0 +1,251 @@ +# Skill 评估与优化方法论 + +本文档描述如何系统性地评估和优化 ROCm kernel skills 的质量。 + +## 1. Skill 质量评估框架 + +### 1.1 评估维度 + +| 维度 | 权重 | 衡量标准 | 评估方法 | +|------|------|---------|---------| +| **正确性** | 40% | AI 生成的 kernel 能通过正确性测试 | KernelBench accuracy ratio | +| **性能** | 30% | 生成 kernel 的 speedup 相比 PyTorch baseline | KernelBench speedup ratio | +| **可运行率** | 20% | kernel 能成功编译并运行 | KernelBench runnable ratio | +| **触发准确率** | 10% | AI 在正确场景下使用了正确的 skill | 人工评审 | + +### 1.2 KernelBench 自动评估 + +```bash +# 运行 Level 1 全量评估 +python -m kernel_agent.evaluation.kernelbench_success_evaluator \ + --config examples/workflows/evaluation/kernelbench/config_eval.yml \ + --dataset datasets/kernelbench/kernel_bench_level_1.json + +# 输出指标 +# - runnable_ratio: 可运行比率 +# - accuracy_ratio: 正确性比率 +# - speed_ratio: 达到 speedup > 1.0x 的比率 +``` + +### 1.3 逐类别评估 + +对每个 operator 类别单独评估: + +```bash +# Level 1 Selected (按类别) +python -m kernel_agent.evaluation.kernelbench_success_evaluator \ + --config examples/workflows/evaluation/kernelbench/config_eval_level1_selected.yml +``` + +**评估记录表模板:** + +| 类别 | 算子数 | 可运行 | 正确 | Speedup>1x | 平均 Speedup | Skill 版本 | +|------|--------|--------|------|-----------|-------------|-----------| +| GEMM | 18 | ?/18 | ?/18 | ?/18 | ?x | v0.1 | +| Elementwise | 14 | ?/14 | ?/14 | ?/14 | ?x | v0.1 | +| Normalization | 8 | ?/8 | ?/8 | ?/8 | ?x | v0.1 | +| ... | ... | ... | ... | ... | ... | ... | + +## 2. 迭代优化流程 + +### 2.1 PDCA 循环 + +``` +Plan → 识别性能瓶颈或失败模式 +Do → 修改 SKILL.md 或 references 中的指引 +Check → 重新运行 KernelBench 评估 +Act → 确认改进,合并到 skill;或回滚 +``` + +### 2.2 具体优化步骤 + +#### Step 1: 收集失败案例 + +```python +# 从评估结果中提取失败案例 +# 分析每个失败的原因类型: +# - compilation_error: 编译错误 → 修改模板代码 +# - runtime_error: 运行时错误 → 添加 troubleshooting 条目 +# - accuracy_error: 精度问题 → 修改精度相关指引 +# - performance_low: 性能不达标 → 添加优化策略 +``` + +#### Step 2: 根因分析 + +| 失败类型 | 检查项 | 修改位置 | +|---------|--------|---------| +| tl.libdevice 错误 | 是否遗漏了 ROCm 禁忌 | SKILL.md "Critical ROCm Constraints" | +| LDS 溢出 | num_stages 建议是否正确 | GPU optimization guide | +| GEMM 极慢 | 是否缺少 XCD swizzle | kernel-templates.md Template 5 (GEMM with XCD Swizzle) | +| 精度不达标 | FP32 累加是否到位 | kernel-templates.md 所有模板 | +| Python min/max | 是否提醒了 tl.minimum | troubleshooting.md | + +#### Step 3: 修改 Skill 内容 + +按照失败原因修改对应文件: +- **模式错误** → 修改 `kernel-templates.md` +- **知识缺失** → 修改 `SKILL.md` 或 GPU 优化指南 +- **新陷阱** → 添加到 `troubleshooting.md` +- **性能数据过时** → 更新 benchmark 表格 + +#### Step 4: A/B 测试 + +```bash +# Version A: 原始 skill +cp -r rocm-kernels rocm-kernels-v1 + +# Version B: 修改后的 skill +# (直接编辑 rocm-kernels/) + +# 分别运行评估,对比结果 +# Compare: runnable_ratio, accuracy_ratio, speed_ratio +``` + +## 3. 与 AI 协作优化 Skill 的方法 + +### 3.1 "生成-评测-反馈" 循环 + +``` +你 (编写 Skill) + ↓ +AI (使用 Skill 生成 kernel) + ↓ +KernelBench (评测 kernel) + ↓ +你 (分析失败,改进 Skill) + ↓ +(重复) +``` + +### 3.2 具体协作方式 + +#### 方式 1: 让 AI 帮你分析失败 + +``` +提示词: "这是 KernelBench Level 1 的评测结果 [粘贴结果]。 +请分析以下失败案例的根因,并建议修改 SKILL.md 的哪些部分。" +``` + +#### 方式 2: 让 AI 帮你生成测试用例 + +``` +提示词: "根据 rocm-kernels skill,为 GEMM 类别生成一个测试 kernel, +目标是 4096x4096 方阵乘法,使用 MI355X 优化。" +``` + +然后手动运行测试,看结果是否符合预期。 + +#### 方式 3: 让 AI 帮你补充 Skill 内容 + +``` +提示词: "BatchNorm 在 AMD GPU 上的评测结果是 0.04x(极差)。 +错误类型是 HIP 运行时错误。请帮我分析原因,并在 troubleshooting.md +中添加对应的解决方案。" +``` + +### 3.3 Skill 版本管理 + +``` +rocm-kernels/ +├── SKILL.md # 主文件 (跟踪版本号) +├── CHANGELOG.md # 变更日志 (每次优化后记录) +├── references/ +└── scripts/ +``` + +**CHANGELOG 格式:** + +```markdown +## v0.2 (2026-03-15) +- Added XCD swizzle pattern for GEMM (fixed 0.3x → 1.1x speedup) +- Added tanh workaround for ROCm +- Fixed LDS overflow guidance for MI355X + +## v0.1 (2026-03-10) +- Initial version with basic templates +``` + +## 4. KernelBench 分类 Skill 开发计划 + +### 4.1 开发优先级 + +| 优先级 | 类别 | 原因 | 预期 Skill 文件 | +|--------|------|------|----------------| +| **P0** | Elementwise | 最多算子、最容易成功 | `elementwise-skill.md` | +| **P0** | GEMM | 最高影响、最频繁使用 | `gemm-skill.md` | +| **P1** | Normalization | 常用、中等难度 | `normalization-skill.md` | +| **P1** | Reduction | 常用、有成熟模式 | `reduction-skill.md` | +| **P1** | Softmax | Attention 基础 | `softmax-skill.md` | +| **P2** | Pooling | 中等频率、Grid 映射挑战 | `pooling-skill.md` | +| **P2** | Attention | 高复杂度、高价值 | `attention-skill.md` | +| **P2** | Fused | 多操作组合 | `fused-skill.md` | + +### 4.2 每个 Skill 文件结构 + +```markdown +--- +name: rocm-{category}-kernel +description: "..." +--- + +# {Category} Kernel Skill + +## Pattern Overview +[核心算法模式] + +## Template Code +[可复制的完整代码] + +## Autotune Configurations +[MI355X 和 R9700 的推荐配置] + +## Common Mistakes +[该类别特有的陷阱] + +## Benchmark Results +[该类别的已知性能数据] +``` + +### 4.3 评估指标目标 + +| 类别 | 可运行率目标 | 正确率目标 | 平均 Speedup 目标 | +|------|------------|-----------|------------------| +| Elementwise | >95% | >90% | >1.5x | +| GEMM | >80% | >70% | >0.8x | +| Normalization | >85% | >80% | >1.0x | +| Reduction | >90% | >85% | >1.5x | +| Softmax | >85% | >80% | >1.0x | +| Pooling | >70% | >60% | >1.0x | +| Attention | >60% | >50% | >0.8x | +| Fused | >50% | >40% | >0.8x | + +## 5. 持续监控 + +### 5.1 定期评估 + +- **每周**: 运行一次 Level 1 全量评估 +- **每次 Skill 修改后**: 运行修改类别的评估 +- **每月**: 运行 Level 1+2 全量评估 + +### 5.2 回归检测 + +修改 Skill 后,确保不会导致其他类别的性能下降: + +```bash +# 修改前: 保存基线 +python eval.py --save-baseline baseline_v1.json + +# 修改后: 对比 +python eval.py --compare-baseline baseline_v1.json +# 任何类别的指标下降 > 5% 需要调查 +``` + +## 6. 总结:优化 Skill 的核心原则 + +1. **数据驱动**: 每次修改都基于 KernelBench 评估数据 +2. **分类优化**: 按 operator 类别独立迭代 +3. **最小修改**: 每次只改一个点,方便归因 +4. **版本记录**: 每次修改记录 CHANGELOG +5. **A/B 测试**: 对比修改前后的评测结果 +6. **渐进式**: 先覆盖高优先级类别(Elementwise → GEMM → Norm) +7. **陷阱文档化**: 每个新发现的坑都写入 troubleshooting.md diff --git a/skills/rocm-kernels/references/transformers-integration.md b/skills/rocm-kernels/references/transformers-integration.md new file mode 100644 index 00000000..3841c489 --- /dev/null +++ b/skills/rocm-kernels/references/transformers-integration.md @@ -0,0 +1,340 @@ +# Transformers Library Integration Guide (ROCm / Triton) + +Complete guide for integrating custom Triton kernels into HuggingFace transformers models on AMD GPUs. + +> **Quick Start:** See [transformers_injection_example.py](../scripts/transformers_injection_example.py) for a minimal working example (~150 lines). + +## Overview + +The HuggingFace transformers library has different architecture patterns than diffusers. Understanding these patterns is critical for successful kernel integration with models like LLaMA, Mistral, Qwen, and other LLMs on ROCm. + +**Key difference from diffusers:** All transformers RMSNorm modules have weights (`elementwise_affine=True`). No need to handle the weight-less variant. + +## Model Architecture Analysis + +```python +from transformers import AutoModelForCausalLM, AutoConfig +import torch + +config = AutoConfig.from_pretrained("Qwen/Qwen3-8B") +print(f"Hidden size: {config.hidden_size}") # 4096 +print(f"Num layers: {config.num_hidden_layers}") # 32 +print(f"Num heads: {config.num_attention_heads}") # 32 +print(f"RMS norm eps: {config.rms_norm_eps}") # 1e-6 + +model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3-8B", + torch_dtype=torch.bfloat16, + device_map="cuda" # ROCm uses same API via HIP +) + +for name, module in model.named_modules(): + class_name = type(module).__name__ + if 'Norm' in class_name: + has_weight = hasattr(module, 'weight') and module.weight is not None + print(f"{name}: {class_name} (has_weight={has_weight})") +``` + +## Common Transformers Architectures + +### LLaMA / Llama-2 / Llama-3 + +| Component | Class | Has Weight | Notes | +|-----------|-------|------------|-------| +| `model.norm` | LlamaRMSNorm | Yes | Final layer norm | +| `model.layers.*.input_layernorm` | LlamaRMSNorm | Yes | Pre-attention norm | +| `model.layers.*.post_attention_layernorm` | LlamaRMSNorm | Yes | Pre-FFN norm | +| `model.layers.*.mlp` | LlamaMLP | - | Uses SiLU gating | + +### Mistral / Mixtral + +| Component | Class | Has Weight | Notes | +|-----------|-------|------------|-------| +| `model.norm` | MistralRMSNorm | Yes | Final layer norm | +| `model.layers.*.input_layernorm` | MistralRMSNorm | Yes | Pre-attention norm | +| `model.layers.*.post_attention_layernorm` | MistralRMSNorm | Yes | Pre-FFN norm | + +### Qwen / Qwen2 / Qwen3 + +| Component | Class | Has Weight | Notes | +|-----------|-------|------------|-------| +| `model.norm` | Qwen2RMSNorm | Yes | Final layer norm | +| `model.layers.*.input_layernorm` | Qwen2RMSNorm | Yes | Pre-attention norm | +| `model.layers.*.post_attention_layernorm` | Qwen2RMSNorm | Yes | Pre-FFN norm | + +### Kernel Applicability + +| Kernel | LLaMA | Mistral | Qwen | Notes | +|--------|-------|---------|------|-------| +| RMSNorm | **Yes** | **Yes** | **Yes** | All use RMSNorm with weights | +| GEGLU | No | No | No | Uses SiLU gating instead | +| RoPE | Indirect | Indirect | Indirect | Computed by transformers internally | +| Attention | Via SDPA | Via SDPA | Via SDPA | Use Flash Attention 2 | + +## Integration Pattern + +### Step 1: Set ROCm Environment Variables + +```python +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' +``` + +### Step 2: Define the Triton RMSNorm Kernel + +```python +import torch +import triton +import triton.language as tl + +@triton.jit +def rmsnorm_fwd_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x, D, + eps: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + out = x * rms_inv * w + + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) + + +def triton_rmsnorm(x, weight, eps=1e-6): + x_2d = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_2d) + M, D = x_2d.shape + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + rmsnorm_fwd_kernel[(M,)]( + x_2d, weight, out, x_2d.stride(0), D, eps, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) +``` + +### Step 3: Create RMSNorm Patcher + +```python +def patch_rmsnorm_modules(model) -> int: + """ + Patch all RMSNorm modules to use Triton kernel on ROCm. + + Works with LlamaRMSNorm, MistralRMSNorm, Qwen2RMSNorm, etc. + """ + patched_count = 0 + + for name, module in model.named_modules(): + class_name = type(module).__name__ + + if 'RMSNorm' in class_name: + # LLaMA uses 'variance_epsilon', others use 'eps' + eps = getattr(module, 'variance_epsilon', None) + if eps is None: + eps = getattr(module, 'eps', 1e-6) + + has_weight = hasattr(module, 'weight') and module.weight is not None + + if has_weight: + def make_patched_forward(mod, epsilon): + def patched_forward(hidden_states): + return triton_rmsnorm(hidden_states, mod.weight, eps=epsilon) + return patched_forward + module.forward = make_patched_forward(module, eps) + patched_count += 1 + else: + print(f"WARNING: {name} has no weight, skipping") + + return patched_count +``` + +### Step 4: Use in Script + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3-8B", + torch_dtype=torch.bfloat16, + device_map="cuda" # ROCm uses same device API via HIP +) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") + +count = patch_rmsnorm_modules(model) +print(f"Patched {count} RMSNorm modules") +# Expected: 65 modules (32 layers * 2 + 1 final) + +inputs = tokenizer("The capital of France is", return_tensors="pt").to("cuda") +outputs = model.generate(**inputs, max_new_tokens=20) +print(tokenizer.decode(outputs[0])) +``` + +## Key Differences from Diffusers + +### 1. RMSNorm Always Has Weight + +Unlike diffusers (where some RMSNorm modules have `elementwise_affine=False`), transformers RMSNorm modules **always** have weights. The `HAS_WEIGHT` branch is always true, so you can simplify the kernel to always load weights. + +### 2. Different Epsilon Attribute Names + +```python +# LLaMA uses 'variance_epsilon' +eps = getattr(module, 'variance_epsilon', 1e-6) + +# Some models use 'eps' +eps = getattr(module, 'eps', 1e-6) + +# Safe pattern +eps = getattr(module, 'variance_epsilon', None) or getattr(module, 'eps', 1e-6) +``` + +### 3. No Attention Processor Pattern + +Diffusers uses `set_processor()` for attention modules. Transformers does not: + +```python +# Transformers: Use Flash Attention 2 instead of custom processors +model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2" +) +``` + +### 4. Device Map vs Manual Move + +```python +# Transformers — use device_map for large models +model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map="auto" # Handles multi-GPU automatically +) + +# Diffusers — manual move then CPU offload +pipe = DiffusionPipeline.from_pretrained(model_id) +pipe.to("cuda") +pipe.enable_model_cpu_offload() +``` + +## ROCm-Specific Considerations + +### 1. ROCm Environment Setup + +```python +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' +``` + +### 2. No tl.libdevice / tl.math.tanh + +If you extend beyond RMSNorm (e.g., custom SiLU activation), remember tanh is not available: + +```python +# Manual tanh for ROCm +e2x = tl.exp(2.0 * x) +tanh_x = (e2x - 1.0) / (e2x + 1.0) +``` + +### 3. Verify HIP Backend + +```python +import torch +print(f"HIP version: {torch.version.hip}") # Should show ROCm version +print(f"GPU: {torch.cuda.get_device_name()}") +``` + +### 4. torch.compile on ROCm + +Custom Triton kernels and `torch.compile` can coexist on ROCm since Triton is already the compilation backend. However, test thoroughly as behavior may differ from eager mode. + +## Model-Specific Integration + +### LLaMA Models + +```python +from transformers import LlamaForCausalLM + +model = LlamaForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + torch_dtype=torch.bfloat16, + device_map="cuda" +) + +count = patch_rmsnorm_modules(model) +print(f"Patched {count} LlamaRMSNorm modules") +# Expected: 65 modules (32 layers * 2 + 1 final) +``` + +### Qwen3-8B + +```python +model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3-8B", + torch_dtype=torch.bfloat16, + device_map="cuda" +) + +count = patch_rmsnorm_modules(model) +print(f"Patched {count} Qwen2RMSNorm modules") +# Expected: 65 modules (32 layers * 2 + 1 final) +``` + +## Verification + +### Verify Injection Worked + +```python +x = torch.randn(1, 10, model.config.hidden_size, device='cuda', dtype=torch.bfloat16) +for name, module in model.named_modules(): + if 'RMSNorm' in type(module).__name__: + out = module(x) + print(f"RMSNorm forward pass: {x.shape} -> {out.shape}") + break +``` + +### Run Generation Test + +```python +inputs = tokenizer("Hello, my name is", return_tensors="pt").to("cuda") +with torch.inference_mode(): + outputs = model.generate(**inputs, max_new_tokens=20) +print(tokenizer.decode(outputs[0])) +``` + +### Profile on ROCm + +```bash +rocprof --stats python your_script.py +rocprofv3 -i metrics.txt python your_script.py +``` + +## Performance Optimization + +### Enable Flash Attention 2 + +```python +model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map="cuda" +) +``` + +### Combine with Custom Kernels + +```python +model = AutoModelForCausalLM.from_pretrained(model_id, ...) +patch_rmsnorm_modules(model) # Inject Triton RMSNorm +# Flash Attention 2 handles attention optimization +``` diff --git a/skills/rocm-kernels/references/troubleshooting.md b/skills/rocm-kernels/references/troubleshooting.md new file mode 100644 index 00000000..b0f0e443 --- /dev/null +++ b/skills/rocm-kernels/references/troubleshooting.md @@ -0,0 +1,292 @@ +# ROCm Triton Kernel Troubleshooting Guide + +Common issues and solutions when developing Triton kernels for AMD GPUs. + +## Build / Import Issues + +### 1. `tl.libdevice` Not Found + +**Error:** `AttributeError: module 'triton.language' has no attribute 'libdevice'` + +**Cause:** `tl.libdevice` is CUDA-only (NVIDIA's libdevice library). + +**Fix:** Replace with manual implementations: +```python +# WRONG (CUDA only) +tl.libdevice.tanh(x) +tl.libdevice.log1p(x) + +# CORRECT (ROCm compatible) +e2x = tl.exp(2.0 * x); tanh_x = (e2x - 1.0) / (e2x + 1.0) +log1p_x = tl.log(1.0 + x) +``` + +### 2. `tl.tanh` / `tl.math.tanh` Not Available + +**Error:** `AttributeError: module 'triton.language.math' has no attribute 'tanh'` + +**Cause:** Neither `tl.tanh`, `tl.math.tanh`, nor `tl.libdevice.tanh` exist on ROCm Triton. This is the most common GEGLU compilation failure. + +**Fix — manual tanh (ONLY reliable method):** +```python +x_f32 = x.to(tl.float32) +e2x = tl.exp(2.0 * x_f32) +tanh_x = (e2x - 1.0) / (e2x + 1.0) +``` + +## Runtime Errors + +### 3. HIP Runtime Error: Invalid Argument + +**Error:** `hipErrorInvalidValue` or `HIP Error: invalid argument` + +**Common causes:** +- Grid/block size exceeds hardware limits +- Mismatched tensor shapes +- LDS overflow + +**Fix:** +```python +# Check grid size +grid = (triton.cdiv(N, BLOCK_SIZE),) +assert grid[0] > 0, f"Grid size must be > 0, got {grid[0]}" + +# Ensure contiguous tensors +x = x.contiguous() + +# Reduce num_stages to avoid LDS overflow +# num_stages=2 is safest +``` + +### 4. HIP Out of Memory (LDS) + +**Error:** `AMDGPU_KERNEL_ERROR_OUT_OF_MEMORY` or `LDS size exceeds limit` + +**Cause:** Kernel uses more LDS than available (64 KB on R9700, 160 KB on MI355X). + +**Fix:** +```python +# Reduce num_stages +num_stages=2 # instead of 3 or 4 + +# Reduce block sizes +BLOCK_M=64, BLOCK_N=64, BLOCK_K=32 # smaller tiles +``` + +### 5. Kernel Timeout + +**Error:** Kernel hangs or times out. + +**Common cause:** Grid and Program ID mismatch. + +```python +# WRONG: 1D grid but 2D program_id +grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) +pid_m = tl.program_id(0) # OK +pid_n = tl.program_id(1) # ERROR: axis 1 doesn't exist in 1D grid + +# CORRECT: Compute 2D indices from 1D grid +grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) +pid = tl.program_id(0) +pid_m = pid // triton.cdiv(N, BLOCK_N) +pid_n = pid % triton.cdiv(N, BLOCK_N) +``` + +## Correctness Issues + +### 6. Autotuning BLOCK_D Causes Wrong Results + +**Symptom:** RMSNorm/AdaLN/GEGLU correctness fails with large `max_abs` errors (4-8+). Kernel runs fast but produces garbage. + +**Cause:** `@triton.autotune` with `BLOCK_D` configs (e.g., 512, 1024, 2048, 4096) may select a `BLOCK_D < D` (hidden dimension). Since `tl.arange(0, BLOCK_D)` only covers `BLOCK_D` elements, the kernel processes a partial row, computing wrong variance and writing incomplete output. + +**Fix:** Never autotune `BLOCK_D` for row-reduction kernels. Compute it dynamically: +```python +# WRONG — autotune may pick BLOCK_D=512 when D=2048 +@triton.autotune(configs=[ + triton.Config({'BLOCK_D': 512}, num_warps=4), + triton.Config({'BLOCK_D': 1024}, num_warps=8), +], key=['D']) + +# CORRECT — compute in Python wrapper +BLOCK_D = triton.next_power_of_2(D) +num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) +kernel[(M,)](..., BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2) +``` + +### 7. RoPE cos/sin Out-of-Bounds GPU Crash (batch > 1) + +**Symptom:** `Memory access fault by GPU node` crash. Only happens when batch_size > 1. + +**Cause:** cos/sin tensors have shape `[seq_len, head_dim]`, but when the grid is `(batch * seq_len, num_heads)`, `pid_s` ranges `[0, batch * seq_len)`. For `pid_s >= seq_len`, `cos_ptr + pid_s * head_dim` is out of bounds. + +**Fix:** Use modular indexing for cos/sin: +```python +# WRONG — crashes when pid_s >= seq_len +cos_val = tl.load(cos_ptr + pid_s * head_dim + offs, ...) + +# CORRECT — wrap position index for batch dimension +seq_idx = pid_s % seq_len +cos_val = tl.load(cos_ptr + seq_idx * head_dim + offs, ...) +``` + +### 8. FP16/BF16 Precision Loss + +**Symptom:** Results differ from PyTorch reference by more than tolerance. + +**Fix:** Always accumulate in FP32: +```python +# WRONG: Accumulate in FP16 +acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float16) + +# CORRECT: Accumulate in FP32, cast at store +acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) +# ... computation ... +result = acc.to(tl.float16) +tl.store(out_ptr + ..., result, mask=mask) +``` + +**Tolerance guidelines:** +- BF16 (7-bit mantissa): `atol=0.1`, `rtol=1e-2` +- FP16 (10-bit mantissa): `atol=0.01`, `rtol=1e-3` + +### 9. Mask Errors + +**Error:** `ValueError: Mask argument cannot be block type` + +**Fix:** Ensure mask dimensions match pointer dimensions: +```python +# 1D kernel +mask = offsets < n_elements +x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + +# 2D kernel +mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) +x = tl.load(ptr + offs_m[:, None] * stride + offs_n[None, :], mask=mask, other=0.0) +``` + +### 10. Python min/max Inside Kernel + +**Error:** `TypeError` or incorrect results. + +**Fix:** +```python +# WRONG: Python builtins +result = min(a, b) +result = max(a, b) + +# CORRECT: Triton functions +result = tl.minimum(a, b) +result = tl.maximum(a, b) +``` + +## Performance Issues + +### 11. GEMM Extremely Slow (0.3-0.5x) + +**Cause:** Missing XCD swizzle on MI355X. + +**Fix:** Add XCD swizzle pattern (see Template 5: GEMM with XCD Swizzle in kernel-templates.md). + +### 12. Elementwise Kernel Slow + +**Common causes:** +1. BLOCK_SIZE too small → not utilizing bandwidth +2. Internal loops → should process full block +3. Missing autotune → not finding optimal config + +**Fix:** +```python +# Use large BLOCK_SIZE for elementwise +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 2048}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_SIZE': 4096}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_SIZE': 8192}, num_warps=16, num_stages=2), + ], + key=['n_elements'], +) +``` + +### 13. Missing @triton.autotune (for elementwise) + +**Symptom:** Kernel runs but performance is poor. + +**Fix:** **EVERY kernel must have autotune with 4+ configs.** Fixed block sizes are almost never optimal. + +### 14. tl.store() Keyword Argument Error + +**Error:** `TypeError: store() got an unexpected keyword argument` + +**Fix:** Check Triton version API. Use positional arguments if needed: +```python +# Check your Triton version +# tl.store(ptr, value, mask=mask) # Most versions +# tl.store(ptr, value, mask) # Some older versions +``` + +### 15. eps: tl.constexpr Causes Recompilation Crash + +**Error:** `AttributeError("'NoneType' object has no attribute 'type'")` during Triton compilation + +**Cause:** When `eps` is declared as `tl.constexpr`, the kernel is compiled separately for each unique eps value. If the kernel first compiles with `eps=1e-6` and later is called with `eps=1e-8` (e.g., from `nn.RMSNorm.eps`), the recompilation on ROCm Triton can crash. + +**Fix:** Remove `tl.constexpr` from `eps` and pass it as a regular runtime parameter: +```python +# WRONG — triggers recompilation for each eps value, may crash on ROCm +@triton.jit +def rmsnorm_kernel(x_ptr, ..., eps: tl.constexpr, BLOCK_D: tl.constexpr): + ... + +# CORRECT — eps is a regular runtime float, no recompilation +@triton.jit +def rmsnorm_kernel(x_ptr, ..., eps, BLOCK_D: tl.constexpr): + ... + +# Also ensure eps is a plain float in the wrapper +rmsnorm_kernel[(M,)](..., float(eps), BLOCK_D=BLOCK_D, ...) +``` + +**Note:** Only `BLOCK_D`, `HAS_WEIGHT`, and other values that change kernel structure should be `tl.constexpr`. Parameters like `eps` that only affect numerical values should be regular parameters. + +## Debugging Tips + +### Check GPU Architecture + +```bash +rocminfo | grep "Name" +# Should show gfx950 (MI355X) or gfx1201 (R9700) +``` + +### Verify ROCm Triton Installation + +```python +import triton +print(triton.__version__) +import torch +print(torch.version.hip) # Should show ROCm version +print(torch.cuda.get_device_properties(0)) +``` + +### Profile Kernel + +```bash +# Basic profiling +rocprof --stats python your_kernel.py + +# Detailed kernel metrics +rocprofv3 -i metrics.txt python your_kernel.py +``` + +### Test Kernel Correctness + +```python +# Compare with PyTorch reference +ref_output = reference_model(inputs) +custom_output = custom_model(inputs) + +torch.testing.assert_close( + custom_output, ref_output, + rtol=1e-2, atol=1e-3 # FP16 tolerance +) +``` diff --git a/skills/rocm-kernels/scripts/benchmark_e2e.py b/skills/rocm-kernels/scripts/benchmark_e2e.py new file mode 100644 index 00000000..f9b00c6c --- /dev/null +++ b/skills/rocm-kernels/scripts/benchmark_e2e.py @@ -0,0 +1,296 @@ +#!/usr/bin/env python3 +""" +End-to-end benchmark: LTX-Video pipeline with/without custom Triton kernels on ROCm. + +Measures total generation time, per-step latency, and peak memory. + +Requirements: + pip install diffusers transformers accelerate torch triton + +Usage: + # Baseline (no custom kernels) + python benchmark_e2e.py --mode baseline + + # With custom Triton kernels + python benchmark_e2e.py --mode triton + + # With torch.compile + python benchmark_e2e.py --mode compile + + # Compare all three + python benchmark_e2e.py --mode all + + # Quick test (fewer frames/steps) + python benchmark_e2e.py --mode triton --num-frames 9 --steps 5 +""" +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' + +import argparse +import json +import time + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +# ============================================================================ +# Triton RMSNorm Kernel (same as in benchmark_kernels.py) +# ============================================================================ + +@triton.jit +def _rmsnorm_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x_row, D, + eps: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + + x = tl.load(x_ptr + row * stride_x_row + offs, mask=mask, other=0.0).to(tl.float32) + sq_sum = tl.sum(x * x, axis=0) + rms_inv = tl.rsqrt(sq_sum / D + eps) + + if HAS_WEIGHT: + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + out = x * rms_inv * w + else: + out = x * rms_inv + + tl.store(out_ptr + row * stride_x_row + offs, out.to(x.dtype), mask=mask) + + +def triton_rmsnorm(x, weight=None, eps=1e-6): + x_flat = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_flat) + M, D = x_flat.shape + has_weight = weight is not None + if not has_weight: + weight = torch.ones(D, device=x.device, dtype=x.dtype) + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + _rmsnorm_kernel[(M,)]( + x_flat, weight, out, x_flat.stride(0), D, eps, has_weight, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) + + +# ============================================================================ +# Attention Processor (uses Triton RMSNorm) +# ============================================================================ + +class TritonLTXVideoAttnProcessor: + def __call__(self, attn, hidden_states, encoder_hidden_states=None, + attention_mask=None, image_rotary_emb=None): + from diffusers.models.transformers.transformer_ltx import apply_rotary_emb + from diffusers.models.attention_dispatch import dispatch_attention_fn + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = triton_rmsnorm(query, attn.norm_q.weight, eps=attn.norm_q.eps) + key = triton_rmsnorm(key, attn.norm_k.weight, eps=attn.norm_k.eps) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + hidden_states = dispatch_attention_fn( + query, key, value, + attn_mask=attention_mask, dropout_p=0.0, is_causal=False, + ) + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +# ============================================================================ +# Module Patchers +# ============================================================================ + +def patch_rmsnorm_modules(model): + patched = 0 + for name, module in model.named_modules(): + if type(module).__name__ == 'RMSNorm': + eps = getattr(module, 'eps', 1e-6) + has_weight = hasattr(module, 'weight') and module.weight is not None + if has_weight: + def make_fwd(mod, e): + def fwd(x): return triton_rmsnorm(x, mod.weight, eps=e) + return fwd + module.forward = make_fwd(module, eps) + else: + def make_fwd_nw(e): + def fwd(x): return triton_rmsnorm(x, None, eps=e) + return fwd + module.forward = make_fwd_nw(eps) + patched += 1 + return patched + + +def inject_triton_kernels(pipe): + stats = {'attention_processors': 0, 'rmsnorm_modules': 0} + if not hasattr(pipe, 'transformer'): + return stats + for name, module in pipe.transformer.named_modules(): + if hasattr(module, 'set_processor') and hasattr(module, 'processor'): + module.set_processor(TritonLTXVideoAttnProcessor()) + stats['attention_processors'] += 1 + stats['rmsnorm_modules'] = patch_rmsnorm_modules(pipe.transformer) + return stats + + +# ============================================================================ +# Benchmark Runner +# ============================================================================ + +def run_benchmark(mode, prompt, num_frames, height, width, steps, + guidance_scale, seed, warmup_iters): + from diffusers import LTXPipeline + from diffusers.utils import export_to_video + + device = "cuda" + dtype = torch.bfloat16 + + print(f"\n{'='*60}") + print(f"MODE: {mode.upper()}") + print(f"{'='*60}") + + pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=dtype) + pipe.to(device) + + if mode == "triton": + stats = inject_triton_kernels(pipe) + print(f" Attention processors: {stats['attention_processors']}") + print(f" RMSNorm patched: {stats['rmsnorm_modules']}") + elif mode == "compile": + pipe.transformer.compile_repeated_blocks(fullgraph=True) + print(" torch.compile enabled (fullgraph=True)") + else: + print(" Baseline (no optimization)") + + # Warmup + if warmup_iters > 0: + print(f"\n Warmup ({warmup_iters} iters, {min(steps, 5)} steps)...") + for i in range(warmup_iters): + _ = pipe(prompt=prompt, num_frames=num_frames, height=height, width=width, + num_inference_steps=min(steps, 5), guidance_scale=guidance_scale) + torch.cuda.synchronize() + + # Benchmark run + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + print(f"\n Generating ({num_frames} frames, {steps} steps)...") + torch.cuda.synchronize() + start = time.time() + output = pipe( + prompt=prompt, num_frames=num_frames, height=height, width=width, + num_inference_steps=steps, guidance_scale=guidance_scale, + generator=torch.Generator(device=device).manual_seed(seed), + ) + torch.cuda.synchronize() + gen_time = time.time() - start + peak_mem = torch.cuda.max_memory_allocated() / 1e9 + + result = { + 'mode': mode, + 'gen_time_s': round(gen_time, 2), + 'time_per_frame_s': round(gen_time / num_frames, 3), + 'time_per_step_s': round(gen_time / steps, 3), + 'peak_memory_gb': round(peak_mem, 2), + } + + print(f"\n Results:") + print(f" Total: {result['gen_time_s']:.2f} s") + print(f" Per frame: {result['time_per_frame_s']:.3f} s") + print(f" Per step: {result['time_per_step_s']:.3f} s") + print(f" Peak mem: {result['peak_memory_gb']:.2f} GB") + + # Save video + out_path = f"ltx_video_{mode}.mp4" + export_to_video(output.frames[0], out_path, fps=24) + print(f" Saved to: {out_path}") + + del pipe + torch.cuda.empty_cache() + return result + + +def main(): + parser = argparse.ArgumentParser(description="E2E LTX-Video benchmark on ROCm") + parser.add_argument("--mode", type=str, default="all", + choices=["baseline", "triton", "compile", "all"]) + parser.add_argument("--prompt", type=str, + default="A cat sleeping in warm sunlight, cinematic, 4K") + parser.add_argument("--num-frames", type=int, default=25) + parser.add_argument("--height", type=int, default=480) + parser.add_argument("--width", type=int, default=704) + parser.add_argument("--steps", type=int, default=30) + parser.add_argument("--guidance-scale", type=float, default=7.5) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--warmup", type=int, default=1) + parser.add_argument("--output-json", type=str, default=None, + help="Save results to JSON for comparison") + args = parser.parse_args() + + print("=" * 60) + print("LTX-Video End-to-End Benchmark (ROCm)") + print("=" * 60) + print(f"Device: {torch.cuda.get_device_name(0)}") + print(f"ROCm: {torch.version.hip if hasattr(torch.version, 'hip') else 'N/A'}") + print(f"Config: {args.num_frames} frames, {args.height}x{args.width}, {args.steps} steps") + + modes = ["baseline", "triton", "compile"] if args.mode == "all" else [args.mode] + all_results = [] + + for mode in modes: + r = run_benchmark(mode, args.prompt, args.num_frames, args.height, + args.width, args.steps, args.guidance_scale, + args.seed, args.warmup) + all_results.append(r) + + # Comparison table + if len(all_results) > 1: + print(f"\n{'='*60}") + print("COMPARISON") + print(f"{'='*60}") + print(f"{'Mode':<12} {'Time (s)':<12} {'Per Step (s)':<15} {'Peak Mem (GB)':<15}") + print("-" * 54) + baseline_time = all_results[0]['gen_time_s'] + for r in all_results: + speedup = baseline_time / r['gen_time_s'] if r['gen_time_s'] > 0 else 0 + suffix = f" ({speedup:.2f}x)" if r['mode'] != 'baseline' else "" + print(f"{r['mode']:<12} {r['gen_time_s']:<12.2f} {r['time_per_step_s']:<15.3f} {r['peak_memory_gb']:<15.2f}{suffix}") + + if args.output_json: + with open(args.output_json, 'w') as f: + json.dump(all_results, f, indent=2) + print(f"\nResults saved to {args.output_json}") + + +if __name__ == "__main__": + main() diff --git a/skills/rocm-kernels/scripts/benchmark_kernels.py b/skills/rocm-kernels/scripts/benchmark_kernels.py new file mode 100644 index 00000000..bb72524a --- /dev/null +++ b/skills/rocm-kernels/scripts/benchmark_kernels.py @@ -0,0 +1,536 @@ +#!/usr/bin/env python3 +""" +Micro-benchmark for all 4 Triton kernels on ROCm: RMSNorm, RoPE 3D, GEGLU, AdaLN. + +Measures: + 1. Correctness vs PyTorch reference + 2. Latency (custom vs baseline, warmup + averaged) + 3. Memory bandwidth utilization + +Usage: + python benchmark_kernels.py + python benchmark_kernels.py --kernel rmsnorm + python benchmark_kernels.py --kernel rope + python benchmark_kernels.py --kernel geglu + python benchmark_kernels.py --kernel adaln + python benchmark_kernels.py --dtype float16 +""" +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' + +import argparse +import time +from typing import Tuple + +import torch +import triton +import triton.language as tl + + +# ============================================================================ +# Kernel 1: RMSNorm +# ============================================================================ +# CRITICAL: BLOCK_D must be >= D (hidden dimension). +# Using autotune with fixed BLOCK_D configs is WRONG because autotune may +# pick BLOCK_D < D, causing only partial row processing. +# Fix: compute BLOCK_D = next_power_of_2(D) dynamically in the Python wrapper. + +@triton.jit +def rmsnorm_fwd_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x, D, + eps: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_D) + mask = col_offsets < D + + x = tl.load(x_ptr + row * stride_x + col_offsets, mask=mask, other=0.0).to(tl.float32) + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + + if HAS_WEIGHT: + w = tl.load(weight_ptr + col_offsets, mask=mask, other=1.0).to(tl.float32) + result = x * rms_inv * w + else: + result = x * rms_inv + + tl.store(out_ptr + row * stride_x + col_offsets, result.to(x.dtype), mask=mask) + + +def triton_rmsnorm(x, weight=None, eps=1e-6): + orig_shape = x.shape + x_2d = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_2d) + M, D = x_2d.shape + has_weight = weight is not None + if not has_weight: + weight = torch.empty(0, device=x.device) + + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + rmsnorm_fwd_kernel[(M,)]( + x_2d, weight, out, + x_2d.stride(0), D, eps, has_weight, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view(orig_shape) + + +def pytorch_rmsnorm(x, weight=None, eps=1e-6): + variance = x.pow(2).mean(dim=-1, keepdim=True) + out = x * torch.rsqrt(variance + eps) + if weight is not None: + out = out * weight + return out + + +# ============================================================================ +# Kernel 2: RoPE 3D +# ============================================================================ +# CRITICAL: cos/sin have shape [seq_len, head_dim], NOT [batch*seq_len, ...]. +# When grid is (batch * seq_len, num_heads), we must use pid_s % seq_len +# to index into cos/sin to avoid out-of-bounds access for batch > 1. + +@triton.jit +def rope_3d_fwd_kernel( + qk_ptr, cos_ptr, sin_ptr, out_ptr, + seq_len, num_heads, head_dim, + stride_s, stride_h, stride_d, + BLOCK_HD: tl.constexpr, +): + pid_s = tl.program_id(0) + pid_h = tl.program_id(1) + half_dim = head_dim // 2 + offs = tl.arange(0, BLOCK_HD) + mask = offs < half_dim + + base = pid_s * stride_s + pid_h * stride_h + x0 = tl.load(qk_ptr + base + offs, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load(qk_ptr + base + half_dim + offs, mask=mask, other=0.0).to(tl.float32) + + seq_idx = pid_s % seq_len + cos_val = tl.load(cos_ptr + seq_idx * head_dim + offs, mask=mask, other=1.0).to(tl.float32) + sin_val = tl.load(sin_ptr + seq_idx * head_dim + offs, mask=mask, other=0.0).to(tl.float32) + + out0 = x0 * cos_val - x1 * sin_val + out1 = x0 * sin_val + x1 * cos_val + + tl.store(out_ptr + base + offs, out0.to(x0.dtype), mask=mask) + tl.store(out_ptr + base + half_dim + offs, out1.to(x0.dtype), mask=mask) + + +def triton_rope_3d(qk, cos, sin): + qk = qk.contiguous() + out = torch.empty_like(qk) + batch, seq_len, num_heads, head_dim = qk.shape + half_dim = head_dim // 2 + qk_flat = qk.view(batch * seq_len, num_heads, head_dim) + out_flat = out.view(batch * seq_len, num_heads, head_dim) + grid = (batch * seq_len, num_heads) + BLOCK_HD = triton.next_power_of_2(half_dim) + num_warps = 4 if BLOCK_HD <= 64 else 8 + rope_3d_fwd_kernel[grid]( + qk_flat, cos, sin, out_flat, + seq_len, num_heads, head_dim, + qk_flat.stride(0), qk_flat.stride(1), qk_flat.stride(2), + BLOCK_HD=BLOCK_HD, num_warps=num_warps, num_stages=2, + ) + return out + + +def pytorch_rope(qk, cos, sin): + half = qk.shape[-1] // 2 + x0, x1 = qk[..., :half], qk[..., half:] + cos_exp = cos.unsqueeze(0).unsqueeze(2)[:, :qk.shape[1], :, :half] + sin_exp = sin.unsqueeze(0).unsqueeze(2)[:, :qk.shape[1], :, :half] + out0 = x0 * cos_exp - x1 * sin_exp + out1 = x0 * sin_exp + x1 * cos_exp + return torch.cat([out0, out1], dim=-1) + + +# ============================================================================ +# Kernel 3: GEGLU +# ============================================================================ +# Same BLOCK_SIZE fix as RMSNorm: compute dynamically, do NOT autotune. + +@triton.jit +def geglu_fwd_kernel( + input_ptr, output_ptr, + stride_in, stride_out, hidden_size, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_H) + mask = offs < hidden_size + + gate = tl.load(input_ptr + row * stride_in + offs, mask=mask, other=0.0).to(tl.float32) + value = tl.load(input_ptr + row * stride_in + hidden_size + offs, mask=mask, other=0.0).to(tl.float32) + + # GELU approx: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + # tl.math.tanh / tl.libdevice.tanh NOT available on ROCm — use manual formula + SQRT_2_OVER_PI = 0.7978845608028654 + tanh_arg = SQRT_2_OVER_PI * (gate + 0.044715 * gate * gate * gate) + e2x = tl.exp(2.0 * tanh_arg) + tanh_val = (e2x - 1.0) / (e2x + 1.0) + cdf = 0.5 * (1.0 + tanh_val) + gelu_gate = gate * cdf + result = gelu_gate * value + + tl.store(output_ptr + row * stride_out + offs, result.to(gate.dtype), mask=mask) + + +def triton_geglu(x): + x = x.contiguous() + *batch_dims, double_h = x.shape + hidden_size = double_h // 2 + x_2d = x.view(-1, double_h) + M = x_2d.shape[0] + out = torch.empty(M, hidden_size, device=x.device, dtype=x.dtype) + + BLOCK_H = triton.next_power_of_2(hidden_size) + num_warps = 4 if BLOCK_H <= 1024 else (8 if BLOCK_H <= 4096 else 16) + geglu_fwd_kernel[(M,)]( + x_2d, out, + x_2d.stride(0), out.stride(0), hidden_size, + BLOCK_H=BLOCK_H, num_warps=num_warps, num_stages=2, + ) + return out.view(*batch_dims, hidden_size) + + +def pytorch_geglu(x): + hidden_size = x.shape[-1] // 2 + gate, value = x[..., :hidden_size], x[..., hidden_size:] + return torch.nn.functional.gelu(gate, approximate='tanh') * value + + +# ============================================================================ +# Kernel 4: AdaLN +# ============================================================================ +# Same BLOCK_D fix: compute dynamically. + +@triton.jit +def adaln_fwd_kernel( + x_ptr, weight_ptr, scale_ptr, shift_ptr, out_ptr, + stride_x, stride_cond, D, + eps: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + x_norm = x * rms_inv + + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + scale = tl.load(scale_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) + shift = tl.load(shift_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) + + out = x_norm * w * (1.0 + scale) + shift + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) + + +def triton_adaln(x, weight, scale, shift, eps=1e-6): + x_flat = x.contiguous().view(-1, x.shape[-1]) + scale_flat = scale.contiguous().view(-1, x.shape[-1]) + shift_flat = shift.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_flat) + M, D = x_flat.shape + + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + adaln_fwd_kernel[(M,)]( + x_flat, weight, scale_flat, shift_flat, out, + x_flat.stride(0), scale_flat.stride(0), D, eps, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) + + +def pytorch_adaln(x, weight, scale, shift, eps=1e-6): + variance = x.pow(2).mean(dim=-1, keepdim=True) + x_norm = x * torch.rsqrt(variance + eps) + return x_norm * weight * (1.0 + scale) + shift + + +# ============================================================================ +# Benchmark Utilities +# ============================================================================ + +def benchmark_fn(func, args, warmup=20, iterations=100) -> Tuple[float, float]: + for _ in range(warmup): + func(*args) + torch.cuda.synchronize() + + times = [] + for _ in range(iterations): + torch.cuda.synchronize() + start = time.perf_counter() + func(*args) + torch.cuda.synchronize() + end = time.perf_counter() + times.append((end - start) * 1000) + + return sum(times) / len(times), min(times) + + +def check_correctness(out, ref, name, dtype): + max_abs = (out.float() - ref.float()).abs().max().item() + max_rel = ((out.float() - ref.float()).abs() / (ref.float().abs() + 1e-8)).max().item() + + # BF16 has 7-bit mantissa; for values ~8-16 the ULP is 0.0625-0.125 + # FP16 has 10-bit mantissa; much tighter precision + atol = 0.15 if dtype == torch.bfloat16 else 0.01 + passed = max_abs < atol + status = "PASS" if passed else "FAIL" + print(f" [{status}] {name}: max_abs={max_abs:.6e}, max_rel={max_rel:.6e}") + return passed + + +# ============================================================================ +# Benchmark Runners +# ============================================================================ + +def benchmark_rmsnorm(dtype): + print("\n" + "=" * 70) + print("BENCHMARK: RMSNorm (168 instances in LTX-Video)") + print("=" * 70) + + configs = [ + (1, 1024, 2048), + (2, 1024, 2048), + (4, 1024, 2048), + (1, 4096, 2048), + (2, 4096, 3072), + (1, 8192, 2048), + (4, 4096, 3072), + ] + + print(f"\n{'Config':<25} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") + print("-" * 70) + + all_correct = True + total_speedup = 0 + + for batch, seq, hidden in configs: + x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") + w = torch.ones(hidden, dtype=dtype, device="cuda") + + ref = pytorch_rmsnorm(x, w) + out = triton_rmsnorm(x, w) + if not check_correctness(out, ref, f"[{batch}x{seq}x{hidden}]", dtype): + all_correct = False + + t_avg, _ = benchmark_fn(triton_rmsnorm, (x, w)) + p_avg, _ = benchmark_fn(pytorch_rmsnorm, (x, w)) + speedup = p_avg / t_avg + total_speedup += speedup + + print(f" [{batch}x{seq}x{hidden}]{'':<13} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") + + # No-weight variant + print("\n -- No-weight variant (elementwise_affine=False) --") + x = torch.randn(2, 4096, 2048, dtype=dtype, device="cuda") + ref_nw = pytorch_rmsnorm(x, None) + out_nw = triton_rmsnorm(x, None) + check_correctness(out_nw, ref_nw, "no-weight [2x4096x2048]", dtype) + + avg_speedup = total_speedup / len(configs) + print(f"\n Average speedup: {avg_speedup:.2f}x") + + # Bandwidth analysis + batch, seq, hidden = 4, 4096, 3072 + x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") + w = torch.ones(hidden, dtype=dtype, device="cuda") + bytes_per_elem = 2 if dtype in (torch.float16, torch.bfloat16) else 4 + total_bytes = batch * seq * hidden * bytes_per_elem * 2 + hidden * bytes_per_elem + t_avg, _ = benchmark_fn(triton_rmsnorm, (x, w)) + bw_gbps = (total_bytes / 1e9) / (t_avg / 1000) + print(f"\n Bandwidth analysis [{batch}x{seq}x{hidden}]:") + print(f" Data moved: {total_bytes / 1e6:.2f} MB") + print(f" Achieved: {bw_gbps:.1f} GB/s") + + return all_correct, avg_speedup + + +def benchmark_rope(dtype): + print("\n" + "=" * 70) + print("BENCHMARK: RoPE 3D (Video Position Encoding)") + print("=" * 70) + + configs = [ + (1, 1024, 16, 64), + (1, 4096, 16, 64), + (2, 4096, 16, 128), + (1, 8192, 32, 64), + ] + + print(f"\n{'Config':<30} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") + print("-" * 75) + + all_correct = True + total_speedup = 0 + + for batch, seq, heads, hdim in configs: + qk = torch.randn(batch, seq, heads, hdim, dtype=dtype, device="cuda") + cos = torch.randn(seq, hdim, dtype=dtype, device="cuda") + sin = torch.randn(seq, hdim, dtype=dtype, device="cuda") + + ref = pytorch_rope(qk, cos, sin) + out = triton_rope_3d(qk, cos, sin) + if not check_correctness(out, ref, f"[{batch}x{seq}x{heads}x{hdim}]", dtype): + all_correct = False + + t_avg, _ = benchmark_fn(triton_rope_3d, (qk, cos, sin)) + p_avg, _ = benchmark_fn(pytorch_rope, (qk, cos, sin)) + speedup = p_avg / t_avg + total_speedup += speedup + + cfg = f"[{batch}x{seq}x{heads}x{hdim}]" + print(f" {cfg:<28} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") + + avg_speedup = total_speedup / len(configs) + print(f"\n Average speedup: {avg_speedup:.2f}x") + return all_correct, avg_speedup + + +def benchmark_geglu(dtype): + print("\n" + "=" * 70) + print("BENCHMARK: GEGLU (For SD3/FLUX, NOT LTX-Video)") + print("=" * 70) + + configs = [ + (1, 1024, 2048), + (2, 1024, 4096), + (2, 4096, 3072), + (4, 4096, 4096), + ] + + print(f"\n{'Config':<30} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") + print("-" * 75) + + all_correct = True + total_speedup = 0 + + for batch, seq, hidden in configs: + x = torch.randn(batch, seq, hidden * 2, dtype=dtype, device="cuda") + + ref = pytorch_geglu(x) + out = triton_geglu(x) + if not check_correctness(out, ref, f"[{batch}x{seq}x{hidden*2}]", dtype): + all_correct = False + + t_avg, _ = benchmark_fn(triton_geglu, (x,)) + p_avg, _ = benchmark_fn(pytorch_geglu, (x,)) + speedup = p_avg / t_avg + total_speedup += speedup + + cfg = f"[{batch}x{seq}x{hidden*2}->{hidden}]" + print(f" {cfg:<28} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") + + avg_speedup = total_speedup / len(configs) + print(f"\n Average speedup: {avg_speedup:.2f}x") + return all_correct, avg_speedup + + +def benchmark_adaln(dtype): + print("\n" + "=" * 70) + print("BENCHMARK: AdaLN (Fused Norm + Conditioning for DiT)") + print("=" * 70) + + configs = [ + (1, 1024, 2048), + (2, 1024, 2048), + (2, 4096, 3072), + (4, 4096, 3072), + ] + + print(f"\n{'Config':<25} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") + print("-" * 70) + + all_correct = True + total_speedup = 0 + + for batch, seq, hidden in configs: + x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") + w = torch.ones(hidden, dtype=dtype, device="cuda") + scale = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") * 0.1 + shift = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") * 0.1 + + ref = pytorch_adaln(x, w, scale, shift) + out = triton_adaln(x, w, scale, shift) + if not check_correctness(out, ref, f"[{batch}x{seq}x{hidden}]", dtype): + all_correct = False + + t_avg, _ = benchmark_fn(triton_adaln, (x, w, scale, shift)) + p_avg, _ = benchmark_fn(pytorch_adaln, (x, w, scale, shift)) + speedup = p_avg / t_avg + total_speedup += speedup + + print(f" [{batch}x{seq}x{hidden}]{'':<13} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") + + avg_speedup = total_speedup / len(configs) + print(f"\n Average speedup: {avg_speedup:.2f}x") + return all_correct, avg_speedup + + +# ============================================================================ +# Main +# ============================================================================ + +def main(): + parser = argparse.ArgumentParser(description="Benchmark Triton kernels on ROCm") + parser.add_argument("--kernel", type=str, default="all", + choices=["all", "rmsnorm", "rope", "geglu", "adaln"]) + parser.add_argument("--dtype", type=str, default="bfloat16", + choices=["bfloat16", "float16"]) + args = parser.parse_args() + + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + + print("=" * 70) + print("ROCm Triton Kernel Micro-Benchmark") + print("=" * 70) + print(f"Device: {torch.cuda.get_device_name(0)}") + print(f"Dtype: {dtype}") + print(f"ROCm: {torch.version.hip if hasattr(torch.version, 'hip') else 'N/A'}") + + results = {} + runners = { + "rmsnorm": benchmark_rmsnorm, + "rope": benchmark_rope, + "geglu": benchmark_geglu, + "adaln": benchmark_adaln, + } + + if args.kernel == "all": + for name, runner in runners.items(): + correct, speedup = runner(dtype) + results[name] = {"correct": correct, "speedup": speedup} + else: + correct, speedup = runners[args.kernel](dtype) + results[args.kernel] = {"correct": correct, "speedup": speedup} + + # Summary + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + print(f"{'Kernel':<15} {'Correct':<12} {'Avg Speedup':<15}") + print("-" * 42) + for name, r in results.items(): + status = "PASS" if r["correct"] else "FAIL" + print(f"{name:<15} {status:<12} {r['speedup']:.2f}x") + + all_pass = all(r["correct"] for r in results.values()) + print(f"\nOverall: {'ALL PASS' if all_pass else 'SOME FAILED'}") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/skills/rocm-kernels/scripts/huggingface_kernels_example.py b/skills/rocm-kernels/scripts/huggingface_kernels_example.py new file mode 100644 index 00000000..80a0acf5 --- /dev/null +++ b/skills/rocm-kernels/scripts/huggingface_kernels_example.py @@ -0,0 +1,344 @@ +#!/usr/bin/env python3 +""" +Example: Using HuggingFace Kernels library to load and use optimized kernels on ROCm. + +This script demonstrates how to: +1. Load kernels from the HuggingFace Hub using get_kernel() +2. Check kernel availability with has_kernel() +3. Integrate Hub kernels with transformers/diffusers models +4. Fall back to local Triton kernels when Hub builds are unavailable + +Requirements: + pip install kernels torch numpy + +Usage: + python scripts/huggingface_kernels_example.py +""" + +import os +import time +from typing import Optional + +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +# ============================================================================= +# Local Triton RMSNorm (fallback when Hub kernel unavailable) +# ============================================================================= + +EPS_DEFAULT = 1e-6 + +@triton.jit +def rmsnorm_fwd_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x, D, eps, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + out = x * rms_inv * w + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) + + +def local_triton_rmsnorm(x, weight, eps=EPS_DEFAULT): + """Local Triton RMSNorm — used as fallback when Hub kernel is unavailable.""" + x_2d = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_2d) + M, D = x_2d.shape + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + rmsnorm_fwd_kernel[(M,)]( + x_2d, weight, out, x_2d.stride(0), D, float(eps), + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) + + +# ============================================================================= +# Part 1: Check Environment +# ============================================================================= + +def check_environment(): + """Print environment information for debugging.""" + print("=" * 60) + print("Environment") + print("=" * 60) + print(f"PyTorch: {torch.__version__}") + print(f"HIP version: {getattr(torch.version, 'hip', 'N/A')}") + print(f"CUDA available: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + print(f"GPU: {torch.cuda.get_device_name()}") + print() + + +# ============================================================================= +# Part 2: Basic Kernel Loading from Hub +# ============================================================================= + +def demo_basic_kernel_loading(): + """Demonstrate basic kernel loading from Hub.""" + print("=" * 60) + print("Part 1: Basic Kernel Loading from Hub") + print("=" * 60) + + try: + from kernels import get_kernel, has_kernel + + repo_id = "kernels-community/triton-layer-norm" + + print(f"\n1. Checking kernel availability: {repo_id}") + if has_kernel(repo_id): + print(" Kernel is available for this ROCm environment") + + print(f"\n2. Loading kernel from Hub...") + kernel = get_kernel(repo_id) + + print(f"\n3. Available functions:") + functions = [f for f in dir(kernel) if not f.startswith('_')] + for func in functions[:10]: + print(f" - {func}") + + print(f"\n4. Testing RMSNorm kernel...") + x = torch.randn(2, 1024, 2048, dtype=torch.bfloat16, device="cuda") + w = torch.ones(2048, dtype=torch.bfloat16, device="cuda") + + rms_fn_name = None + for name in ('rms_norm', 'rms_norm_fn', 'rmsnorm'): + if hasattr(kernel, name): + rms_fn_name = name + break + + if rms_fn_name: + rms_fn = getattr(kernel, rms_fn_name) + try: + out = rms_fn(x, w, eps=1e-6) + except TypeError: + # rms_norm_fn(x, weight, bias, ...) requires bias argument + out = rms_fn(x, w, None, eps=1e-6) + print(f" Using: kernel.{rms_fn_name}()") + print(f" Input: {x.shape}, Output: {out.shape}") + print(f" Success!") + else: + print(f" No RMSNorm function found. Available: {functions}") + + return kernel + else: + print(" No compatible build for this ROCm environment") + print(" Will use local Triton kernel as fallback") + return None + + except ImportError: + print("\n kernels library not installed. Install with: pip install kernels") + return None + except Exception as e: + print(f"\n Error: {e}") + return None + + +# ============================================================================= +# Part 3: Benchmark Hub Kernel vs Local Triton vs PyTorch +# ============================================================================= + +def demo_benchmark(hub_kernel): + """Benchmark Hub kernel vs local Triton vs PyTorch.""" + print("\n" + "=" * 60) + print("Part 2: Benchmark Hub vs Local Triton vs PyTorch") + print("=" * 60) + + shapes = [(2, 1024, 2048), (4, 4096, 3072)] + warmup, iterations = 20, 100 + + for shape in shapes: + x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") + w = torch.ones(shape[-1], dtype=torch.bfloat16, device="cuda") + + def _call_hub(fn, x, w, eps): + try: + return fn(x, w, eps=eps) + except TypeError: + return fn(x, w, None, eps=eps) + + hub_rms_fn_raw = None + if hub_kernel: + for fn_name in ('rms_norm', 'rms_norm_fn', 'rmsnorm'): + if hasattr(hub_kernel, fn_name): + hub_rms_fn_raw = getattr(hub_kernel, fn_name) + break + + # Warmup all implementations + for _ in range(warmup): + local_triton_rmsnorm(x, w, eps=1e-6) + variance = x.pow(2).mean(-1, keepdim=True) + _ = x * torch.rsqrt(variance + 1e-6) * w + if hub_rms_fn_raw: + _call_hub(hub_rms_fn_raw, x, w, 1e-6) + torch.cuda.synchronize() + + # PyTorch baseline + start = time.perf_counter() + for _ in range(iterations): + variance = x.pow(2).mean(-1, keepdim=True) + _ = x * torch.rsqrt(variance + 1e-6) * w + torch.cuda.synchronize() + pt_ms = (time.perf_counter() - start) / iterations * 1000 + + # Local Triton + start = time.perf_counter() + for _ in range(iterations): + local_triton_rmsnorm(x, w, eps=1e-6) + torch.cuda.synchronize() + local_ms = (time.perf_counter() - start) / iterations * 1000 + + print(f"\n Shape {shape}:") + print(f" PyTorch: {pt_ms:.4f} ms") + print(f" Local Triton: {local_ms:.4f} ms (speedup: {pt_ms/local_ms:.2f}x)") + + if hub_rms_fn_raw: + start = time.perf_counter() + for _ in range(iterations): + _call_hub(hub_rms_fn_raw, x, w, 1e-6) + torch.cuda.synchronize() + hub_ms = (time.perf_counter() - start) / iterations * 1000 + print(f" Hub kernel: {hub_ms:.4f} ms (speedup: {pt_ms/hub_ms:.2f}x)") + + +# ============================================================================= +# Part 4: Model Integration with Fallback +# ============================================================================= + +def demo_model_integration(hub_kernel): + """Demonstrate integrating kernels with models, with fallback.""" + print("\n" + "=" * 60) + print("Part 3: Model Integration with Fallback") + print("=" * 60) + + class SimpleModel(nn.Module): + def __init__(self, hidden_size=2048): + super().__init__() + self.norm = nn.RMSNorm(hidden_size) + self.linear = nn.Linear(hidden_size, hidden_size) + + def forward(self, x): + return self.linear(self.norm(x)) + + model = SimpleModel().cuda().to(torch.bfloat16) + + # Decide which RMSNorm to use + hub_rms_fn = None + if hub_kernel: + for fn_name in ('rms_norm', 'rms_norm_fn', 'rmsnorm'): + if hasattr(hub_kernel, fn_name): + hub_rms_fn = getattr(hub_kernel, fn_name) + break + + if hub_rms_fn: + def _hub_rmsnorm(x, w, eps): + try: + return hub_rms_fn(x, w, eps=eps) + except TypeError: + return hub_rms_fn(x, w, None, eps=eps) + rmsnorm_fn = _hub_rmsnorm + source = "Hub kernel" + else: + rmsnorm_fn = local_triton_rmsnorm + source = "Local Triton" + + print(f"\n1. Using {source} for RMSNorm") + + # Patch model + for name, module in model.named_modules(): + if isinstance(module, nn.RMSNorm): + raw_eps = getattr(module, 'eps', None) + eps = float(raw_eps) if raw_eps is not None else 1e-6 + + def make_forward(mod, epsilon, fn): + def forward(x): + return fn(x, mod.weight, epsilon) + return forward + + module.forward = make_forward(module, eps, rmsnorm_fn) + print(f" Patched: {name} (eps={eps})") + + # Test + print(f"\n2. Testing forward pass...") + x = torch.randn(2, 1024, 2048, dtype=torch.bfloat16, device="cuda") + with torch.inference_mode(): + y = model(x) + print(f" Input: {x.shape} -> Output: {y.shape}") + print(f" Success!") + + +# ============================================================================= +# Part 5: Publishing Info +# ============================================================================= + +def demo_publishing_info(): + """Show information about publishing kernels to Hub.""" + print("\n" + "=" * 60) + print("Part 4: Publishing Triton Kernels to Hub") + print("=" * 60) + + print(""" + For Triton kernels (best ROCm compatibility): + + 1. Create project structure: + my-triton-kernel/ + ├── build.toml + ├── kernel_src/ + │ └── rmsnorm.py # Triton kernel + └── torch-ext/ + ├── torch_binding.cpp + └── my_kernels/__init__.py + + 2. Configure build.toml with ROCm support: + [general] + name = "my_kernels" + backends = ["cuda", "rocm"] + + 3. Build and publish: + $ pip install kernel-builder + $ kernel-builder build + $ huggingface-cli upload my-username/my-kernel ./dist + + See: https://huggingface.co/docs/kernels + """) + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + print("=" * 60) + print("HuggingFace Kernels Integration Example (ROCm)") + print("=" * 60) + + check_environment() + + if not torch.cuda.is_available(): + print("GPU not available. This example requires an AMD GPU with ROCm.") + return + + hub_kernel = demo_basic_kernel_loading() + demo_benchmark(hub_kernel) + demo_model_integration(hub_kernel) + demo_publishing_info() + + print("\n" + "=" * 60) + print("Done!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/skills/rocm-kernels/scripts/transformers_injection_example.py b/skills/rocm-kernels/scripts/transformers_injection_example.py new file mode 100644 index 00000000..6bbc3ce7 --- /dev/null +++ b/skills/rocm-kernels/scripts/transformers_injection_example.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +""" +Minimal example: Inject custom Triton kernels into HuggingFace Transformers models on ROCm. + +This script demonstrates the essential pattern for integrating custom Triton kernels +with transformers models like LLaMA, Mistral, and Qwen on AMD GPUs. + +Key lessons: +1. Transformers RMSNorm modules always have weights (unlike some diffusers modules) +2. Use 'RMSNorm' substring match to catch LlamaRMSNorm, MistralRMSNorm, etc. +3. Check for 'variance_epsilon' (LLaMA) or 'eps' (others) for epsilon value +4. Use Flash Attention 2 for attention optimization instead of custom processors +5. ROCm: tl.libdevice/tl.math.tanh NOT available — use manual math + +Usage: + python scripts/transformers_injection_example.py +""" + +import os +import sys +import time + +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +# ============================================================================= +# Triton RMSNorm Kernel (ROCm compatible) +# ============================================================================= + +@triton.jit +def rmsnorm_fwd_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x, D, eps, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + out = x * rms_inv * w + + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) + + +def triton_rmsnorm(x, weight, eps=1e-6): + x_2d = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_2d) + M, D = x_2d.shape + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + rmsnorm_fwd_kernel[(M,)]( + x_2d, weight, out, x_2d.stride(0), D, float(eps), + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) + + +# ============================================================================= +# RMSNorm Module Patcher +# ============================================================================= + +def patch_rmsnorm_modules(model: nn.Module) -> int: + """ + Patch all RMSNorm modules to use Triton kernel on ROCm. + + Works with LlamaRMSNorm, MistralRMSNorm, Qwen2RMSNorm, etc. + Unlike diffusers, transformers RMSNorm always has weights. + """ + patched_count = 0 + + for name, module in model.named_modules(): + class_name = type(module).__name__ + + if 'RMSNorm' in class_name: + eps = getattr(module, 'variance_epsilon', None) + if eps is None: + eps = getattr(module, 'eps', 1e-6) + + has_weight = hasattr(module, 'weight') and module.weight is not None + + if has_weight: + def make_patched_forward(mod, epsilon): + def patched_forward(hidden_states): + return triton_rmsnorm(hidden_states, mod.weight, eps=epsilon) + return patched_forward + module.forward = make_patched_forward(module, eps) + patched_count += 1 + else: + print(f"WARNING: {name} has no weight, skipping") + + return patched_count + + +def inject_optimized_kernels(model) -> dict: + """Inject custom Triton kernels into a transformers model.""" + stats = {'rmsnorm_modules': 0} + stats['rmsnorm_modules'] = patch_rmsnorm_modules(model) + return stats + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + from transformers import AutoModelForCausalLM, AutoTokenizer + + print("=" * 60) + print("Transformers Triton Kernel Injection (ROCm)") + print("=" * 60) + + # Verify ROCm + print(f"\nROCm HIP version: {getattr(torch.version, 'hip', 'N/A')}") + print(f"GPU: {torch.cuda.get_device_name()}") + + model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + + print(f"\n1. Loading model: {model_id}...") + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + device_map="cuda" + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + rmsnorm_count = sum(1 for _, m in model.named_modules() if 'RMSNorm' in type(m).__name__) + print(f" Found {rmsnorm_count} RMSNorm modules") + + print("\n2. Injecting optimized Triton kernels...") + stats = inject_optimized_kernels(model) + print(f" RMSNorm modules patched: {stats['rmsnorm_modules']}") + + print("\n3. Verifying injection...") + x = torch.randn(1, 10, model.config.hidden_size, device='cuda', dtype=torch.bfloat16) + for name, module in model.named_modules(): + if 'RMSNorm' in type(module).__name__: + out = module(x) + print(f" RMSNorm forward pass: {x.shape} -> {out.shape}") + break + + print("\n4. Running generation test...") + prompt = "The capital of France is" + inputs = tokenizer(prompt, return_tensors="pt").to("cuda") + + with torch.inference_mode(): + _ = model.generate(**inputs, max_new_tokens=5, do_sample=False) + + num_tokens = 50 + start_time = time.perf_counter() + with torch.inference_mode(): + outputs = model.generate( + **inputs, + max_new_tokens=num_tokens, + do_sample=False, + pad_token_id=tokenizer.eos_token_id + ) + end_time = time.perf_counter() + + elapsed = end_time - start_time + tokens_per_second = num_tokens / elapsed + + print(f" Prompt: {prompt}") + print(f" Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") + print(f" Generated {num_tokens} tokens in {elapsed:.2f}s ({tokens_per_second:.1f} tokens/s)") + + print("\n" + "=" * 60) + print("Success! Custom Triton kernels are being used on ROCm.") + print("=" * 60) + + +if __name__ == "__main__": + main() From 6e47b0b4fb97f9fc543519439561f8df5b3cccb4 Mon Sep 17 00:00:00 2001 From: PhlimosJW <220233704@seu.edu.cn> Date: Fri, 13 Mar 2026 06:45:26 -0500 Subject: [PATCH 2/5] Remove custom-cuda-kernels-agent-skills.md --- skills/custom-cuda-kernels-agent-skills.md | 298 --------------------- 1 file changed, 298 deletions(-) delete mode 100644 skills/custom-cuda-kernels-agent-skills.md diff --git a/skills/custom-cuda-kernels-agent-skills.md b/skills/custom-cuda-kernels-agent-skills.md deleted file mode 100644 index e9405b6f..00000000 --- a/skills/custom-cuda-kernels-agent-skills.md +++ /dev/null @@ -1,298 +0,0 @@ ---- -title: "Custom Kernels for All from Codex and Claude" -thumbnail: /blog/assets/custom-cuda-kernels/meme.png -authors: -- user: burtenshaw -- user: sayakpaul -- user: ariG23498 -- user: evalstate ---- - - - -# Custom Kernels for All from Codex and Claude - -![oprah custom cuda kernels](assets/custom-cuda-kernels/meme.png) - -tl;dr: We built an agent skill that teaches coding agents how to write production CUDA kernels. Then we pointed Claude and Codex at two real targets: a **diffusers** pipeline and a **transformers** model. The agents produced working kernels for both, with correct PyTorch bindings and benchmarks, end to end. - -Writing CUDA kernels is hard. Writing CUDA kernels that correctly integrate with `transformers` and `diffusers` is harder. There are architecture-specific memory access patterns, vectorization strategies, warp shuffle reductions, and a dozen integration pitfalls that trip up even experienced developers. It is exactly the kind of specialized, high-stakes problem where agent skills shine. - -We gave coding agents the domain knowledge they need, like which GPU architecture to target, how to structure a kernel-builder project, when to use shared memory versus registers, and how to write PyTorch bindings. The agents did the rest. If you have used the [LLM training skill](https://huggingface.co/blog/hf-skills-training) or read [We Got Claude to Teach Open Models](https://huggingface.co/blog/upskill), the pattern will feel familiar: package domain expertise into a skill, point the agent at a problem, and let it work. - -## Why a skill for kernels? - -The [Kernel Hub](https://huggingface.co/blog/hello-hf-kernels) solved the distribution of custom hardware kernels. You can load pre-compiled kernels from the Hub with a single `get_kernel` call. No builds, no flags. However, someone still needs to **write the kernels**. That is the gap this skill fills. - -CUDA kernel development has a brutal surface area: - -- Hardware-specific optimization guides for each generation of GPU. H100, A100, and T4 each have different compute capabilities, shared memory sizes, and bandwidth profiles -- In Libraries, `diffusers` and `transformers` have different module hierarchies, normalization conventions, and integration patterns. Custom kernels need to be registered in PyTorch for `torch.compile` to recognize. -- For distribution, kernels can depend on CUDA, Pytorch, and Python versions creating massive environment matrices. - -This is domain knowledge that gets lost in documentation tabs and Stack Overflow answers. An agent skill packages it into context that loads on demand. - -First, let's show how to use the skill right away, then we'll dive into the details of how we benchmarked the kernels. - -## Installing the skill - -The skill ships with the `kernels` library. Install it into your coding agent with a single command: - -```shell -# we need to install kernels from main for this -pip install git+https://github.com/huggingface/kernels.git#subdirectory=kernels -kernels skills add cuda-kernels --claude -``` - -This drops the skill into `.claude/skills/cuda-kernels/` where Claude Code and Cursor pick it up automatically. For other agents: - -```shell -# Codex -kernels skills add cuda-kernels --codex - -# OpenCode -kernels skills add cuda-kernels --opencode - -# Custom destination -kernels skills add cuda-kernels --dest ./my-agent/skills/ - -# Install globally (available across all projects) -kernels skills add cuda-kernels --global - -# Overwrite an existing installation -kernels skills add cuda-kernels --claude --force -``` - -Once installed, prompt your agent: - -``` -Build a vectorized RMSNorm kernel for H100 targeting the Qwen3-8B model in transformers. -``` - -Or, you can go for something more open-ended: - -``` -Build an optimized attention kernel for H100 targeting the Qwen3-8B model in transformers. Benchmark it against the PyTorch baseline and validate improvements in end-to-end performance. -``` - -The agent can read the skill, select the right architecture parameters, generate the CUDA source, write the PyTorch bindings, set up `build.toml`, and create a benchmark script. - -If you're working on more complex kernels, or architecture-specific optimizations, that aren't covered in the skill, then the skill supplies the fundamental building blocks and patterns to get you started. We are also open to contributions on the [skill itself](https://github.com/huggingface/kernels/tree/main/.docs/skills). - -## What is in the skill - -The skill is roughly **550 tokens** of structured guidance plus reference scripts, GPU optimization guides, troubleshooting docs, and complete working examples. Agentic coding tools like Codex and Claude can read this and produce a working kernel project. - -It covers: - -- NVIDIA GPU Architecture-aware optimization for H100, A100, and T4 (compute capabilities, memory bandwidth, shared memory sizes, block sizing) -- Integration patterns for both `diffusers` and `transformers`, including the pitfalls specific to each library -- Kernel templates with vectorized memory access patterns for BF16, FP16, and FP32 -- Benchmarking workflows for both isolated kernel micro-benchmarks and end-to-end pipeline comparisons -- HuggingFace Kernel Hub integration via `get_kernel` for loading community kernels - -``` -.claude/skills/cuda-kernels/ -├── SKILL.md # Main instructions (~550 tokens) -├── scripts/ -│ ├── benchmark_example.py # End-to-end benchmark template -│ ├── benchmark_rmsnorm.py # Isolated kernel micro-benchmark -│ ├── ltx_kernel_injection_example.py # Diffusers integration pattern -│ ├── transformers_injection_example.py # Transformers integration pattern -│ └── huggingface_kernels_example.py # Kernel Hub integration -└── references/ - ├── diffusers-integration.md # Diffusers guide with pitfalls - ├── transformers-integration.md # Transformers guide - ├── huggingface-kernels-integration.md - ├── h100-optimization-guide.md - ├── a100-optimization-guide.md - ├── t4-optimization-guide.md - ├── kernel-templates.md - └── troubleshooting.md -``` - -When an agent loads this, it gets everything it needs to go from "write me an RMSNorm kernel" to a buildable, benchmarkable project. It will grep and glob the skill to find the relevant files and directories. So it's important to structure the skill in a way that is easy to find. - -The agent is instructed to generate kernels that conform to the templates in `references/kernel-templates.md` and produce a complete kernel project: - -``` -examples/your_model/ -├── kernel_src/ -│ └── rmsnorm.cu # Vectorized CUDA kernel -├── torch-ext/ -│ ├── your_kernels/__init__.py -│ └── torch_binding.cpp # PyTorch C++ bindings -├── benchmark_rmsnorm.py # Micro-benchmark script -├── build.toml # kernel-builder config -├── setup.py # pip install -e . -└── pyproject.toml -``` - -We tested this on two real targets. - -## Benchmarking the kernels: Diffusers (LTX-Video on H100) - -The agent built RMSNorm, RoPE 3D, GEGLU, and AdaLN kernels for [LTX-Video](https://huggingface.co/Lightricks/LTX-Video), a video generation pipeline from `diffusers`. The full example is at `examples/ltx_video/`. We optimized the RMSNorm kernel for H100. Both benchmarks were run on H100 80GB HBM3 at precision BFloat16. - -If you want to check out the generated kernel, got to [this example](https://github.com/burtenshaw/kernel-skill/tree/main/examples/ltx_video) - -### Isolated RMSNorm benchmark - -First, we compare the isolated RMSNorm kernel performance against the PyTorch baseline. This is the main speedup in the optimized pipeline. - -![isolated rmsnorm benchmark ltx-video](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kernels-skill-benchmark/rmsnorm_ltx_video.png) - -
-Table - -| Shape | Custom (ms) | PyTorch (ms) | Speedup | -| :---- | :---: | :---: | :---: | -| [1x1024x2048] | 0.039 | 0.064 | **1.64x** | -| [2x1024x2048] | 0.040 | 0.073 | **1.82x** | -| [4x1024x2048] | 0.052 | 0.093 | **1.78x** | -| [1x4096x2048] | 0.052 | 0.093 | **1.79x** | -| [2x4096x3072] | 0.102 | 0.209 | **2.04x** | -| [1x8192x2048] | 0.083 | 0.150 | **1.81x** | -| [4x4096x3072] | 0.173 | 0.393 | **2.26x** | - -**Average speedup: 1.88x** and a bandwidth efficiency: 34.7% of H100 theoretical (3,350 GB/s) - -
- -### End-to-end video generation (49 frames, 30 steps, H100 80GB) - -Next, we compare the end-to-end video generation performance of the optimized kernels against the baseline (no compile) and the `torch.compile` baseline. - -![e2e benchmark ltx-video](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kernels-skill-benchmark/e2e_ltx_video.png) - -
-Table - -| Configuration | Time (s) | it/s | Speedup | -| :---- | :---: | :---: | :---: | -| Baseline (no compile) | 2.87 | 12.58 | 1.00x | -| **Generated Optimized Kernels** | 2.70 | 13.52 | **1.06x** | -| Baseline + torch.compile | 2.14 | 19.05 | 1.34x | -| Optimized + torch.compile | 2.01 | 18.45 | 1.43x | - -
- -RMSNorm accounts for ~5% of total compute in LTX-Video. The remaining time is spent in attention, linear projections, and VAE decode. The 6% end-to-end speedup from a single kernel type is consistent with that profile. - -## Benchmarking the kernels: Transformers (Qwen3-8B on H100) - -The agent built an RMSNorm kernel for [Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B), a large language model from `transformers` with 65 RMSNorm modules across 32 layers. The full example is at `examples/qwen3_8b/`. We optimized the RMSNorm kernel for H100. Both benchmarks were run on H100 80GB HBM3 at precision BFloat16. - -If you want to explore the kernel, check it out [here.](https://github.com/burtenshaw/kernel-skill/tree/main/examples/qwen3_8b) - -### Isolated RMSNorm benchmark - -Once again, we compare the isolated RMSNorm kernel performance against the PyTorch baseline. - -![isolated rmsnorm benchmark qwen3-8b](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kernels-skill-benchmark/rmsnorm_qwen3.png) - -**Average speedup: 1.94x** and a bandwidth efficiency: 22.3% of H100 theoretical (3,350 GB/s) - -
-Table - -| Shape | Custom (ms) | PyTorch (ms) | Speedup | -| :---- | :---: | :---: | :---: | -| [1x128x4096] | 0.040 | 0.062 | **1.58x** | -| [1x512x4096] | 0.038 | 0.064 | **1.69x** | -| [1x1024x4096] | 0.037 | 0.071 | **1.90x** | -| [1x2048x4096] | 0.045 | 0.091 | **2.03x** | -| [1x4096x4096] | 0.071 | 0.150 | **2.12x** | -| [4x512x4096] | 0.056 | 0.093 | **1.67x** | -| [8x256x4096] | 0.045 | 0.092 | **2.06x** | -| [1x8192x4096] | 0.109 | 0.269 | **2.47x** | - -
- -Speedup scales with sequence length: 1.58x at 128 tokens, 2.47x at 8192 tokens. For long-context inference, the custom kernel roughly halves RMSNorm latency. - -## Publishing your kernel to the Hub - -The agent gives you a working kernel. The [Kernel Hub](https://huggingface.co/kernels-community) lets you share it so anyone can load it without compilation. Here is the full path from agent output to published kernel. - -### 1. Verify the project structure - -The agent produces a project that already follows the [kernel-builder](https://huggingface.co/docs/kernels/en/builder/writing-kernels) layout: - -``` -your_kernel/ -├── build.toml # Build configuration -├── kernel_src/ -│ └── rmsnorm.cu # CUDA kernel source -└── torch-ext/ - ├── torch_binding.cpp # Registers Torch ops - └── your_kernels/ - └── __init__.py # Python API wrapping _ops -``` - -The `build.toml` tells `kernel-builder` what to build. The agent generates this for you, including the correct `cuda-capabilities` for your target GPU: - -``` -[general] -name = "your_kernels" -backends = ["cuda"] - -[torch] -src = ["torch-ext/torch_binding.cpp"] - -[kernel.rmsnorm] -backend = "cuda" -src = ["kernel_src/rmsnorm.cu"] -depends = ["torch"] -cuda-capabilities = ["9.0"] # H100 -``` - -### 2. Build all variants with Nix - -Kernel Hub kernels must support all recent PyTorch and CUDA configurations. The kernel-builder Nix flake handles this automatically. Copy the [example `flake.nix`](https://github.com/huggingface/kernels/blob/main/builder/examples/relu/flake.nix) into your project and run: - -```shell -nix flake update -nix run .#build-and-copy -L -``` - -This builds the kernel for every required PyTorch/CUDA variant and places the results in `build/`. For faster builds, enable the HuggingFace Nix cache: - -```shell -nix run nixpkgs#cachix -- use huggingface -``` - -### 3. Create a Hub repo and push - -Create a model repo on the Hub and upload the built kernel: - -```shell -huggingface-cli repo create your-org/your-kernel --type model -huggingface-cli upload your-org/your-kernel ./build -``` - -### 4. Others load it in one line - -Once published, anyone can use your kernel with zero compilation: - -```py -from kernels import get_kernel - -rmsnorm = get_kernel("your-org/your-kernel") -``` - -`get_kernel` detects the user's Python, PyTorch, and CUDA versions and downloads the matching pre-compiled binary. No builds, no flags, typically ready in seconds. - -The skill and the Hub are complementary. The skill handles development. The Hub handles distribution. Build a kernel with the skill, validate it with the benchmark scripts, publish it to the Hub, and it becomes a one-liner for everyone else. - -## Conclusion - -We built an agent skill that teaches coding agents how to write production CUDA kernels. Then we pointed Claude and Codex at two real targets: a **diffusers** pipeline and a **transformers** model. The agents produced working kernels for both, with correct PyTorch bindings and benchmarks, end to end. We benchmarked the kernels and found that the optimized kernels can provide a speedup in both isolated and end-to-end performance. - -## Resources - -- [CUDA Kernels Skill in `kernels`](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels) -- [HuggingFace Kernel Hub Blog](https://huggingface.co/blog/hello-hf-kernels) -- [We Got Claude to Fine-Tune an Open Source LLM](https://huggingface.co/blog/hf-skills-training) -- [We Got Claude to Teach Open Models](https://huggingface.co/blog/upskill) -- [HuggingFace Kernels Community](https://huggingface.co/kernels-community) From 1caf811a58906811eb9e270f8a16eb344cfa3e64 Mon Sep 17 00:00:00 2001 From: PhlimosJW <220233704@seu.edu.cn> Date: Fri, 13 Mar 2026 06:49:10 -0500 Subject: [PATCH 3/5] Remove kernel-agent-knowledge-base and skill-evaluation-methodology --- .../references/kernel-agent-knowledge-base.md | 140 ---------- .../skill-evaluation-methodology.md | 251 ------------------ 2 files changed, 391 deletions(-) delete mode 100644 skills/rocm-kernels/references/kernel-agent-knowledge-base.md delete mode 100644 skills/rocm-kernels/references/skill-evaluation-methodology.md diff --git a/skills/rocm-kernels/references/kernel-agent-knowledge-base.md b/skills/rocm-kernels/references/kernel-agent-knowledge-base.md deleted file mode 100644 index 25ea2cb2..00000000 --- a/skills/rocm-kernels/references/kernel-agent-knowledge-base.md +++ /dev/null @@ -1,140 +0,0 @@ -# Kernel-Agent 项目知识提取 - -本文档记录从 `/home/jixiong/kernel-agent` 项目中提取的核心知识,作为 ROCm kernel skills 的基础。 - -## 1. 项目概况 - -kernel-agent 是一个 **LLM 驱动的 Triton/Helion kernel 生成与评测工作流**,专门面向 AMD ROCm 平台。 - -| 组件 | 说明 | -|------|------| -| **后端** | Triton (主要), Helion (实验性) | -| **目标平台** | AMD GPU (ROCm) | -| **评测基准** | KernelBench (Level 1-7) | -| **工作流** | 生成 → 执行 → 正确性检查 → 性能优化 → 迭代 | -| **LLM 提供商** | OpenAI, Anthropic, Google, AMD on-prem | - -## 2. AMD GPU 硬件参数 (来自 amd_gpu_specs.py) - -### MI355X (gfx950) - CDNA3+ - -| 参数 | 值 | 优化影响 | -|------|-----|---------| -| GPU 架构 | CDNA3+ (gfx950) | 编译目标 | -| GPU 显存 | 288GB HBM3e | 大模型无压力 | -| 内存带宽 | 8 TB/s | 内存受限 kernel 的上限 | -| XCD 数量 | **32** | XCD Swizzle 必须用 NUM_XCDS=32 | -| CU 总数 | 256 | Grid 大小的倍数 | -| 每 XCD 的 CU | 8 | XCD 间负载均衡 | -| LDS/CU | **160 KB** | 比 MI300X 大 2.5 倍 | -| L2 Cache | 256 MB | 大型共享缓存 | -| Wavefront | 64 | CDNA 固定 Wave64 | -| MFMA 指令 | 16x16 (最优), 32x32 | matrix_instr_nonkdim=16 | -| FP8 格式 | float8_e4m3fn (OCP) | 与 MI300X 不同! | -| 最优 num_warps | 4-16 | autotune 范围 | -| 最优 num_stages | 2-3 | 避免 LDS 溢出 | -| 最优 BLOCK_SIZE (1D) | 1024-8192 | 比 MI300X 更大 | -| 最优 BLOCK_M/N (2D) | 128-256 | GEMM tile 大小 | - -### R9700 (RDNA4, gfx1201) - -| 参数 | 值 | 优化影响 | -|------|-----|---------| -| GPU 架构 | RDNA4 (gfx1201) | Wave32 模式 | -| Wavefront | **32** | 归约代码需要不同偏移 | -| CU 总数 | 64 | Grid 大小的倍数 | -| LDS/CU | 64 KB | 标准大小 | -| L1 Cache | 32 KB | 每 CU 私有 | -| L2 Cache | 8 MB | 全 CU 共享 | -| L3 Cache | 64 MB | 末级缓存 | -| Cacheline | **256 B** | 比 RDNA3 更大,需更严格对齐 | -| Max Threads/Block | 1024 | 32 waves × 32 threads | -| Max Threads/CU | 2048 | 64 waves × 32 threads | -| FP16 矩阵 TFLOPS | 191 | 矩阵指令 | -| FP8 矩阵 TFLOPS | 383 | 推理加速 | -| 矩阵核心 | 有限 (无 FP8 MFMA) | 不支持高级矩阵指令 | - -## 3. 关键优化知识 (来自 prompt_constructor.py) - -### MI355X 必须的优化 - -1. **XCD Swizzle (GEMM 必须)**: NUM_XCDS=32,将 block ID 映射到 32 个 XCD -2. **L2 Cache Grouping**: GROUP_M=8 或 16,提高 L2 缓存命中率 -3. **MFMA 16x16**: matrix_instr_nonkdim=16 -4. **环境变量**: `TRITON_HIP_USE_BLOCK_PINGPONG=1`, `TRITON_HIP_USE_ASYNC_COPY=1` -5. **num_stages=2-3**: 避免 LDS 溢出 - -### Triton on ROCm 禁忌 - -- **禁止** `tl.libdevice.*` (CUDA 专属) -- **禁止** `tl.tanh` (不支持,用 `(exp(2x)-1)/(exp(2x)+1)`) -- **禁止** `break/continue` (用 `tl.where` 替代) -- **禁止** Python `min()/max()` (用 `tl.minimum()/tl.maximum()`) -- **必须** 用 `tl.float32` 做累加器 -- **必须** 对 exp/log/sqrt/rsqrt/除法 转换为 FP32 - -### Autotune 配置 - -#### 逐元素 (1D) - -**MI355X**: BLOCK_SIZE = [1024, 2048, 4096, 4096, 8192, 16384] -**R9700**: BLOCK_SIZE = [256, 512, 1024] (更小) - -#### GEMM (2D) - -**MI355X**: BLOCK_M/N = [128-256], BLOCK_K = [32-64], GROUP_M = 8 - -## 4. 问题分类体系 (来自 classify_problem) - -| 类别 | 匹配模式 | 典型算子 | -|------|---------|---------| -| elementwise | relu, gelu, swish, silu, sigmoid, tanh, elu... | 激活函数 | -| softmax | softmax, logsoftmax | Softmax 变体 | -| norm | layernorm, batchnorm, rmsnorm, groupnorm... | 归一化 | -| pooling | pool | 池化操作 | -| reduction | sum_reduction, mean_reduction, max_reduction... | 归约操作 | -| attention | attention, multihead | 注意力机制 | -| matvec | matrix_vector, matvec | 矩阵-向量乘 | -| batched_gemm | batch, bmm | 批量矩阵乘 | -| gemm_2d | matmul, gemm, mm_ | 2D 矩阵乘 | - -## 5. KernelBench 测试结果关键发现 - -### 表现优秀的类别 (在 kernel-agent 上) - -| 类别 | 最佳 Speedup | 代表算子 | -|------|-------------|---------| -| Reduction | 5.00x | Sum reduction | -| Pooling | 5.16x | Average Pooling 3D | -| 激活函数 | 2.94x | Softsign, Softplus, Swish | -| 归一化 | 1.73x | LayerNorm | -| 特殊 GEMM | 1.98x | 对角矩阵乘 | - -### 需要重点优化的类别 - -| 类别 | 当前 Speedup | 根本原因 | -|------|-------------|---------| -| 大 K GEMM | 0.04x | 寄存器压力、内存访问不优 | -| BatchNorm | 0.04x | HIP 运行时错误、同步问题 | -| 对称/三角矩阵乘 | 0.08-0.20x | 线程利用率低 | -| Argmax/Argmin | FAILED | Triton API 限制 | -| 融合算子 | 0.32x (平均) | 多操作组合复杂度 | - -### 常见错误类型 - -1. **HIP Runtime Error**: GPU 内存访问冲突 -2. **精度问题**: FP16 累积误差 -3. **program_id 限制**: 3D Grid 映射 -4. **tl.store() kwarg 错误**: Triton API 差异 -5. **max_contiguous 错误**: 内存访问模式 - -## 6. 性能分析工具链 - -| 工具 | 用途 | -|------|------| -| `rocprof` / `rocprofv3` | GPU kernel profiling | -| `rocm-bandwidth-test` | 内存带宽测试 | -| `rocminfo` | GPU 设备信息 | -| `rocm-smi` | GPU 状态监控 | -| `omniperf` | 全面性能分析 | -| `omnitrace` | 系统级追踪 | diff --git a/skills/rocm-kernels/references/skill-evaluation-methodology.md b/skills/rocm-kernels/references/skill-evaluation-methodology.md deleted file mode 100644 index 7e9548a6..00000000 --- a/skills/rocm-kernels/references/skill-evaluation-methodology.md +++ /dev/null @@ -1,251 +0,0 @@ -# Skill 评估与优化方法论 - -本文档描述如何系统性地评估和优化 ROCm kernel skills 的质量。 - -## 1. Skill 质量评估框架 - -### 1.1 评估维度 - -| 维度 | 权重 | 衡量标准 | 评估方法 | -|------|------|---------|---------| -| **正确性** | 40% | AI 生成的 kernel 能通过正确性测试 | KernelBench accuracy ratio | -| **性能** | 30% | 生成 kernel 的 speedup 相比 PyTorch baseline | KernelBench speedup ratio | -| **可运行率** | 20% | kernel 能成功编译并运行 | KernelBench runnable ratio | -| **触发准确率** | 10% | AI 在正确场景下使用了正确的 skill | 人工评审 | - -### 1.2 KernelBench 自动评估 - -```bash -# 运行 Level 1 全量评估 -python -m kernel_agent.evaluation.kernelbench_success_evaluator \ - --config examples/workflows/evaluation/kernelbench/config_eval.yml \ - --dataset datasets/kernelbench/kernel_bench_level_1.json - -# 输出指标 -# - runnable_ratio: 可运行比率 -# - accuracy_ratio: 正确性比率 -# - speed_ratio: 达到 speedup > 1.0x 的比率 -``` - -### 1.3 逐类别评估 - -对每个 operator 类别单独评估: - -```bash -# Level 1 Selected (按类别) -python -m kernel_agent.evaluation.kernelbench_success_evaluator \ - --config examples/workflows/evaluation/kernelbench/config_eval_level1_selected.yml -``` - -**评估记录表模板:** - -| 类别 | 算子数 | 可运行 | 正确 | Speedup>1x | 平均 Speedup | Skill 版本 | -|------|--------|--------|------|-----------|-------------|-----------| -| GEMM | 18 | ?/18 | ?/18 | ?/18 | ?x | v0.1 | -| Elementwise | 14 | ?/14 | ?/14 | ?/14 | ?x | v0.1 | -| Normalization | 8 | ?/8 | ?/8 | ?/8 | ?x | v0.1 | -| ... | ... | ... | ... | ... | ... | ... | - -## 2. 迭代优化流程 - -### 2.1 PDCA 循环 - -``` -Plan → 识别性能瓶颈或失败模式 -Do → 修改 SKILL.md 或 references 中的指引 -Check → 重新运行 KernelBench 评估 -Act → 确认改进,合并到 skill;或回滚 -``` - -### 2.2 具体优化步骤 - -#### Step 1: 收集失败案例 - -```python -# 从评估结果中提取失败案例 -# 分析每个失败的原因类型: -# - compilation_error: 编译错误 → 修改模板代码 -# - runtime_error: 运行时错误 → 添加 troubleshooting 条目 -# - accuracy_error: 精度问题 → 修改精度相关指引 -# - performance_low: 性能不达标 → 添加优化策略 -``` - -#### Step 2: 根因分析 - -| 失败类型 | 检查项 | 修改位置 | -|---------|--------|---------| -| tl.libdevice 错误 | 是否遗漏了 ROCm 禁忌 | SKILL.md "Critical ROCm Constraints" | -| LDS 溢出 | num_stages 建议是否正确 | GPU optimization guide | -| GEMM 极慢 | 是否缺少 XCD swizzle | kernel-templates.md Template 5 (GEMM with XCD Swizzle) | -| 精度不达标 | FP32 累加是否到位 | kernel-templates.md 所有模板 | -| Python min/max | 是否提醒了 tl.minimum | troubleshooting.md | - -#### Step 3: 修改 Skill 内容 - -按照失败原因修改对应文件: -- **模式错误** → 修改 `kernel-templates.md` -- **知识缺失** → 修改 `SKILL.md` 或 GPU 优化指南 -- **新陷阱** → 添加到 `troubleshooting.md` -- **性能数据过时** → 更新 benchmark 表格 - -#### Step 4: A/B 测试 - -```bash -# Version A: 原始 skill -cp -r rocm-kernels rocm-kernels-v1 - -# Version B: 修改后的 skill -# (直接编辑 rocm-kernels/) - -# 分别运行评估,对比结果 -# Compare: runnable_ratio, accuracy_ratio, speed_ratio -``` - -## 3. 与 AI 协作优化 Skill 的方法 - -### 3.1 "生成-评测-反馈" 循环 - -``` -你 (编写 Skill) - ↓ -AI (使用 Skill 生成 kernel) - ↓ -KernelBench (评测 kernel) - ↓ -你 (分析失败,改进 Skill) - ↓ -(重复) -``` - -### 3.2 具体协作方式 - -#### 方式 1: 让 AI 帮你分析失败 - -``` -提示词: "这是 KernelBench Level 1 的评测结果 [粘贴结果]。 -请分析以下失败案例的根因,并建议修改 SKILL.md 的哪些部分。" -``` - -#### 方式 2: 让 AI 帮你生成测试用例 - -``` -提示词: "根据 rocm-kernels skill,为 GEMM 类别生成一个测试 kernel, -目标是 4096x4096 方阵乘法,使用 MI355X 优化。" -``` - -然后手动运行测试,看结果是否符合预期。 - -#### 方式 3: 让 AI 帮你补充 Skill 内容 - -``` -提示词: "BatchNorm 在 AMD GPU 上的评测结果是 0.04x(极差)。 -错误类型是 HIP 运行时错误。请帮我分析原因,并在 troubleshooting.md -中添加对应的解决方案。" -``` - -### 3.3 Skill 版本管理 - -``` -rocm-kernels/ -├── SKILL.md # 主文件 (跟踪版本号) -├── CHANGELOG.md # 变更日志 (每次优化后记录) -├── references/ -└── scripts/ -``` - -**CHANGELOG 格式:** - -```markdown -## v0.2 (2026-03-15) -- Added XCD swizzle pattern for GEMM (fixed 0.3x → 1.1x speedup) -- Added tanh workaround for ROCm -- Fixed LDS overflow guidance for MI355X - -## v0.1 (2026-03-10) -- Initial version with basic templates -``` - -## 4. KernelBench 分类 Skill 开发计划 - -### 4.1 开发优先级 - -| 优先级 | 类别 | 原因 | 预期 Skill 文件 | -|--------|------|------|----------------| -| **P0** | Elementwise | 最多算子、最容易成功 | `elementwise-skill.md` | -| **P0** | GEMM | 最高影响、最频繁使用 | `gemm-skill.md` | -| **P1** | Normalization | 常用、中等难度 | `normalization-skill.md` | -| **P1** | Reduction | 常用、有成熟模式 | `reduction-skill.md` | -| **P1** | Softmax | Attention 基础 | `softmax-skill.md` | -| **P2** | Pooling | 中等频率、Grid 映射挑战 | `pooling-skill.md` | -| **P2** | Attention | 高复杂度、高价值 | `attention-skill.md` | -| **P2** | Fused | 多操作组合 | `fused-skill.md` | - -### 4.2 每个 Skill 文件结构 - -```markdown ---- -name: rocm-{category}-kernel -description: "..." ---- - -# {Category} Kernel Skill - -## Pattern Overview -[核心算法模式] - -## Template Code -[可复制的完整代码] - -## Autotune Configurations -[MI355X 和 R9700 的推荐配置] - -## Common Mistakes -[该类别特有的陷阱] - -## Benchmark Results -[该类别的已知性能数据] -``` - -### 4.3 评估指标目标 - -| 类别 | 可运行率目标 | 正确率目标 | 平均 Speedup 目标 | -|------|------------|-----------|------------------| -| Elementwise | >95% | >90% | >1.5x | -| GEMM | >80% | >70% | >0.8x | -| Normalization | >85% | >80% | >1.0x | -| Reduction | >90% | >85% | >1.5x | -| Softmax | >85% | >80% | >1.0x | -| Pooling | >70% | >60% | >1.0x | -| Attention | >60% | >50% | >0.8x | -| Fused | >50% | >40% | >0.8x | - -## 5. 持续监控 - -### 5.1 定期评估 - -- **每周**: 运行一次 Level 1 全量评估 -- **每次 Skill 修改后**: 运行修改类别的评估 -- **每月**: 运行 Level 1+2 全量评估 - -### 5.2 回归检测 - -修改 Skill 后,确保不会导致其他类别的性能下降: - -```bash -# 修改前: 保存基线 -python eval.py --save-baseline baseline_v1.json - -# 修改后: 对比 -python eval.py --compare-baseline baseline_v1.json -# 任何类别的指标下降 > 5% 需要调查 -``` - -## 6. 总结:优化 Skill 的核心原则 - -1. **数据驱动**: 每次修改都基于 KernelBench 评估数据 -2. **分类优化**: 按 operator 类别独立迭代 -3. **最小修改**: 每次只改一个点,方便归因 -4. **版本记录**: 每次修改记录 CHANGELOG -5. **A/B 测试**: 对比修改前后的评测结果 -6. **渐进式**: 先覆盖高优先级类别(Elementwise → GEMM → Norm) -7. **陷阱文档化**: 每个新发现的坑都写入 troubleshooting.md From 364d8856ffac4c7ffb52149ffc7ca92f3cfc1b56 Mon Sep 17 00:00:00 2001 From: M4jupitercannon Date: Tue, 17 Mar 2026 03:53:08 +0000 Subject: [PATCH 4/5] Add R9700 (RDNA4) benchmark results and fix eps constexpr issue --- skills/rocm-kernels/SKILL.md | 53 +- skills/rocm-kernels/scripts/benchmark_e2e.py | 4 +- .../rocm-kernels/scripts/benchmark_kernels.py | 536 ------------------ 3 files changed, 45 insertions(+), 548 deletions(-) delete mode 100644 skills/rocm-kernels/scripts/benchmark_kernels.py diff --git a/skills/rocm-kernels/SKILL.md b/skills/rocm-kernels/SKILL.md index f1d9ad58..3ea49a2b 100644 --- a/skills/rocm-kernels/SKILL.md +++ b/skills/rocm-kernels/SKILL.md @@ -83,10 +83,10 @@ python scripts/benchmark_e2e.py --mode all --output-json results.json ## Supported Hardware -| GPU | Architecture | Wave Size | LDS/CU | Mem BW | Key Feature | -|-----|-------------|-----------|--------|--------|-------------| -| **MI355X** | CDNA3+ (gfx950) | Wave64 | **160 KB** | 8 TB/s | 32 XCDs, XCD Swizzle for GEMM | -| **R9700** | RDNA4 (gfx1201) | **Wave32** | 64 KB | ~608 GB/s | 256B cacheline, inference-focused | +| GPU | Architecture | Wave Size | LDS/CU | Mem BW | Key Feature | Verified | +|-----|-------------|-----------|--------|--------|-------------|:--------:| +| **MI355X** | CDNA3+ (gfx950) | Wave64 | **160 KB** | 8 TB/s | 32 XCDs, XCD Swizzle for GEMM | Yes | +| **R9700** | RDNA4 (gfx1201) | **Wave32** | 64 KB | ~608 GB/s | 256B cacheline, inference-focused | Yes | > See [MI355X guide](references/mi355x-optimization-guide.md) | [R9700 guide](references/r9700-optimization-guide.md) @@ -393,6 +393,38 @@ RMSNorm bandwidth utilization: 3554 GB/s (MI355X theoretical: 8 TB/s, ~44%). **Key finding**: MI355X Triton E2E speedup (22%) is significantly higher than H100 CUDA reference (6%), because MI355X's default PyTorch RMSNorm path has more room for optimization. +### Micro-benchmark Results (R9700, BF16) + +| Kernel | Avg Speedup | Best Config Speedup | Status | +|--------|:-----------:|:-------------------:|:------:| +| **RMSNorm** | **2.90x** | 3.97x ([1×8192×2048]) | PASS | +| **RoPE 3D** | **2.09x** | 2.38x ([1×1024×16×64]) | PASS | +| **GEGLU** | **1.69x** | 1.93x ([2×1024×8192]) | PASS | +| **AdaLN** | **3.00x** | 3.67x ([4×4096×3072]) | PASS | + +RMSNorm bandwidth utilization: 483 GB/s (R9700 theoretical: ~608 GB/s, **~79%**). + +R9700 speedups are higher than MI355X because PyTorch's default RDNA4 backend is less mature, leaving more room for Triton optimization. The bandwidth utilization (79%) is also significantly better than MI355X (44%). + +### End-to-End LTX-Video (R9700, 25 frames, 30 steps) + +| Mode | Time (s) | Per Step (s) | Peak Mem (GB) | Speedup | +|------|:--------:|:------------:|:-------------:|:-------:| +| baseline | 6.89 | 0.230 | 18.58 | 1.00x | +| **triton** | **6.06** | **0.202** | **18.58** | **1.14x** | +| torch.compile | 5.07 | 0.169 | 18.58 | 1.36x | + +### R9700 Additional Validation + +| Test | Result | +|------|--------| +| Transformers injection (TinyLlama 1.1B) | PASS — 45 RMSNorm patched, 99.9 tokens/s | +| HuggingFace Kernels Hub integration | PASS — Hub kernel loads and runs on ROCm | +| Local Triton vs Hub kernel (small shape) | Local **5.92x** vs Hub 1.27x (lower launch overhead) | +| Local Triton vs Hub kernel (large shape) | Local 3.59x vs Hub 3.57x (comparable) | +| num_warps sweep (2/4/8/16/32) | Default heuristic (4/8/16) is near-optimal; nw=32 always worst | +| rocprof kernel fusion analysis | Triton fuses 4 PyTorch kernels (pow+mean+rsqrt+mul) into 1 | + ### CUDA Reference (H100, for comparison) | Shape | Custom (ms) | PyTorch (ms) | Speedup | @@ -404,12 +436,12 @@ H100 E2E: ~6% (RMSNorm is ~5% of total compute). ### Optimization Targets -| Kernel | Current | Target | Priority | -|--------|:-------:|:------:|:--------:| -| RMSNorm | 1.71x | >2x | P0 — increase bandwidth util (44%→60%+) | -| AdaLN | 2.22x | >2.5x | P1 — already strong | -| GEGLU | 1.43x | >1.5x | P1 — tanh overhead | -| RoPE 3D | 1.21x | >1.5x | P2 — small head_dim launch overhead | +| Kernel | MI355X | R9700 | Target | Priority | +|--------|:------:|:-----:|:------:|:--------:| +| RMSNorm | 1.71x | 2.90x | >3x (R9700) | P0 — MI355X bandwidth util (44%→60%+) | +| AdaLN | 2.22x | 3.00x | >3.5x (R9700) | P1 — already strong on both | +| GEGLU | 1.43x | 1.69x | >2x | P1 — tanh overhead | +| RoPE 3D | 1.21x | 2.09x | >2.5x (R9700) | P2 — small head_dim launch overhead | ## Common Issues on ROCm @@ -440,6 +472,7 @@ rocminfo | grep -E "Name|Compute Unit|Wavefront" ### Benchmark & Test Scripts - [benchmark_kernels.py](scripts/benchmark_kernels.py) - Micro-benchmark all 4 kernels (correctness + perf + bandwidth) - [benchmark_e2e.py](scripts/benchmark_e2e.py) - End-to-end LTX-Video pipeline benchmark (baseline vs Triton vs compile) +- [sweep_num_warps.py](scripts/sweep_num_warps.py) - num_warps sweep for R9700 Wave32 optimization - [ltx_kernel_injection_example.py](scripts/ltx_kernel_injection_example.py) - Minimal diffusers injection example - [transformers_injection_example.py](scripts/transformers_injection_example.py) - Minimal transformers injection example - [huggingface_kernels_example.py](scripts/huggingface_kernels_example.py) - HuggingFace Kernels Hub integration example diff --git a/skills/rocm-kernels/scripts/benchmark_e2e.py b/skills/rocm-kernels/scripts/benchmark_e2e.py index f9b00c6c..0a441914 100644 --- a/skills/rocm-kernels/scripts/benchmark_e2e.py +++ b/skills/rocm-kernels/scripts/benchmark_e2e.py @@ -45,7 +45,7 @@ def _rmsnorm_kernel( x_ptr, weight_ptr, out_ptr, stride_x_row, D, - eps: tl.constexpr, + eps, HAS_WEIGHT: tl.constexpr, BLOCK_D: tl.constexpr, ): @@ -76,7 +76,7 @@ def triton_rmsnorm(x, weight=None, eps=1e-6): BLOCK_D = triton.next_power_of_2(D) num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) _rmsnorm_kernel[(M,)]( - x_flat, weight, out, x_flat.stride(0), D, eps, has_weight, + x_flat, weight, out, x_flat.stride(0), D, float(eps), has_weight, BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, ) return out.view_as(x) diff --git a/skills/rocm-kernels/scripts/benchmark_kernels.py b/skills/rocm-kernels/scripts/benchmark_kernels.py deleted file mode 100644 index bb72524a..00000000 --- a/skills/rocm-kernels/scripts/benchmark_kernels.py +++ /dev/null @@ -1,536 +0,0 @@ -#!/usr/bin/env python3 -""" -Micro-benchmark for all 4 Triton kernels on ROCm: RMSNorm, RoPE 3D, GEGLU, AdaLN. - -Measures: - 1. Correctness vs PyTorch reference - 2. Latency (custom vs baseline, warmup + averaged) - 3. Memory bandwidth utilization - -Usage: - python benchmark_kernels.py - python benchmark_kernels.py --kernel rmsnorm - python benchmark_kernels.py --kernel rope - python benchmark_kernels.py --kernel geglu - python benchmark_kernels.py --kernel adaln - python benchmark_kernels.py --dtype float16 -""" -import os -os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' -os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' - -import argparse -import time -from typing import Tuple - -import torch -import triton -import triton.language as tl - - -# ============================================================================ -# Kernel 1: RMSNorm -# ============================================================================ -# CRITICAL: BLOCK_D must be >= D (hidden dimension). -# Using autotune with fixed BLOCK_D configs is WRONG because autotune may -# pick BLOCK_D < D, causing only partial row processing. -# Fix: compute BLOCK_D = next_power_of_2(D) dynamically in the Python wrapper. - -@triton.jit -def rmsnorm_fwd_kernel( - x_ptr, weight_ptr, out_ptr, - stride_x, D, - eps: tl.constexpr, - HAS_WEIGHT: tl.constexpr, - BLOCK_D: tl.constexpr, -): - row = tl.program_id(0) - col_offsets = tl.arange(0, BLOCK_D) - mask = col_offsets < D - - x = tl.load(x_ptr + row * stride_x + col_offsets, mask=mask, other=0.0).to(tl.float32) - variance = tl.sum(x * x, axis=0) / D - rms_inv = tl.rsqrt(variance + eps) - - if HAS_WEIGHT: - w = tl.load(weight_ptr + col_offsets, mask=mask, other=1.0).to(tl.float32) - result = x * rms_inv * w - else: - result = x * rms_inv - - tl.store(out_ptr + row * stride_x + col_offsets, result.to(x.dtype), mask=mask) - - -def triton_rmsnorm(x, weight=None, eps=1e-6): - orig_shape = x.shape - x_2d = x.contiguous().view(-1, x.shape[-1]) - out = torch.empty_like(x_2d) - M, D = x_2d.shape - has_weight = weight is not None - if not has_weight: - weight = torch.empty(0, device=x.device) - - BLOCK_D = triton.next_power_of_2(D) - num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) - rmsnorm_fwd_kernel[(M,)]( - x_2d, weight, out, - x_2d.stride(0), D, eps, has_weight, - BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, - ) - return out.view(orig_shape) - - -def pytorch_rmsnorm(x, weight=None, eps=1e-6): - variance = x.pow(2).mean(dim=-1, keepdim=True) - out = x * torch.rsqrt(variance + eps) - if weight is not None: - out = out * weight - return out - - -# ============================================================================ -# Kernel 2: RoPE 3D -# ============================================================================ -# CRITICAL: cos/sin have shape [seq_len, head_dim], NOT [batch*seq_len, ...]. -# When grid is (batch * seq_len, num_heads), we must use pid_s % seq_len -# to index into cos/sin to avoid out-of-bounds access for batch > 1. - -@triton.jit -def rope_3d_fwd_kernel( - qk_ptr, cos_ptr, sin_ptr, out_ptr, - seq_len, num_heads, head_dim, - stride_s, stride_h, stride_d, - BLOCK_HD: tl.constexpr, -): - pid_s = tl.program_id(0) - pid_h = tl.program_id(1) - half_dim = head_dim // 2 - offs = tl.arange(0, BLOCK_HD) - mask = offs < half_dim - - base = pid_s * stride_s + pid_h * stride_h - x0 = tl.load(qk_ptr + base + offs, mask=mask, other=0.0).to(tl.float32) - x1 = tl.load(qk_ptr + base + half_dim + offs, mask=mask, other=0.0).to(tl.float32) - - seq_idx = pid_s % seq_len - cos_val = tl.load(cos_ptr + seq_idx * head_dim + offs, mask=mask, other=1.0).to(tl.float32) - sin_val = tl.load(sin_ptr + seq_idx * head_dim + offs, mask=mask, other=0.0).to(tl.float32) - - out0 = x0 * cos_val - x1 * sin_val - out1 = x0 * sin_val + x1 * cos_val - - tl.store(out_ptr + base + offs, out0.to(x0.dtype), mask=mask) - tl.store(out_ptr + base + half_dim + offs, out1.to(x0.dtype), mask=mask) - - -def triton_rope_3d(qk, cos, sin): - qk = qk.contiguous() - out = torch.empty_like(qk) - batch, seq_len, num_heads, head_dim = qk.shape - half_dim = head_dim // 2 - qk_flat = qk.view(batch * seq_len, num_heads, head_dim) - out_flat = out.view(batch * seq_len, num_heads, head_dim) - grid = (batch * seq_len, num_heads) - BLOCK_HD = triton.next_power_of_2(half_dim) - num_warps = 4 if BLOCK_HD <= 64 else 8 - rope_3d_fwd_kernel[grid]( - qk_flat, cos, sin, out_flat, - seq_len, num_heads, head_dim, - qk_flat.stride(0), qk_flat.stride(1), qk_flat.stride(2), - BLOCK_HD=BLOCK_HD, num_warps=num_warps, num_stages=2, - ) - return out - - -def pytorch_rope(qk, cos, sin): - half = qk.shape[-1] // 2 - x0, x1 = qk[..., :half], qk[..., half:] - cos_exp = cos.unsqueeze(0).unsqueeze(2)[:, :qk.shape[1], :, :half] - sin_exp = sin.unsqueeze(0).unsqueeze(2)[:, :qk.shape[1], :, :half] - out0 = x0 * cos_exp - x1 * sin_exp - out1 = x0 * sin_exp + x1 * cos_exp - return torch.cat([out0, out1], dim=-1) - - -# ============================================================================ -# Kernel 3: GEGLU -# ============================================================================ -# Same BLOCK_SIZE fix as RMSNorm: compute dynamically, do NOT autotune. - -@triton.jit -def geglu_fwd_kernel( - input_ptr, output_ptr, - stride_in, stride_out, hidden_size, - BLOCK_H: tl.constexpr, -): - row = tl.program_id(0) - offs = tl.arange(0, BLOCK_H) - mask = offs < hidden_size - - gate = tl.load(input_ptr + row * stride_in + offs, mask=mask, other=0.0).to(tl.float32) - value = tl.load(input_ptr + row * stride_in + hidden_size + offs, mask=mask, other=0.0).to(tl.float32) - - # GELU approx: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) - # tl.math.tanh / tl.libdevice.tanh NOT available on ROCm — use manual formula - SQRT_2_OVER_PI = 0.7978845608028654 - tanh_arg = SQRT_2_OVER_PI * (gate + 0.044715 * gate * gate * gate) - e2x = tl.exp(2.0 * tanh_arg) - tanh_val = (e2x - 1.0) / (e2x + 1.0) - cdf = 0.5 * (1.0 + tanh_val) - gelu_gate = gate * cdf - result = gelu_gate * value - - tl.store(output_ptr + row * stride_out + offs, result.to(gate.dtype), mask=mask) - - -def triton_geglu(x): - x = x.contiguous() - *batch_dims, double_h = x.shape - hidden_size = double_h // 2 - x_2d = x.view(-1, double_h) - M = x_2d.shape[0] - out = torch.empty(M, hidden_size, device=x.device, dtype=x.dtype) - - BLOCK_H = triton.next_power_of_2(hidden_size) - num_warps = 4 if BLOCK_H <= 1024 else (8 if BLOCK_H <= 4096 else 16) - geglu_fwd_kernel[(M,)]( - x_2d, out, - x_2d.stride(0), out.stride(0), hidden_size, - BLOCK_H=BLOCK_H, num_warps=num_warps, num_stages=2, - ) - return out.view(*batch_dims, hidden_size) - - -def pytorch_geglu(x): - hidden_size = x.shape[-1] // 2 - gate, value = x[..., :hidden_size], x[..., hidden_size:] - return torch.nn.functional.gelu(gate, approximate='tanh') * value - - -# ============================================================================ -# Kernel 4: AdaLN -# ============================================================================ -# Same BLOCK_D fix: compute dynamically. - -@triton.jit -def adaln_fwd_kernel( - x_ptr, weight_ptr, scale_ptr, shift_ptr, out_ptr, - stride_x, stride_cond, D, - eps: tl.constexpr, - BLOCK_D: tl.constexpr, -): - row = tl.program_id(0) - offs = tl.arange(0, BLOCK_D) - mask = offs < D - - x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) - variance = tl.sum(x * x, axis=0) / D - rms_inv = tl.rsqrt(variance + eps) - x_norm = x * rms_inv - - w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) - scale = tl.load(scale_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) - shift = tl.load(shift_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) - - out = x_norm * w * (1.0 + scale) + shift - tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) - - -def triton_adaln(x, weight, scale, shift, eps=1e-6): - x_flat = x.contiguous().view(-1, x.shape[-1]) - scale_flat = scale.contiguous().view(-1, x.shape[-1]) - shift_flat = shift.contiguous().view(-1, x.shape[-1]) - out = torch.empty_like(x_flat) - M, D = x_flat.shape - - BLOCK_D = triton.next_power_of_2(D) - num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) - adaln_fwd_kernel[(M,)]( - x_flat, weight, scale_flat, shift_flat, out, - x_flat.stride(0), scale_flat.stride(0), D, eps, - BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, - ) - return out.view_as(x) - - -def pytorch_adaln(x, weight, scale, shift, eps=1e-6): - variance = x.pow(2).mean(dim=-1, keepdim=True) - x_norm = x * torch.rsqrt(variance + eps) - return x_norm * weight * (1.0 + scale) + shift - - -# ============================================================================ -# Benchmark Utilities -# ============================================================================ - -def benchmark_fn(func, args, warmup=20, iterations=100) -> Tuple[float, float]: - for _ in range(warmup): - func(*args) - torch.cuda.synchronize() - - times = [] - for _ in range(iterations): - torch.cuda.synchronize() - start = time.perf_counter() - func(*args) - torch.cuda.synchronize() - end = time.perf_counter() - times.append((end - start) * 1000) - - return sum(times) / len(times), min(times) - - -def check_correctness(out, ref, name, dtype): - max_abs = (out.float() - ref.float()).abs().max().item() - max_rel = ((out.float() - ref.float()).abs() / (ref.float().abs() + 1e-8)).max().item() - - # BF16 has 7-bit mantissa; for values ~8-16 the ULP is 0.0625-0.125 - # FP16 has 10-bit mantissa; much tighter precision - atol = 0.15 if dtype == torch.bfloat16 else 0.01 - passed = max_abs < atol - status = "PASS" if passed else "FAIL" - print(f" [{status}] {name}: max_abs={max_abs:.6e}, max_rel={max_rel:.6e}") - return passed - - -# ============================================================================ -# Benchmark Runners -# ============================================================================ - -def benchmark_rmsnorm(dtype): - print("\n" + "=" * 70) - print("BENCHMARK: RMSNorm (168 instances in LTX-Video)") - print("=" * 70) - - configs = [ - (1, 1024, 2048), - (2, 1024, 2048), - (4, 1024, 2048), - (1, 4096, 2048), - (2, 4096, 3072), - (1, 8192, 2048), - (4, 4096, 3072), - ] - - print(f"\n{'Config':<25} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") - print("-" * 70) - - all_correct = True - total_speedup = 0 - - for batch, seq, hidden in configs: - x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") - w = torch.ones(hidden, dtype=dtype, device="cuda") - - ref = pytorch_rmsnorm(x, w) - out = triton_rmsnorm(x, w) - if not check_correctness(out, ref, f"[{batch}x{seq}x{hidden}]", dtype): - all_correct = False - - t_avg, _ = benchmark_fn(triton_rmsnorm, (x, w)) - p_avg, _ = benchmark_fn(pytorch_rmsnorm, (x, w)) - speedup = p_avg / t_avg - total_speedup += speedup - - print(f" [{batch}x{seq}x{hidden}]{'':<13} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") - - # No-weight variant - print("\n -- No-weight variant (elementwise_affine=False) --") - x = torch.randn(2, 4096, 2048, dtype=dtype, device="cuda") - ref_nw = pytorch_rmsnorm(x, None) - out_nw = triton_rmsnorm(x, None) - check_correctness(out_nw, ref_nw, "no-weight [2x4096x2048]", dtype) - - avg_speedup = total_speedup / len(configs) - print(f"\n Average speedup: {avg_speedup:.2f}x") - - # Bandwidth analysis - batch, seq, hidden = 4, 4096, 3072 - x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") - w = torch.ones(hidden, dtype=dtype, device="cuda") - bytes_per_elem = 2 if dtype in (torch.float16, torch.bfloat16) else 4 - total_bytes = batch * seq * hidden * bytes_per_elem * 2 + hidden * bytes_per_elem - t_avg, _ = benchmark_fn(triton_rmsnorm, (x, w)) - bw_gbps = (total_bytes / 1e9) / (t_avg / 1000) - print(f"\n Bandwidth analysis [{batch}x{seq}x{hidden}]:") - print(f" Data moved: {total_bytes / 1e6:.2f} MB") - print(f" Achieved: {bw_gbps:.1f} GB/s") - - return all_correct, avg_speedup - - -def benchmark_rope(dtype): - print("\n" + "=" * 70) - print("BENCHMARK: RoPE 3D (Video Position Encoding)") - print("=" * 70) - - configs = [ - (1, 1024, 16, 64), - (1, 4096, 16, 64), - (2, 4096, 16, 128), - (1, 8192, 32, 64), - ] - - print(f"\n{'Config':<30} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") - print("-" * 75) - - all_correct = True - total_speedup = 0 - - for batch, seq, heads, hdim in configs: - qk = torch.randn(batch, seq, heads, hdim, dtype=dtype, device="cuda") - cos = torch.randn(seq, hdim, dtype=dtype, device="cuda") - sin = torch.randn(seq, hdim, dtype=dtype, device="cuda") - - ref = pytorch_rope(qk, cos, sin) - out = triton_rope_3d(qk, cos, sin) - if not check_correctness(out, ref, f"[{batch}x{seq}x{heads}x{hdim}]", dtype): - all_correct = False - - t_avg, _ = benchmark_fn(triton_rope_3d, (qk, cos, sin)) - p_avg, _ = benchmark_fn(pytorch_rope, (qk, cos, sin)) - speedup = p_avg / t_avg - total_speedup += speedup - - cfg = f"[{batch}x{seq}x{heads}x{hdim}]" - print(f" {cfg:<28} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") - - avg_speedup = total_speedup / len(configs) - print(f"\n Average speedup: {avg_speedup:.2f}x") - return all_correct, avg_speedup - - -def benchmark_geglu(dtype): - print("\n" + "=" * 70) - print("BENCHMARK: GEGLU (For SD3/FLUX, NOT LTX-Video)") - print("=" * 70) - - configs = [ - (1, 1024, 2048), - (2, 1024, 4096), - (2, 4096, 3072), - (4, 4096, 4096), - ] - - print(f"\n{'Config':<30} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") - print("-" * 75) - - all_correct = True - total_speedup = 0 - - for batch, seq, hidden in configs: - x = torch.randn(batch, seq, hidden * 2, dtype=dtype, device="cuda") - - ref = pytorch_geglu(x) - out = triton_geglu(x) - if not check_correctness(out, ref, f"[{batch}x{seq}x{hidden*2}]", dtype): - all_correct = False - - t_avg, _ = benchmark_fn(triton_geglu, (x,)) - p_avg, _ = benchmark_fn(pytorch_geglu, (x,)) - speedup = p_avg / t_avg - total_speedup += speedup - - cfg = f"[{batch}x{seq}x{hidden*2}->{hidden}]" - print(f" {cfg:<28} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") - - avg_speedup = total_speedup / len(configs) - print(f"\n Average speedup: {avg_speedup:.2f}x") - return all_correct, avg_speedup - - -def benchmark_adaln(dtype): - print("\n" + "=" * 70) - print("BENCHMARK: AdaLN (Fused Norm + Conditioning for DiT)") - print("=" * 70) - - configs = [ - (1, 1024, 2048), - (2, 1024, 2048), - (2, 4096, 3072), - (4, 4096, 3072), - ] - - print(f"\n{'Config':<25} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") - print("-" * 70) - - all_correct = True - total_speedup = 0 - - for batch, seq, hidden in configs: - x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") - w = torch.ones(hidden, dtype=dtype, device="cuda") - scale = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") * 0.1 - shift = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") * 0.1 - - ref = pytorch_adaln(x, w, scale, shift) - out = triton_adaln(x, w, scale, shift) - if not check_correctness(out, ref, f"[{batch}x{seq}x{hidden}]", dtype): - all_correct = False - - t_avg, _ = benchmark_fn(triton_adaln, (x, w, scale, shift)) - p_avg, _ = benchmark_fn(pytorch_adaln, (x, w, scale, shift)) - speedup = p_avg / t_avg - total_speedup += speedup - - print(f" [{batch}x{seq}x{hidden}]{'':<13} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") - - avg_speedup = total_speedup / len(configs) - print(f"\n Average speedup: {avg_speedup:.2f}x") - return all_correct, avg_speedup - - -# ============================================================================ -# Main -# ============================================================================ - -def main(): - parser = argparse.ArgumentParser(description="Benchmark Triton kernels on ROCm") - parser.add_argument("--kernel", type=str, default="all", - choices=["all", "rmsnorm", "rope", "geglu", "adaln"]) - parser.add_argument("--dtype", type=str, default="bfloat16", - choices=["bfloat16", "float16"]) - args = parser.parse_args() - - dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 - - print("=" * 70) - print("ROCm Triton Kernel Micro-Benchmark") - print("=" * 70) - print(f"Device: {torch.cuda.get_device_name(0)}") - print(f"Dtype: {dtype}") - print(f"ROCm: {torch.version.hip if hasattr(torch.version, 'hip') else 'N/A'}") - - results = {} - runners = { - "rmsnorm": benchmark_rmsnorm, - "rope": benchmark_rope, - "geglu": benchmark_geglu, - "adaln": benchmark_adaln, - } - - if args.kernel == "all": - for name, runner in runners.items(): - correct, speedup = runner(dtype) - results[name] = {"correct": correct, "speedup": speedup} - else: - correct, speedup = runners[args.kernel](dtype) - results[args.kernel] = {"correct": correct, "speedup": speedup} - - # Summary - print("\n" + "=" * 70) - print("SUMMARY") - print("=" * 70) - print(f"{'Kernel':<15} {'Correct':<12} {'Avg Speedup':<15}") - print("-" * 42) - for name, r in results.items(): - status = "PASS" if r["correct"] else "FAIL" - print(f"{name:<15} {status:<12} {r['speedup']:.2f}x") - - all_pass = all(r["correct"] for r in results.values()) - print(f"\nOverall: {'ALL PASS' if all_pass else 'SOME FAILED'}") - print("=" * 70) - - -if __name__ == "__main__": - main() From 2057f31d2f9f4ed01b51544fd035132cd1c3eee2 Mon Sep 17 00:00:00 2001 From: 01xjw <220233704@seu.edu.cn> Date: Tue, 17 Mar 2026 08:15:41 +0000 Subject: [PATCH 5/5] fix the benchmark_kernels file --- .../rocm-kernels/scripts/benchmark_kernels.py | 536 ++++++++++++++++++ 1 file changed, 536 insertions(+) create mode 100644 skills/rocm-kernels/scripts/benchmark_kernels.py diff --git a/skills/rocm-kernels/scripts/benchmark_kernels.py b/skills/rocm-kernels/scripts/benchmark_kernels.py new file mode 100644 index 00000000..26bb9f8d --- /dev/null +++ b/skills/rocm-kernels/scripts/benchmark_kernels.py @@ -0,0 +1,536 @@ +#!/usr/bin/env python3 +""" +Micro-benchmark for all 4 Triton kernels on ROCm: RMSNorm, RoPE 3D, GEGLU, AdaLN. + +Measures: + 1. Correctness vs PyTorch reference + 2. Latency (custom vs baseline, warmup + averaged) + 3. Memory bandwidth utilization + +Usage: + python benchmark_kernels.py + python benchmark_kernels.py --kernel rmsnorm + python benchmark_kernels.py --kernel rope + python benchmark_kernels.py --kernel geglu + python benchmark_kernels.py --kernel adaln + python benchmark_kernels.py --dtype float16 +""" +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' + +import argparse +import time +from typing import Tuple + +import torch +import triton +import triton.language as tl + + +# ============================================================================ +# Kernel 1: RMSNorm +# ============================================================================ +# CRITICAL: BLOCK_D must be >= D (hidden dimension). +# Using autotune with fixed BLOCK_D configs is WRONG because autotune may +# pick BLOCK_D < D, causing only partial row processing. +# Fix: compute BLOCK_D = next_power_of_2(D) dynamically in the Python wrapper. + +@triton.jit +def rmsnorm_fwd_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x, D, + eps, + HAS_WEIGHT: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_D) + mask = col_offsets < D + + x = tl.load(x_ptr + row * stride_x + col_offsets, mask=mask, other=0.0).to(tl.float32) + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + + if HAS_WEIGHT: + w = tl.load(weight_ptr + col_offsets, mask=mask, other=1.0).to(tl.float32) + result = x * rms_inv * w + else: + result = x * rms_inv + + tl.store(out_ptr + row * stride_x + col_offsets, result.to(x.dtype), mask=mask) + + +def triton_rmsnorm(x, weight=None, eps=1e-6): + orig_shape = x.shape + x_2d = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_2d) + M, D = x_2d.shape + has_weight = weight is not None + if not has_weight: + weight = torch.empty(0, device=x.device) + + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + rmsnorm_fwd_kernel[(M,)]( + x_2d, weight, out, + x_2d.stride(0), D, float(eps), has_weight, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view(orig_shape) + + +def pytorch_rmsnorm(x, weight=None, eps=1e-6): + variance = x.pow(2).mean(dim=-1, keepdim=True) + out = x * torch.rsqrt(variance + eps) + if weight is not None: + out = out * weight + return out + + +# ============================================================================ +# Kernel 2: RoPE 3D +# ============================================================================ +# CRITICAL: cos/sin have shape [seq_len, head_dim], NOT [batch*seq_len, ...]. +# When grid is (batch * seq_len, num_heads), we must use pid_s % seq_len +# to index into cos/sin to avoid out-of-bounds access for batch > 1. + +@triton.jit +def rope_3d_fwd_kernel( + qk_ptr, cos_ptr, sin_ptr, out_ptr, + seq_len, num_heads, head_dim, + stride_s, stride_h, stride_d, + BLOCK_HD: tl.constexpr, +): + pid_s = tl.program_id(0) + pid_h = tl.program_id(1) + half_dim = head_dim // 2 + offs = tl.arange(0, BLOCK_HD) + mask = offs < half_dim + + base = pid_s * stride_s + pid_h * stride_h + x0 = tl.load(qk_ptr + base + offs, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load(qk_ptr + base + half_dim + offs, mask=mask, other=0.0).to(tl.float32) + + seq_idx = pid_s % seq_len + cos_val = tl.load(cos_ptr + seq_idx * head_dim + offs, mask=mask, other=1.0).to(tl.float32) + sin_val = tl.load(sin_ptr + seq_idx * head_dim + offs, mask=mask, other=0.0).to(tl.float32) + + out0 = x0 * cos_val - x1 * sin_val + out1 = x0 * sin_val + x1 * cos_val + + tl.store(out_ptr + base + offs, out0.to(x0.dtype), mask=mask) + tl.store(out_ptr + base + half_dim + offs, out1.to(x0.dtype), mask=mask) + + +def triton_rope_3d(qk, cos, sin): + qk = qk.contiguous() + out = torch.empty_like(qk) + batch, seq_len, num_heads, head_dim = qk.shape + half_dim = head_dim // 2 + qk_flat = qk.view(batch * seq_len, num_heads, head_dim) + out_flat = out.view(batch * seq_len, num_heads, head_dim) + grid = (batch * seq_len, num_heads) + BLOCK_HD = triton.next_power_of_2(half_dim) + num_warps = 4 if BLOCK_HD <= 64 else 8 + rope_3d_fwd_kernel[grid]( + qk_flat, cos, sin, out_flat, + seq_len, num_heads, head_dim, + qk_flat.stride(0), qk_flat.stride(1), qk_flat.stride(2), + BLOCK_HD=BLOCK_HD, num_warps=num_warps, num_stages=2, + ) + return out + + +def pytorch_rope(qk, cos, sin): + half = qk.shape[-1] // 2 + x0, x1 = qk[..., :half], qk[..., half:] + cos_exp = cos.unsqueeze(0).unsqueeze(2)[:, :qk.shape[1], :, :half] + sin_exp = sin.unsqueeze(0).unsqueeze(2)[:, :qk.shape[1], :, :half] + out0 = x0 * cos_exp - x1 * sin_exp + out1 = x0 * sin_exp + x1 * cos_exp + return torch.cat([out0, out1], dim=-1) + + +# ============================================================================ +# Kernel 3: GEGLU +# ============================================================================ +# Same BLOCK_SIZE fix as RMSNorm: compute dynamically, do NOT autotune. + +@triton.jit +def geglu_fwd_kernel( + input_ptr, output_ptr, + stride_in, stride_out, hidden_size, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_H) + mask = offs < hidden_size + + gate = tl.load(input_ptr + row * stride_in + offs, mask=mask, other=0.0).to(tl.float32) + value = tl.load(input_ptr + row * stride_in + hidden_size + offs, mask=mask, other=0.0).to(tl.float32) + + # GELU approx: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + # tl.math.tanh / tl.libdevice.tanh NOT available on ROCm — use manual formula + SQRT_2_OVER_PI = 0.7978845608028654 + tanh_arg = SQRT_2_OVER_PI * (gate + 0.044715 * gate * gate * gate) + e2x = tl.exp(2.0 * tanh_arg) + tanh_val = (e2x - 1.0) / (e2x + 1.0) + cdf = 0.5 * (1.0 + tanh_val) + gelu_gate = gate * cdf + result = gelu_gate * value + + tl.store(output_ptr + row * stride_out + offs, result.to(gate.dtype), mask=mask) + + +def triton_geglu(x): + x = x.contiguous() + *batch_dims, double_h = x.shape + hidden_size = double_h // 2 + x_2d = x.view(-1, double_h) + M = x_2d.shape[0] + out = torch.empty(M, hidden_size, device=x.device, dtype=x.dtype) + + BLOCK_H = triton.next_power_of_2(hidden_size) + num_warps = 4 if BLOCK_H <= 1024 else (8 if BLOCK_H <= 4096 else 16) + geglu_fwd_kernel[(M,)]( + x_2d, out, + x_2d.stride(0), out.stride(0), hidden_size, + BLOCK_H=BLOCK_H, num_warps=num_warps, num_stages=2, + ) + return out.view(*batch_dims, hidden_size) + + +def pytorch_geglu(x): + hidden_size = x.shape[-1] // 2 + gate, value = x[..., :hidden_size], x[..., hidden_size:] + return torch.nn.functional.gelu(gate, approximate='tanh') * value + + +# ============================================================================ +# Kernel 4: AdaLN +# ============================================================================ +# Same BLOCK_D fix: compute dynamically. + +@triton.jit +def adaln_fwd_kernel( + x_ptr, weight_ptr, scale_ptr, shift_ptr, out_ptr, + stride_x, stride_cond, D, + eps, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + x_norm = x * rms_inv + + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + scale = tl.load(scale_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) + shift = tl.load(shift_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) + + out = x_norm * w * (1.0 + scale) + shift + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) + + +def triton_adaln(x, weight, scale, shift, eps=1e-6): + x_flat = x.contiguous().view(-1, x.shape[-1]) + scale_flat = scale.contiguous().view(-1, x.shape[-1]) + shift_flat = shift.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_flat) + M, D = x_flat.shape + + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + adaln_fwd_kernel[(M,)]( + x_flat, weight, scale_flat, shift_flat, out, + x_flat.stride(0), scale_flat.stride(0), D, float(eps), + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) + + +def pytorch_adaln(x, weight, scale, shift, eps=1e-6): + variance = x.pow(2).mean(dim=-1, keepdim=True) + x_norm = x * torch.rsqrt(variance + eps) + return x_norm * weight * (1.0 + scale) + shift + + +# ============================================================================ +# Benchmark Utilities +# ============================================================================ + +def benchmark_fn(func, args, warmup=20, iterations=100) -> Tuple[float, float]: + for _ in range(warmup): + func(*args) + torch.cuda.synchronize() + + times = [] + for _ in range(iterations): + torch.cuda.synchronize() + start = time.perf_counter() + func(*args) + torch.cuda.synchronize() + end = time.perf_counter() + times.append((end - start) * 1000) + + return sum(times) / len(times), min(times) + + +def check_correctness(out, ref, name, dtype): + max_abs = (out.float() - ref.float()).abs().max().item() + max_rel = ((out.float() - ref.float()).abs() / (ref.float().abs() + 1e-8)).max().item() + + # BF16 has 7-bit mantissa; for values ~8-16 the ULP is 0.0625-0.125 + # FP16 has 10-bit mantissa; tighter but RoPE trig ops can accumulate 1-2 ULP error + atol = 0.15 if dtype == torch.bfloat16 else 0.02 + passed = max_abs < atol + status = "PASS" if passed else "FAIL" + print(f" [{status}] {name}: max_abs={max_abs:.6e}, max_rel={max_rel:.6e}") + return passed + + +# ============================================================================ +# Benchmark Runners +# ============================================================================ + +def benchmark_rmsnorm(dtype): + print("\n" + "=" * 70) + print("BENCHMARK: RMSNorm (168 instances in LTX-Video)") + print("=" * 70) + + configs = [ + (1, 1024, 2048), + (2, 1024, 2048), + (4, 1024, 2048), + (1, 4096, 2048), + (2, 4096, 3072), + (1, 8192, 2048), + (4, 4096, 3072), + ] + + print(f"\n{'Config':<25} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") + print("-" * 70) + + all_correct = True + total_speedup = 0 + + for batch, seq, hidden in configs: + x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") + w = torch.ones(hidden, dtype=dtype, device="cuda") + + ref = pytorch_rmsnorm(x, w) + out = triton_rmsnorm(x, w) + if not check_correctness(out, ref, f"[{batch}x{seq}x{hidden}]", dtype): + all_correct = False + + t_avg, _ = benchmark_fn(triton_rmsnorm, (x, w)) + p_avg, _ = benchmark_fn(pytorch_rmsnorm, (x, w)) + speedup = p_avg / t_avg + total_speedup += speedup + + print(f" [{batch}x{seq}x{hidden}]{'':<13} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") + + # No-weight variant + print("\n -- No-weight variant (elementwise_affine=False) --") + x = torch.randn(2, 4096, 2048, dtype=dtype, device="cuda") + ref_nw = pytorch_rmsnorm(x, None) + out_nw = triton_rmsnorm(x, None) + check_correctness(out_nw, ref_nw, "no-weight [2x4096x2048]", dtype) + + avg_speedup = total_speedup / len(configs) + print(f"\n Average speedup: {avg_speedup:.2f}x") + + # Bandwidth analysis + batch, seq, hidden = 4, 4096, 3072 + x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") + w = torch.ones(hidden, dtype=dtype, device="cuda") + bytes_per_elem = 2 if dtype in (torch.float16, torch.bfloat16) else 4 + total_bytes = batch * seq * hidden * bytes_per_elem * 2 + hidden * bytes_per_elem + t_avg, _ = benchmark_fn(triton_rmsnorm, (x, w)) + bw_gbps = (total_bytes / 1e9) / (t_avg / 1000) + print(f"\n Bandwidth analysis [{batch}x{seq}x{hidden}]:") + print(f" Data moved: {total_bytes / 1e6:.2f} MB") + print(f" Achieved: {bw_gbps:.1f} GB/s") + + return all_correct, avg_speedup + + +def benchmark_rope(dtype): + print("\n" + "=" * 70) + print("BENCHMARK: RoPE 3D (Video Position Encoding)") + print("=" * 70) + + configs = [ + (1, 1024, 16, 64), + (1, 4096, 16, 64), + (2, 4096, 16, 128), + (1, 8192, 32, 64), + ] + + print(f"\n{'Config':<30} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") + print("-" * 75) + + all_correct = True + total_speedup = 0 + + for batch, seq, heads, hdim in configs: + qk = torch.randn(batch, seq, heads, hdim, dtype=dtype, device="cuda") + cos = torch.randn(seq, hdim, dtype=dtype, device="cuda") + sin = torch.randn(seq, hdim, dtype=dtype, device="cuda") + + ref = pytorch_rope(qk, cos, sin) + out = triton_rope_3d(qk, cos, sin) + if not check_correctness(out, ref, f"[{batch}x{seq}x{heads}x{hdim}]", dtype): + all_correct = False + + t_avg, _ = benchmark_fn(triton_rope_3d, (qk, cos, sin)) + p_avg, _ = benchmark_fn(pytorch_rope, (qk, cos, sin)) + speedup = p_avg / t_avg + total_speedup += speedup + + cfg = f"[{batch}x{seq}x{heads}x{hdim}]" + print(f" {cfg:<28} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") + + avg_speedup = total_speedup / len(configs) + print(f"\n Average speedup: {avg_speedup:.2f}x") + return all_correct, avg_speedup + + +def benchmark_geglu(dtype): + print("\n" + "=" * 70) + print("BENCHMARK: GEGLU (For SD3/FLUX, NOT LTX-Video)") + print("=" * 70) + + configs = [ + (1, 1024, 2048), + (2, 1024, 4096), + (2, 4096, 3072), + (4, 4096, 4096), + ] + + print(f"\n{'Config':<30} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") + print("-" * 75) + + all_correct = True + total_speedup = 0 + + for batch, seq, hidden in configs: + x = torch.randn(batch, seq, hidden * 2, dtype=dtype, device="cuda") + + ref = pytorch_geglu(x) + out = triton_geglu(x) + if not check_correctness(out, ref, f"[{batch}x{seq}x{hidden*2}]", dtype): + all_correct = False + + t_avg, _ = benchmark_fn(triton_geglu, (x,)) + p_avg, _ = benchmark_fn(pytorch_geglu, (x,)) + speedup = p_avg / t_avg + total_speedup += speedup + + cfg = f"[{batch}x{seq}x{hidden*2}->{hidden}]" + print(f" {cfg:<28} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") + + avg_speedup = total_speedup / len(configs) + print(f"\n Average speedup: {avg_speedup:.2f}x") + return all_correct, avg_speedup + + +def benchmark_adaln(dtype): + print("\n" + "=" * 70) + print("BENCHMARK: AdaLN (Fused Norm + Conditioning for DiT)") + print("=" * 70) + + configs = [ + (1, 1024, 2048), + (2, 1024, 2048), + (2, 4096, 3072), + (4, 4096, 3072), + ] + + print(f"\n{'Config':<25} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") + print("-" * 70) + + all_correct = True + total_speedup = 0 + + for batch, seq, hidden in configs: + x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") + w = torch.ones(hidden, dtype=dtype, device="cuda") + scale = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") * 0.1 + shift = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") * 0.1 + + ref = pytorch_adaln(x, w, scale, shift) + out = triton_adaln(x, w, scale, shift) + if not check_correctness(out, ref, f"[{batch}x{seq}x{hidden}]", dtype): + all_correct = False + + t_avg, _ = benchmark_fn(triton_adaln, (x, w, scale, shift)) + p_avg, _ = benchmark_fn(pytorch_adaln, (x, w, scale, shift)) + speedup = p_avg / t_avg + total_speedup += speedup + + print(f" [{batch}x{seq}x{hidden}]{'':<13} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") + + avg_speedup = total_speedup / len(configs) + print(f"\n Average speedup: {avg_speedup:.2f}x") + return all_correct, avg_speedup + + +# ============================================================================ +# Main +# ============================================================================ + +def main(): + parser = argparse.ArgumentParser(description="Benchmark Triton kernels on ROCm") + parser.add_argument("--kernel", type=str, default="all", + choices=["all", "rmsnorm", "rope", "geglu", "adaln"]) + parser.add_argument("--dtype", type=str, default="bfloat16", + choices=["bfloat16", "float16"]) + args = parser.parse_args() + + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + + print("=" * 70) + print("ROCm Triton Kernel Micro-Benchmark") + print("=" * 70) + print(f"Device: {torch.cuda.get_device_name(0)}") + print(f"Dtype: {dtype}") + print(f"ROCm: {torch.version.hip if hasattr(torch.version, 'hip') else 'N/A'}") + + results = {} + runners = { + "rmsnorm": benchmark_rmsnorm, + "rope": benchmark_rope, + "geglu": benchmark_geglu, + "adaln": benchmark_adaln, + } + + if args.kernel == "all": + for name, runner in runners.items(): + correct, speedup = runner(dtype) + results[name] = {"correct": correct, "speedup": speedup} + else: + correct, speedup = runners[args.kernel](dtype) + results[args.kernel] = {"correct": correct, "speedup": speedup} + + # Summary + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + print(f"{'Kernel':<15} {'Correct':<12} {'Avg Speedup':<15}") + print("-" * 42) + for name, r in results.items(): + status = "PASS" if r["correct"] else "FAIL" + print(f"{name:<15} {status:<12} {r['speedup']:.2f}x") + + all_pass = all(r["correct"] for r in results.values()) + print(f"\nOverall: {'ALL PASS' if all_pass else 'SOME FAILED'}") + print("=" * 70) + + +if __name__ == "__main__": + main()