diff --git a/skills/rocm-kernels/CHANGELOG.md b/skills/rocm-kernels/CHANGELOG.md new file mode 100644 index 00000000..1d52b9a5 --- /dev/null +++ b/skills/rocm-kernels/CHANGELOG.md @@ -0,0 +1,32 @@ +# Changelog + +## v0.2 (2026-03-12) + +### Added +- **Transformers integration**: `references/transformers-integration.md` — LLaMA/Mistral/Qwen RMSNorm patching, Flash Attention 2, epsilon handling differences +- **Transformers injection script**: `scripts/transformers_injection_example.py` — minimal runnable example (~150 lines) +- **HuggingFace Kernels Hub integration**: `references/huggingface-kernels-integration.md` — `get_kernel`, `has_kernel`, publishing, ROCm compatibility notes +- **HuggingFace Kernels example script**: `scripts/huggingface_kernels_example.py` — Hub loading, benchmarking, model integration with fallback +- **GEMM template with XCD swizzle**: Template 5 in `kernel-templates.md` — full GEMM kernel with XCD swizzle for MI355X, L2 cache grouping, autotune configs, Python API, and benchmark +- **CHANGELOG.md**: Version tracking for skill iterations + +### Fixed +- Broken cross-references: "Template 2" for GEMM → corrected to "Template 5" in `troubleshooting.md`, `kernelbench-classification.md`, and `skill-evaluation-methodology.md` +- R9700 Memory Bandwidth: filled in ~608 GB/s (was TBD) in SKILL.md + +### Updated +- `SKILL.md` See Also section: added new integration guides, scripts, and Hub links +- `SKILL.md` argument-hint: added gemm, transformers, huggingface-kernels, get_kernel +- `manifest.txt`: added all new files + +## v0.1 (2026-03-10) + +### Added +- Initial skill with SKILL.md, 4 kernel templates (RMSNorm, RoPE 3D, GEGLU, AdaLN) +- MI355X and R9700 GPU optimization guides +- Diffusers integration guide (LTX-Video) +- Troubleshooting guide (14 ROCm-specific issues) +- Benchmark scripts: micro-benchmark (`benchmark_kernels.py`) and E2E (`benchmark_e2e.py`) +- LTX-Video injection example (`ltx_kernel_injection_example.py`) +- KernelBench classification and evaluation methodology docs +- Kernel-agent knowledge base diff --git a/skills/rocm-kernels/SKILL.md b/skills/rocm-kernels/SKILL.md new file mode 100644 index 00000000..3ea49a2b --- /dev/null +++ b/skills/rocm-kernels/SKILL.md @@ -0,0 +1,501 @@ +--- +name: rocm-kernels +description: "Provides guidance for writing and benchmarking optimized Triton kernels for AMD GPUs (MI355X, R9700) on ROCm, targeting HuggingFace diffusers (LTX-Video, SD3, FLUX) and transformers. Core kernels: RMSNorm, RoPE 3D, GEGLU, AdaLN. Includes XCD swizzle, autotune, diffusers integration patterns, and LTX-Video pipeline injection." +disable-model-invocation: false +user-invocable: true +allowed-tools: "Read, Grep, Glob, Bash" +argument-hint: "kernel type: rmsnorm, rope, rope-3d, geglu, adaln, gemm, benchmark, diffusers, transformers, ltx-video, huggingface-kernels, get_kernel, autotune, xcd-swizzle" +--- + +# ROCm Triton Kernels for Diffusers & Transformers + +This skill provides patterns and guidance for developing optimized Triton kernels targeting AMD GPUs (MI355X, R9700) on ROCm, for use with HuggingFace **diffusers** (LTX-Video, SD3, FLUX) and **transformers** libraries. + +## Quick Start + +### Diffusers (LTX-Video) + +**Inject optimized kernels into LTX-Video pipeline:** +```python +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' + +from diffusers import LTXPipeline +pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) +pipe.to("cuda") # ROCm uses same API via HIP +inject_optimized_kernels(pipe) # BEFORE CPU offloading +pipe.enable_model_cpu_offload() +``` + +**For a minimal integration example (~150 lines):** +```bash +python scripts/ltx_kernel_injection_example.py +``` + +### Isolated Kernel Micro-benchmarks +```bash +# All 4 kernels: correctness + performance + bandwidth +python scripts/benchmark_kernels.py + +# Single kernel +python scripts/benchmark_kernels.py --kernel rmsnorm +python scripts/benchmark_kernels.py --kernel rope +python scripts/benchmark_kernels.py --kernel geglu +python scripts/benchmark_kernels.py --kernel adaln +``` + +### End-to-End Pipeline Benchmark +```bash +# Compare baseline vs Triton vs torch.compile +python scripts/benchmark_e2e.py --mode all + +# Quick test +python scripts/benchmark_e2e.py --mode triton --num-frames 9 --steps 5 + +# Save results for comparison +python scripts/benchmark_e2e.py --mode all --output-json results.json +``` + +## Target Model: LTX-Video + +### Architecture Overview + +| Component | Class | Has Weight | Count | Kernel | +|-----------|-------|------------|-------|--------| +| `transformer_blocks.*.norm1` | RMSNorm | **No** (elementwise_affine=False) | 56 | RMSNorm | +| `transformer_blocks.*.norm2` | RMSNorm | **No** | 56 | RMSNorm | +| `transformer_blocks.*.attn1.norm_q` | torch.nn.RMSNorm | Yes | 28 | RMSNorm | +| `transformer_blocks.*.attn1.norm_k` | torch.nn.RMSNorm | Yes | 28 | RMSNorm | +| `transformer_blocks.*.ff` | FeedForward | - | 28 | **GELU** (not GEGLU!) | +| Rotary position encoding | LTXVideoRotaryPosEmbed | - | 1 | RoPE 3D | + +**Total RMSNorm modules: 168** (56 with weights, 112 without) + +### Target Kernels + +| Kernel | Use Case | Input Layout | Key Challenge | +|--------|----------|-------------|---------------| +| **RMSNorm** | Normalization | `[..., hidden_size]` | Weight may be None; 168 instances | +| **RoPE 3D** | Video position encoding | `[batch, t*h*w, heads, head_dim]` | 3D → temporal + spatial decomposition | +| **GEGLU** | Gated activation (SD3/FLUX) | `[batch, seq, 2*hidden]` → `[batch, seq, hidden]` | Gate/value split | +| **AdaLN** | Conditioned normalization (DiT) | `norm(x) * weight * (1+scale) + shift` | Fused norm + condition | + +## Supported Hardware + +| GPU | Architecture | Wave Size | LDS/CU | Mem BW | Key Feature | Verified | +|-----|-------------|-----------|--------|--------|-------------|:--------:| +| **MI355X** | CDNA3+ (gfx950) | Wave64 | **160 KB** | 8 TB/s | 32 XCDs, XCD Swizzle for GEMM | Yes | +| **R9700** | RDNA4 (gfx1201) | **Wave32** | 64 KB | ~608 GB/s | 256B cacheline, inference-focused | Yes | + +> See [MI355X guide](references/mi355x-optimization-guide.md) | [R9700 guide](references/r9700-optimization-guide.md) + +## When This Skill Applies + +Use this skill when: +- Writing Triton kernels for **RMSNorm, RoPE, GEGLU, AdaLN** on AMD GPUs +- Integrating custom kernels with **diffusers** pipelines (LTX-Video, SD3, FLUX) +- Benchmarking kernel performance against PyTorch baseline on ROCm +- Optimizing existing kernels for MI355X or R9700 architecture +- Debugging ROCm/HIP-specific kernel issues + +## Critical ROCm Constraints + +### Things That DON'T Work on AMD + +```python +# FORBIDDEN - CUDA only, NOT available on ROCm +tl.libdevice.tanh(x) # Use manual formula below +tl.libdevice.log1p(x) # Use: tl.log(1.0 + x) +tl.math.tanh(x) # Also NOT available on ROCm Triton + +# Manual tanh (ONLY reliable method on ROCm): +e2x = tl.exp(2.0 * x) +tanh_x = (e2x - 1.0) / (e2x + 1.0) + +# FORBIDDEN - Triton limitations on ROCm +break / continue # Use: tl.where() +min(a, b) / max(a, b) # Use: tl.minimum(a, b) / tl.maximum(a, b) +``` + +### Mandatory Environment Variables + +```python +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' +``` + +## Core Kernel Implementations + +### 1. RMSNorm (Core Optimization Target) + +Row-wise reduction pattern. **168 instances** in LTX-Video, ~5% of total compute. + +**CRITICAL: Do NOT autotune BLOCK_D.** Autotune may pick `BLOCK_D < D`, causing partial row processing and wrong results. Always compute `BLOCK_D = triton.next_power_of_2(D)` in the Python wrapper. + +```python +@triton.jit +def rmsnorm_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x, D, + eps: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + + if HAS_WEIGHT: + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + out = x * rms_inv * w + else: + out = x * rms_inv + + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) + + +def triton_rmsnorm(x, weight=None, eps=1e-6): + x_2d = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_2d) + M, D = x_2d.shape + has_weight = weight is not None + if not has_weight: + weight = torch.empty(0, device=x.device) + + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + rmsnorm_kernel[(M,)]( + x_2d, weight, out, x_2d.stride(0), D, eps, has_weight, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) +``` + +**LTX-Video pitfall: Weight may be None!** +```python +has_weight = hasattr(module, 'weight') and module.weight is not None +``` + +### 2. RoPE 3D (Video Position Encoding) + +Element-wise pattern. LTX-Video splits `head_dim` into temporal + spatial components. + +**CRITICAL: cos/sin have shape `[seq_len, head_dim]`.** When grid flattens batch dimension (`batch * seq_len`), use `pid_s % seq_len` to index cos/sin, otherwise batch > 1 causes OOB GPU crash. + +```python +@triton.jit +def rope_3d_kernel( + qk_ptr, cos_ptr, sin_ptr, out_ptr, + seq_len, num_heads, head_dim, + stride_s, stride_h, stride_d, + BLOCK_HD: tl.constexpr, +): + pid_s = tl.program_id(0) # batch * seq_len + pid_h = tl.program_id(1) # head index + half_dim = head_dim // 2 + offs = tl.arange(0, BLOCK_HD) + mask = offs < half_dim + + base = pid_s * stride_s + pid_h * stride_h + x0 = tl.load(qk_ptr + base + offs, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load(qk_ptr + base + half_dim + offs, mask=mask, other=0.0).to(tl.float32) + + seq_idx = pid_s % seq_len # wrap for batch > 1 + cos_val = tl.load(cos_ptr + seq_idx * head_dim + offs, mask=mask, other=1.0).to(tl.float32) + sin_val = tl.load(sin_ptr + seq_idx * head_dim + offs, mask=mask, other=0.0).to(tl.float32) + + out0 = x0 * cos_val - x1 * sin_val + out1 = x0 * sin_val + x1 * cos_val + + tl.store(out_ptr + base + offs, out0.to(x0.dtype), mask=mask) + tl.store(out_ptr + base + half_dim + offs, out1.to(x0.dtype), mask=mask) + + +def triton_rope_3d(qk, cos, sin): + qk = qk.contiguous() + out = torch.empty_like(qk) + batch, seq_len, num_heads, head_dim = qk.shape + qk_flat = qk.view(batch * seq_len, num_heads, head_dim) + out_flat = out.view(batch * seq_len, num_heads, head_dim) + BLOCK_HD = triton.next_power_of_2(head_dim // 2) + num_warps = 4 if BLOCK_HD <= 64 else 8 + rope_3d_kernel[(batch * seq_len, num_heads)]( + qk_flat, cos, sin, out_flat, + seq_len, num_heads, head_dim, + qk_flat.stride(0), qk_flat.stride(1), qk_flat.stride(2), + BLOCK_HD=BLOCK_HD, num_warps=num_warps, num_stages=2, + ) + return out +``` + +### 3. GEGLU (For SD3/FLUX, NOT LTX-Video) + +Element-wise gated activation. Input `[batch, seq, 2*hidden]` → Output `[batch, seq, hidden]`. + +**Same BLOCK_SIZE rule: compute dynamically, do NOT autotune.** + +```python +@triton.jit +def geglu_kernel( + input_ptr, output_ptr, + stride_in, stride_out, hidden_size, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_H) + mask = offs < hidden_size + + gate = tl.load(input_ptr + row * stride_in + offs, mask=mask, other=0.0).to(tl.float32) + value = tl.load(input_ptr + row * stride_in + hidden_size + offs, mask=mask, other=0.0).to(tl.float32) + + # GELU approx — manual tanh (tl.math.tanh NOT available on ROCm) + k = 0.7978845608028654 # sqrt(2/pi) + tanh_arg = k * (gate + 0.044715 * gate * gate * gate) + e2x = tl.exp(2.0 * tanh_arg) + tanh_val = (e2x - 1.0) / (e2x + 1.0) + gate_gelu = 0.5 * gate * (1.0 + tanh_val) + result = gate_gelu * value + + tl.store(output_ptr + row * stride_out + offs, result.to(gate.dtype), mask=mask) + + +def triton_geglu(x): + x = x.contiguous() + *batch_dims, double_h = x.shape + hidden_size = double_h // 2 + x_2d = x.view(-1, double_h) + M = x_2d.shape[0] + out = torch.empty(M, hidden_size, device=x.device, dtype=x.dtype) + BLOCK_H = triton.next_power_of_2(hidden_size) + num_warps = 4 if BLOCK_H <= 1024 else (8 if BLOCK_H <= 4096 else 16) + geglu_kernel[(M,)]( + x_2d, out, x_2d.stride(0), out.stride(0), hidden_size, + BLOCK_H=BLOCK_H, num_warps=num_warps, num_stages=2, + ) + return out.view(*batch_dims, hidden_size) +``` + +**Warning: LTX-Video uses GELU, NOT GEGLU.** GEGLU is for SD3/FLUX. + +### 4. AdaLN (Adaptive Layer Normalization for DiT) + +Fused normalization + conditioning: `norm(x) * weight * (1 + scale) + shift` + +**Same BLOCK_D rule: compute dynamically.** + +```python +@triton.jit +def adaln_kernel( + x_ptr, weight_ptr, scale_ptr, shift_ptr, out_ptr, + stride_x, stride_cond, D, + eps: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + x_norm = x * rms_inv + + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + scale = tl.load(scale_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) + shift = tl.load(shift_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) + + out = x_norm * w * (1.0 + scale) + shift + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) + + +def triton_adaln(x, weight, scale, shift, eps=1e-6): + x_flat = x.contiguous().view(-1, x.shape[-1]) + scale_flat = scale.contiguous().view(-1, x.shape[-1]) + shift_flat = shift.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_flat) + M, D = x_flat.shape + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + adaln_kernel[(M,)]( + x_flat, weight, scale_flat, shift_flat, out, + x_flat.stride(0), scale_flat.stride(0), D, eps, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) +``` + +## Diffusers Integration + +> **See [diffusers-integration.md](references/diffusers-integration.md) for the complete guide.** + +### Minimal Integration Pattern + +```python +def patch_rmsnorm_modules(model): + """Patch all RMSNorm modules to use custom Triton kernel.""" + for name, module in model.named_modules(): + if type(module).__name__ == 'RMSNorm': + eps = getattr(module, 'eps', 1e-6) + has_weight = hasattr(module, 'weight') and module.weight is not None + if has_weight: + def make_forward(mod, epsilon): + def forward(x): + return triton_rmsnorm(x, mod.weight, eps=epsilon) + return forward + module.forward = make_forward(module, eps) + else: + def make_forward(epsilon): + def forward(x): + w = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype) + return triton_rmsnorm(x, w, eps=epsilon) + return forward + module.forward = make_forward(eps) + +pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) +pipe.to("cuda") +patch_rmsnorm_modules(pipe.transformer) +pipe.enable_model_cpu_offload() +``` + +### Diffusers Critical Pitfalls + +1. **RMSNorm weight may be None** — LTX-Video uses `elementwise_affine=False` +2. **Diffusers RMSNorm != torch.nn.RMSNorm** — Use `type(module).__name__` not `isinstance()` +3. **LTX-Video uses GELU, not GEGLU** — Don't patch GEGLU for LTX-Video +4. **Inject BEFORE CPU offloading** — `inject_kernels()` then `enable_model_cpu_offload()` + +## Performance Expectations + +### Micro-benchmark Results (MI355X, BF16) + +| Kernel | Avg Speedup | Best Config Speedup | Status | +|--------|:-----------:|:-------------------:|:------:| +| **RMSNorm** | **1.71x** | 2.44x ([4×4096×3072]) | PASS | +| **RoPE 3D** | **1.21x** | 1.52x ([2×4096×16×128]) | PASS | +| **GEGLU** | **1.43x** | 2.13x ([4×4096×8192]) | PASS | +| **AdaLN** | **2.22x** | 2.77x ([4×4096×3072]) | PASS | + +RMSNorm bandwidth utilization: 3554 GB/s (MI355X theoretical: 8 TB/s, ~44%). + +### End-to-End LTX-Video (MI355X, 25 frames, 30 steps) + +| Mode | Time (s) | Per Step (s) | Peak Mem (GB) | Speedup | +|------|:--------:|:------------:|:-------------:|:-------:| +| baseline | 1.20 | 0.040 | 18.58 | 1.00x | +| **triton** | **0.98** | **0.033** | **18.58** | **1.22x** | +| torch.compile | 0.78 | 0.026 | 18.58 | 1.54x | + +**Key finding**: MI355X Triton E2E speedup (22%) is significantly higher than H100 CUDA reference (6%), because MI355X's default PyTorch RMSNorm path has more room for optimization. + +### Micro-benchmark Results (R9700, BF16) + +| Kernel | Avg Speedup | Best Config Speedup | Status | +|--------|:-----------:|:-------------------:|:------:| +| **RMSNorm** | **2.90x** | 3.97x ([1×8192×2048]) | PASS | +| **RoPE 3D** | **2.09x** | 2.38x ([1×1024×16×64]) | PASS | +| **GEGLU** | **1.69x** | 1.93x ([2×1024×8192]) | PASS | +| **AdaLN** | **3.00x** | 3.67x ([4×4096×3072]) | PASS | + +RMSNorm bandwidth utilization: 483 GB/s (R9700 theoretical: ~608 GB/s, **~79%**). + +R9700 speedups are higher than MI355X because PyTorch's default RDNA4 backend is less mature, leaving more room for Triton optimization. The bandwidth utilization (79%) is also significantly better than MI355X (44%). + +### End-to-End LTX-Video (R9700, 25 frames, 30 steps) + +| Mode | Time (s) | Per Step (s) | Peak Mem (GB) | Speedup | +|------|:--------:|:------------:|:-------------:|:-------:| +| baseline | 6.89 | 0.230 | 18.58 | 1.00x | +| **triton** | **6.06** | **0.202** | **18.58** | **1.14x** | +| torch.compile | 5.07 | 0.169 | 18.58 | 1.36x | + +### R9700 Additional Validation + +| Test | Result | +|------|--------| +| Transformers injection (TinyLlama 1.1B) | PASS — 45 RMSNorm patched, 99.9 tokens/s | +| HuggingFace Kernels Hub integration | PASS — Hub kernel loads and runs on ROCm | +| Local Triton vs Hub kernel (small shape) | Local **5.92x** vs Hub 1.27x (lower launch overhead) | +| Local Triton vs Hub kernel (large shape) | Local 3.59x vs Hub 3.57x (comparable) | +| num_warps sweep (2/4/8/16/32) | Default heuristic (4/8/16) is near-optimal; nw=32 always worst | +| rocprof kernel fusion analysis | Triton fuses 4 PyTorch kernels (pow+mean+rsqrt+mul) into 1 | + +### CUDA Reference (H100, for comparison) + +| Shape | Custom (ms) | PyTorch (ms) | Speedup | +|:---|:---:|:---:|:---:| +| [1×1024×2048] | 0.019 | 0.065 | **3.37x** | +| [2×4096×3072] | 0.087 | 0.208 | **2.41x** | + +H100 E2E: ~6% (RMSNorm is ~5% of total compute). + +### Optimization Targets + +| Kernel | MI355X | R9700 | Target | Priority | +|--------|:------:|:-----:|:------:|:--------:| +| RMSNorm | 1.71x | 2.90x | >3x (R9700) | P0 — MI355X bandwidth util (44%→60%+) | +| AdaLN | 2.22x | 3.00x | >3.5x (R9700) | P1 — already strong on both | +| GEGLU | 1.43x | 1.69x | >2x | P1 — tanh overhead | +| RoPE 3D | 1.21x | 2.09x | >2.5x (R9700) | P2 — small head_dim launch overhead | + +## Common Issues on ROCm + +| Issue | Symptom | Fix | +|-------|---------|-----| +| **Autotune BLOCK_D** | Wrong results (max_abs 4-8+) | **Never autotune BLOCK_D.** Use `triton.next_power_of_2(D)` | +| **RoPE batch OOB** | GPU crash (`Memory access fault`) | Use `pid_s % seq_len` for cos/sin indexing | +| `tl.libdevice` | Not found on AMD | Use manual math formulas | +| `tl.tanh` / `tl.math.tanh` | Not on ROCm | Manual: `e2x=exp(2x); (e2x-1)/(e2x+1)` | +| Python min/max | Runtime error | `tl.minimum()`/`tl.maximum()` | +| LDS overflow | HIP OOM | Reduce num_stages to 2 | +| Weight is None | AttributeError | Check `elementwise_affine` | +| isinstance() miss | RMSNorm not patched | Use `type(module).__name__` | + +> See [troubleshooting.md](references/troubleshooting.md) for all common issues. + +## Performance Profiling + +```bash +rocprof --stats python your_kernel.py +rocprofv3 -i metrics.txt python your_kernel.py +rocm-bandwidth-test +rocminfo | grep -E "Name|Compute Unit|Wavefront" +``` + +## See Also + +### Benchmark & Test Scripts +- [benchmark_kernels.py](scripts/benchmark_kernels.py) - Micro-benchmark all 4 kernels (correctness + perf + bandwidth) +- [benchmark_e2e.py](scripts/benchmark_e2e.py) - End-to-end LTX-Video pipeline benchmark (baseline vs Triton vs compile) +- [sweep_num_warps.py](scripts/sweep_num_warps.py) - num_warps sweep for R9700 Wave32 optimization +- [ltx_kernel_injection_example.py](scripts/ltx_kernel_injection_example.py) - Minimal diffusers injection example +- [transformers_injection_example.py](scripts/transformers_injection_example.py) - Minimal transformers injection example +- [huggingface_kernels_example.py](scripts/huggingface_kernels_example.py) - HuggingFace Kernels Hub integration example + +### Integration Guides +- [diffusers-integration.md](references/diffusers-integration.md) - LTX-Video pipeline integration +- [transformers-integration.md](references/transformers-integration.md) - LLaMA/Mistral/Qwen integration +- [huggingface-kernels-integration.md](references/huggingface-kernels-integration.md) - HuggingFace Kernels Hub (`get_kernel`) +- [kernel-templates.md](references/kernel-templates.md) - Complete Triton kernel templates (incl. GEMM with XCD Swizzle) + +### GPU Optimization Guides +- [mi355x-optimization-guide.md](references/mi355x-optimization-guide.md) - MI355X (gfx950) deep dive +- [r9700-optimization-guide.md](references/r9700-optimization-guide.md) - R9700 (RDNA4) deep dive + +### Reference +- [troubleshooting.md](references/troubleshooting.md) - Common issues and solutions +- [kernelbench-classification.md](references/kernelbench-classification.md) - KernelBench operator taxonomy +- [skill-evaluation-methodology.md](references/skill-evaluation-methodology.md) - How to evaluate and improve skills +- [kernel-agent-knowledge-base.md](references/kernel-agent-knowledge-base.md) - Knowledge from kernel-agent project + +### External Resources +- [Triton Documentation](https://triton-lang.org/) +- [ROCm Documentation](https://rocm.docs.amd.com/) +- [HuggingFace Kernels Hub](https://huggingface.co/kernels-community) +- [LTX-Video on HuggingFace](https://huggingface.co/Lightricks/LTX-Video) +- [HuggingFace Diffusers](https://huggingface.co/docs/diffusers/en/index) diff --git a/skills/rocm-kernels/manifest.txt b/skills/rocm-kernels/manifest.txt new file mode 100644 index 00000000..3e88dac8 --- /dev/null +++ b/skills/rocm-kernels/manifest.txt @@ -0,0 +1,18 @@ +# Files for rocm-kernels skill +SKILL.md +CHANGELOG.md +references/mi355x-optimization-guide.md +references/r9700-optimization-guide.md +references/kernel-templates.md +references/diffusers-integration.md +references/transformers-integration.md +references/huggingface-kernels-integration.md +references/troubleshooting.md +references/skill-evaluation-methodology.md +references/kernelbench-classification.md +references/kernel-agent-knowledge-base.md +scripts/benchmark_kernels.py +scripts/benchmark_e2e.py +scripts/ltx_kernel_injection_example.py +scripts/transformers_injection_example.py +scripts/huggingface_kernels_example.py diff --git a/skills/rocm-kernels/references/diffusers-integration.md b/skills/rocm-kernels/references/diffusers-integration.md new file mode 100644 index 00000000..1809e0d1 --- /dev/null +++ b/skills/rocm-kernels/references/diffusers-integration.md @@ -0,0 +1,252 @@ +# Diffusers Pipeline Integration Guide (ROCm) + +Integrating custom Triton kernels into HuggingFace diffusers pipelines on AMD GPUs. + +## Overview + +This guide covers injecting optimized Triton kernels (RMSNorm, RoPE 3D, GEGLU, AdaLN) into diffusers pipelines running on ROCm. The patterns are analogous to the CUDA kernel integration but use Triton instead of CUDA C. + +## LTX-Video Architecture + +### Module Inventory + +```python +from diffusers import LTXPipeline +pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) + +# Analyze RMSNorm modules +for name, module in pipe.transformer.named_modules(): + if 'Norm' in type(module).__name__: + has_weight = hasattr(module, 'weight') and module.weight is not None + print(f"{name}: {type(module).__name__} (has_weight={has_weight})") +``` + +### Kernel Applicability in LTX-Video + +| Kernel | Used? | Count | Notes | +|--------|-------|-------|-------| +| **RMSNorm** | Yes | **168** | 56 with weights, 112 without | +| **RoPE 3D** | Indirect | 1 | Diffusers computes via LTXVideoRotaryPosEmbed | +| **GEGLU** | **No** | 0 | LTX uses `activation_fn="gelu-approximate"` | +| **AdaLN** | Partial | ~28 | Scale/shift pattern in transformer blocks | + +## Integration Pattern + +### Step 1: Triton RMSNorm Wrapper + +```python +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _rmsnorm_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x_row, D, + eps: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + + x = tl.load(x_ptr + row * stride_x_row + offs, mask=mask, other=0.0).to(tl.float32) + sq_sum = tl.sum(x * x, axis=0) + rms_inv = tl.rsqrt(sq_sum / D + eps) + + if HAS_WEIGHT: + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + out = x * rms_inv * w + else: + out = x * rms_inv + + tl.store(out_ptr + row * stride_x_row + offs, out.to(x.dtype), mask=mask) + + +def triton_rmsnorm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6): + """Drop-in replacement for RMSNorm forward pass.""" + x_flat = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_flat) + M, D = x_flat.shape + has_weight = weight is not None + + if not has_weight: + weight = torch.ones(D, device=x.device, dtype=x.dtype) + + # CRITICAL: BLOCK_D must be >= D. Never autotune BLOCK_D. + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + _rmsnorm_kernel[(M,)]( + x_flat, weight, out, + x_flat.stride(0), D, + eps, has_weight, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) +``` + +### Step 2: Module Patcher + +```python +def patch_rmsnorm_modules(model) -> int: + """ + Patch all RMSNorm modules to use Triton kernel on ROCm. + + Handles both: + - Modules WITH weight (elementwise_affine=True) — attention norm_q/norm_k + - Modules WITHOUT weight (elementwise_affine=False) — transformer block norms + """ + patched = 0 + for name, module in model.named_modules(): + # IMPORTANT: Use class name, NOT isinstance + if type(module).__name__ == 'RMSNorm': + eps = getattr(module, 'eps', 1e-6) + has_weight = hasattr(module, 'weight') and module.weight is not None + + if has_weight: + def make_forward(mod, epsilon): + def forward(x): + return triton_rmsnorm(x, mod.weight, eps=epsilon) + return forward + module.forward = make_forward(module, eps) + else: + def make_forward_no_weight(epsilon): + def forward(x): + return triton_rmsnorm(x, None, eps=epsilon) + return forward + module.forward = make_forward_no_weight(eps) + + patched += 1 + return patched +``` + +### Step 3: Pipeline Injection + +```python +def inject_optimized_kernels(pipe) -> dict: + """ + Inject Triton kernels into LTX-Video pipeline. + + Call AFTER pipe.to("cuda"), BEFORE pipe.enable_model_cpu_offload(). + """ + stats = {'rmsnorm_modules': 0} + + if not hasattr(pipe, 'transformer'): + print("WARNING: Pipeline has no 'transformer' attribute!") + return stats + + stats['rmsnorm_modules'] = patch_rmsnorm_modules(pipe.transformer) + return stats +``` + +### Step 4: Usage + +```python +import torch +from diffusers import LTXPipeline +from diffusers.utils import export_to_video + +pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) +pipe.to("cuda") # ROCm via HIP + +stats = inject_optimized_kernels(pipe) +print(f"RMSNorm modules patched: {stats['rmsnorm_modules']}") +# Expected: 168 + +pipe.enable_model_cpu_offload() # AFTER injection + +output = pipe( + prompt="A cat sleeping in the sun", + num_frames=25, height=480, width=704, + num_inference_steps=30, +) +export_to_video(output.frames[0], "output.mp4", fps=24) +``` + +## Model-Specific Notes + +### LTX-Video +- Uses **GELU** (`activation_fn="gelu-approximate"`), NOT GEGLU +- RMSNorm in blocks: `elementwise_affine=False` (no weight) +- RMSNorm in attention: `elementwise_affine=True` (has weight) +- RoPE: Computed by diffusers via `LTXVideoRotaryPosEmbed` + +### SD3 / FLUX +- Uses **GEGLU** in FeedForward blocks +- Different attention patterns +- May have different normalization conventions +- Verify architecture before applying LTX-Video patterns + +## ROCm-Specific Considerations + +### BF16 vs FP16 + +```python +# MI355X supports BF16 — use it for diffusers +pipe = LTXPipeline.from_pretrained(..., torch_dtype=torch.bfloat16) + +# R9700 (RDNA4) — check BF16 support, may need FP16 +# torch_dtype=torch.float16 +``` + +### ROCm Memory Management + +```python +# ROCm uses same API as CUDA via HIP +pipe.to("cuda") # Works on ROCm +pipe.enable_model_cpu_offload() # Works on ROCm +torch.cuda.empty_cache() # Works on ROCm +``` + +### Triton on ROCm vs CUDA C Kernels + +| Aspect | CUDA C (original skill) | Triton (this skill) | +|--------|------------------------|---------------------| +| Build system | setup.py + nvcc | No build needed | +| Portability | NVIDIA only | AMD + NVIDIA | +| Performance | Maximum | 80-95% of CUDA C | +| Complexity | High (C++/CUDA) | Lower (Python) | +| Autotune | Manual | `@triton.autotune` | +| torch.compile | Needs custom op | Automatic compatibility | + +## Verification + +```python +# Check injection worked +for name, module in pipe.transformer.named_modules(): + if type(module).__name__ == 'RMSNorm': + x = torch.randn(1, 10, 2048, device='cuda', dtype=torch.bfloat16) + out = module(x) + print(f"RMSNorm forward: {x.shape} -> {out.shape}") + break + +# Compare with PyTorch reference +def pytorch_rmsnorm(x, weight, eps=1e-6): + rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + if weight is not None: + return x * rms * weight + return x * rms + +# Verify correctness +torch.testing.assert_close( + triton_rmsnorm(x, weight, eps=1e-6), + pytorch_rmsnorm(x, weight, eps=1e-6), + rtol=1e-2, atol=1e-3 +) +``` + +## Troubleshooting + +| Issue | Fix | +|-------|-----| +| `NoneType has no attribute contiguous` | RMSNorm weight is None, pass `None` to kernel | +| `isinstance()` not matching | Use `type(module).__name__ == 'RMSNorm'` | +| GEGLU not called | LTX-Video uses GELU, not GEGLU | +| Patching doesn't persist | Inject BEFORE `enable_model_cpu_offload()` | +| HIP error during inference | Check ROCm version compatibility with PyTorch | diff --git a/skills/rocm-kernels/references/huggingface-kernels-integration.md b/skills/rocm-kernels/references/huggingface-kernels-integration.md new file mode 100644 index 00000000..e8287ddb --- /dev/null +++ b/skills/rocm-kernels/references/huggingface-kernels-integration.md @@ -0,0 +1,351 @@ +# HuggingFace Kernels Integration Guide (ROCm) + +Complete guide for using and publishing kernels with the HuggingFace Kernels library (`get_kernel`) on ROCm. + +> **Quick Start:** See [huggingface_kernels_example.py](../scripts/huggingface_kernels_example.py) for a minimal working example. + +## Overview + +The [HuggingFace Kernels](https://huggingface.co/docs/kernels/en/index) library enables dynamic loading of pre-compiled kernels from the Hugging Face Hub. This eliminates the need for local compilation and ensures compatibility across different Python, PyTorch, and CUDA/ROCm versions. + +**Key Benefits:** +- **No local compilation** — download pre-built binaries +- **Version management** — load specific kernel versions +- **Multi-version support** — multiple versions coexist in one Python process +- **Automatic compatibility** — matches your PyTorch/ROCm configuration + +**ROCm Note:** Not all Hub kernels have ROCm builds. Triton-based kernels (e.g., `triton-layer-norm`) are more likely to work on ROCm than CUDA C kernels. Always check with `has_kernel()` first. + +## Installation + +```bash +pip install kernels torch numpy +``` + +Requirements: +- PyTorch >= 2.5 (ROCm build) +- ROCm-capable AMD GPU +- Python 3.8+ + +## Core API + +### get_kernel + +Download and load a kernel from the Hub: + +```python +from kernels import get_kernel + +kernel = get_kernel("kernels-community/triton-layer-norm") + +# With specific version +kernel = get_kernel("kernels-community/triton-layer-norm", version=1) + +# With specific revision +kernel = get_kernel("kernels-community/flash-attn", revision="v2.0.0") +``` + +**Parameters:** + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `repo_id` | str | required | Hub repository (e.g., "kernels-community/activation") | +| `revision` | str | "main" | Branch, tag, or commit hash | +| `version` | int/str | None | Kernel version number (mutually exclusive with `revision`) | + +**Returns:** `ModuleType` — the imported kernel module + +### has_kernel + +Check if a kernel build exists for your environment: + +```python +from kernels import has_kernel + +if has_kernel("kernels-community/triton-layer-norm"): + kernel = get_kernel("kernels-community/triton-layer-norm") +else: + print("No compatible build for this ROCm/PyTorch version") +``` + +### get_local_kernel + +Load a kernel from a local path (useful during development): + +```python +from kernels import get_local_kernel + +kernel = get_local_kernel("/path/to/my-kernel") +``` + +### load_kernel & get_locked_kernel + +For reproducible, offline-capable deployments using lockfiles: + +```python +from kernels import load_kernel, get_locked_kernel + +kernel = load_kernel("lockfile.json") +kernel = get_locked_kernel("kernels-community/activation", lockfile="kernel.lock") +``` + +## Usage Examples + +### 1. RMSNorm Kernel from Hub + +**Note:** The actual function name may vary by kernel version. Use `dir(kernel)` to inspect, and check for `rms_norm_fn`, `rms_norm`, or `rmsnorm`. + +```python +import torch +from kernels import get_kernel, has_kernel + +repo_id = "kernels-community/triton-layer-norm" + +if has_kernel(repo_id): + layer_norm = get_kernel(repo_id) + + # Inspect available functions + print([f for f in dir(layer_norm) if not f.startswith('_')]) + # e.g. ['layer_norm', 'layer_norm_fn', 'rms_norm_fn', ...] + + x = torch.randn(2, 1024, 2048, dtype=torch.bfloat16, device="cuda") + weight = torch.ones(2048, dtype=torch.bfloat16, device="cuda") + + # Use the actual function name (rms_norm_fn in current version) + out = layer_norm.rms_norm_fn(x, weight, eps=1e-6) + print(f"Output shape: {out.shape}") +else: + print("No ROCm-compatible build available") +``` + +### 2. Integration with Transformers Models + +```python +import torch +from kernels import get_kernel, has_kernel + +repo_id = "kernels-community/triton-layer-norm" + +if has_kernel(repo_id): + rmsnorm_kernel = get_kernel(repo_id) + + def patch_rmsnorm_with_hub_kernel(model): + """Patch model's RMSNorm to use Hub kernel.""" + patched = 0 + for name, module in model.named_modules(): + if 'RMSNorm' in type(module).__name__: + eps = getattr(module, 'variance_epsilon', None) or getattr(module, 'eps', 1e-6) + + def make_forward(mod, epsilon): + def forward(hidden_states): + return rmsnorm_kernel.rms_norm(hidden_states, mod.weight, eps=epsilon) + return forward + + module.forward = make_forward(module, eps) + patched += 1 + return patched +``` + +### 3. Integration with Diffusers Pipelines + +```python +import torch +from diffusers import LTXPipeline +from kernels import get_kernel, has_kernel + +if has_kernel("kernels-community/triton-layer-norm"): + rmsnorm_kernel = get_kernel("kernels-community/triton-layer-norm") + + def patch_rmsnorm(model): + for name, module in model.named_modules(): + if type(module).__name__ == 'RMSNorm': + eps = getattr(module, 'eps', 1e-6) + has_weight = hasattr(module, 'weight') and module.weight is not None + + if has_weight: + def make_forward(mod, epsilon): + def forward(x): + return rmsnorm_kernel.rms_norm(x, mod.weight, eps=epsilon) + return forward + module.forward = make_forward(module, eps) + + pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) + pipe.to("cuda") + patch_rmsnorm(pipe.transformer) +``` + +### 4. Benchmark Hub Kernel vs PyTorch + +```python +import time +import torch +from kernels import get_kernel + +kernel = get_kernel("kernels-community/triton-layer-norm") + +sizes = [(2, 1024, 2048), (4, 4096, 4096)] +for shape in sizes: + x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") + w = torch.ones(shape[-1], dtype=torch.bfloat16, device="cuda") + + for _ in range(10): + kernel.rms_norm(x, w, eps=1e-6) + variance = x.pow(2).mean(-1, keepdim=True) + _ = x * torch.rsqrt(variance + 1e-6) * w + torch.cuda.synchronize() + + iters = 100 + start = time.perf_counter() + for _ in range(iters): + kernel.rms_norm(x, w, eps=1e-6) + torch.cuda.synchronize() + hub_ms = (time.perf_counter() - start) / iters * 1000 + + start = time.perf_counter() + for _ in range(iters): + variance = x.pow(2).mean(-1, keepdim=True) + _ = x * torch.rsqrt(variance + 1e-6) * w + torch.cuda.synchronize() + pt_ms = (time.perf_counter() - start) / iters * 1000 + + print(f"Shape {shape}: Hub={hub_ms:.3f}ms, PyTorch={pt_ms:.3f}ms, Speedup={pt_ms/hub_ms:.2f}x") +``` + +## ROCm-Specific Notes + +### Kernel Compatibility + +Not all Hub kernels have ROCm builds: + +| Kernel Type | ROCm Support | Notes | +|-------------|:------------:|-------| +| Triton-based (e.g., `triton-layer-norm`) | Likely | Triton compiles to HIP | +| CUDA C-based (e.g., `flash-attn`) | Check | Needs explicit ROCm build | +| Custom CUDA ops | Unlikely | CUDA-only unless HIP-ported | + +**Always check availability first:** +```python +from kernels import has_kernel + +if has_kernel("kernels-community/triton-layer-norm"): + print("ROCm build available") +else: + print("No ROCm build — use local Triton kernel instead") +``` + +### Fallback Strategy + +When a Hub kernel is not available for ROCm, fall back to the local Triton implementation: + +```python +from kernels import has_kernel, get_kernel + +def get_rmsnorm_function(): + """Get best available RMSNorm implementation.""" + if has_kernel("kernels-community/triton-layer-norm"): + kernel = get_kernel("kernels-community/triton-layer-norm") + return lambda x, w, eps: kernel.rms_norm(x, w, eps=eps) + else: + from your_local_kernels import triton_rmsnorm + return triton_rmsnorm +``` + +### Environment Check + +```python +import torch +print(f"PyTorch: {torch.__version__}") +print(f"HIP version: {getattr(torch.version, 'hip', 'N/A')}") +print(f"GPU: {torch.cuda.get_device_name()}") +print(f"GPU arch: {torch.cuda.get_device_capability()}") +``` + +## Publishing Kernels to Hub + +### Triton Kernel Project Structure + +For Triton-based kernels (best ROCm compatibility): + +``` +my-triton-kernel/ +├── build.toml +├── kernel_src/ +│ └── rmsnorm.py # Triton kernel source +└── torch-ext/ + ├── torch_binding.cpp + └── my_kernels/ + └── __init__.py +``` + +### build.toml for Triton Kernels + +```toml +[general] +name = "my_triton_kernels" +backends = ["cuda", "rocm"] # Include ROCm backend + +[torch] +src = ["torch-ext/torch_binding.cpp"] + +[kernel.rmsnorm] +backend = "triton" +src = ["kernel_src/rmsnorm.py"] +depends = ["torch"] +``` + +### Build and Publish + +```bash +pip install kernel-builder +kernel-builder build + +huggingface-cli repo create your-org/your-kernel --type model +huggingface-cli upload your-org/your-kernel ./dist +``` + +### Others Load It + +```python +from kernels import get_kernel + +rmsnorm = get_kernel("your-org/your-kernel") +``` + +## Available Community Kernels + +Popular kernels from `kernels-community`: + +| Kernel | Description | ROCm? | +|--------|-------------|:-----:| +| `triton-layer-norm` | LayerNorm, RMSNorm | Likely | +| `activation` | GELU, SiLU, etc. | Check | +| `flash-attn` | Flash Attention 2 | Check | +| `quantization` | INT8/INT4 ops | Check | + +Browse all kernels: https://huggingface.co/kernels-community + +## Caching and Offline Usage + +```python +import os +os.environ["HF_HUB_OFFLINE"] = "1" + +# Will only use cached kernels +kernel = get_kernel("kernels-community/triton-layer-norm") +``` + +## Best Practices + +1. **Always check availability** — `has_kernel()` before `get_kernel()` +2. **Pin versions** — `get_kernel(repo, version=1)` for reproducibility +3. **Have a fallback** — local Triton kernel when Hub build is unavailable +4. **Use lockfiles in production** — `load_kernel("kernel.lock")` +5. **Test on your GPU** — verify correctness after loading + +## See Also + +- [HuggingFace Kernels Documentation](https://huggingface.co/docs/kernels/en/index) +- [HuggingFace Kernels GitHub](https://github.com/huggingface/kernels) +- [Kernel Builder Documentation](https://github.com/huggingface/kernel-builder) +- [Community Kernels](https://huggingface.co/kernels-community) +- [Blog: Learn the Kernel Hub in 5 Minutes](https://huggingface.co/blog/hello-hf-kernels) diff --git a/skills/rocm-kernels/references/kernel-templates.md b/skills/rocm-kernels/references/kernel-templates.md new file mode 100644 index 00000000..702debbf --- /dev/null +++ b/skills/rocm-kernels/references/kernel-templates.md @@ -0,0 +1,512 @@ +# Triton Kernel Templates for ROCm (LTX-Video Operators) + +Copy-paste ready Triton kernel templates for RMSNorm, RoPE 3D, GEGLU, and AdaLN on AMD GPUs. + +## Required Header + +**Every kernel file MUST start with:** + +```python +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' + +import torch +import torch.nn as nn +import triton +import triton.language as tl +``` + +## Template 1: RMSNorm (Core Target) + +Row-wise reduction. **168 instances** in LTX-Video. Handles both with-weight and no-weight variants. + +**CRITICAL: Do NOT autotune BLOCK_D.** Autotune may select `BLOCK_D < D`, causing partial row processing and completely wrong results. Always compute `BLOCK_D = triton.next_power_of_2(D)` dynamically. + +### Triton Kernel + +```python +@triton.jit +def rmsnorm_fwd_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x, D, + eps: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_D) + mask = col_offsets < D + + x = tl.load(x_ptr + row * stride_x + col_offsets, mask=mask, other=0.0).to(tl.float32) + + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + + if HAS_WEIGHT: + w = tl.load(weight_ptr + col_offsets, mask=mask, other=1.0).to(tl.float32) + result = x * rms_inv * w + else: + result = x * rms_inv + + tl.store(out_ptr + row * stride_x + col_offsets, result.to(x.dtype), mask=mask) +``` + +### Python API + +```python +def triton_rmsnorm( + x: torch.Tensor, + weight: torch.Tensor = None, + eps: float = 1e-6, +) -> torch.Tensor: + orig_shape = x.shape + x_2d = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_2d) + M, D = x_2d.shape + + has_weight = weight is not None + if not has_weight: + weight = torch.empty(0, device=x.device) + + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + rmsnorm_fwd_kernel[(M,)]( + x_2d, weight, out, + x_2d.stride(0), D, eps, has_weight, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view(orig_shape) +``` + +### Benchmark + +```python +def benchmark_rmsnorm(): + configs = [ + (1, 1024, 2048), + (2, 1024, 2048), + (2, 4096, 3072), + ] + for batch, seq, hidden in configs: + x = torch.randn(batch, seq, hidden, device='cuda', dtype=torch.float16) + w = torch.ones(hidden, device='cuda', dtype=torch.float16) + + # Reference + ref = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6) * w + + # Custom + out = triton_rmsnorm(x, w, eps=1e-6) + + # Verify + torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-3) + print(f"[{batch}x{seq}x{hidden}] ✓ Correct") +``` + +## Template 2: RoPE 3D (Video Position Encoding) + +Element-wise rotation. Splits head_dim into temporal + spatial (height + width) components. + +**CRITICAL: cos/sin have shape `[seq_len, head_dim]`, NOT `[batch*seq_len, ...]`.** When the grid flattens the batch dimension, use `pid_s % seq_len` to index cos/sin, otherwise batch > 1 causes out-of-bounds GPU crash. + +### Triton Kernel + +```python +@triton.jit +def rope_3d_fwd_kernel( + qk_ptr, cos_ptr, sin_ptr, out_ptr, + seq_len, num_heads, head_dim, + stride_s, stride_h, stride_d, + BLOCK_HD: tl.constexpr, +): + pid_s = tl.program_id(0) # ranges [0, batch * seq_len) + pid_h = tl.program_id(1) + + half_dim = head_dim // 2 + offs = tl.arange(0, BLOCK_HD) + mask = offs < half_dim + + base = pid_s * stride_s + pid_h * stride_h + x0 = tl.load(qk_ptr + base + offs, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load(qk_ptr + base + half_dim + offs, mask=mask, other=0.0).to(tl.float32) + + seq_idx = pid_s % seq_len # wrap for batch > 1 + cos_val = tl.load(cos_ptr + seq_idx * head_dim + offs, mask=mask, other=1.0).to(tl.float32) + sin_val = tl.load(sin_ptr + seq_idx * head_dim + offs, mask=mask, other=0.0).to(tl.float32) + + out0 = x0 * cos_val - x1 * sin_val + out1 = x0 * sin_val + x1 * cos_val + + tl.store(out_ptr + base + offs, out0.to(x0.dtype), mask=mask) + tl.store(out_ptr + base + half_dim + offs, out1.to(x0.dtype), mask=mask) +``` + +### Python API + +```python +def triton_rope_3d( + qk: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: + """ + Apply 3D RoPE to Q or K tensor. + + Args: + qk: [batch, seq_len, num_heads, head_dim] + cos: [seq_len, head_dim] — NOT batch-expanded! + sin: [seq_len, head_dim] + """ + qk = qk.contiguous() + out = torch.empty_like(qk) + batch, seq_len, num_heads, head_dim = qk.shape + + qk_flat = qk.view(batch * seq_len, num_heads, head_dim) + out_flat = out.view(batch * seq_len, num_heads, head_dim) + + BLOCK_HD = triton.next_power_of_2(head_dim // 2) + num_warps = 4 if BLOCK_HD <= 64 else 8 + + rope_3d_fwd_kernel[(batch * seq_len, num_heads)]( + qk_flat, cos, sin, out_flat, + seq_len, num_heads, head_dim, + qk_flat.stride(0), qk_flat.stride(1), qk_flat.stride(2), + BLOCK_HD=BLOCK_HD, num_warps=num_warps, num_stages=2, + ) + return out +``` + +## Template 3: GEGLU (For SD3/FLUX) + +Gated activation: `GELU(gate) * value`. Input splits in half along last dim. + +**Note: LTX-Video uses GELU, NOT GEGLU. This template is for SD3/FLUX.** + +### Triton Kernel + +```python +@triton.jit +def geglu_fwd_kernel( + input_ptr, output_ptr, + stride_in, stride_out, hidden_size, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_H) + mask = offs < hidden_size + + gate = tl.load(input_ptr + row * stride_in + offs, + mask=mask, other=0.0).to(tl.float32) + value = tl.load(input_ptr + row * stride_in + hidden_size + offs, + mask=mask, other=0.0).to(tl.float32) + + # Manual tanh — tl.math.tanh / tl.libdevice.tanh NOT available on ROCm + SQRT_2_OVER_PI = 0.7978845608028654 + tanh_arg = SQRT_2_OVER_PI * (gate + 0.044715 * gate * gate * gate) + e2x = tl.exp(2.0 * tanh_arg) + tanh_val = (e2x - 1.0) / (e2x + 1.0) + cdf = 0.5 * (1.0 + tanh_val) + gelu_gate = gate * cdf + + result = gelu_gate * value + tl.store(output_ptr + row * stride_out + offs, result.to(gate.dtype), mask=mask) +``` + +### Python API + +```python +def triton_geglu(x: torch.Tensor) -> torch.Tensor: + """ + GEGLU activation: GELU(x[..., :H]) * x[..., H:] + + Input: [..., 2*hidden_size] → Output: [..., hidden_size] + """ + x = x.contiguous() + *batch_dims, double_h = x.shape + hidden_size = double_h // 2 + + x_2d = x.view(-1, double_h) + M = x_2d.shape[0] + out = torch.empty(M, hidden_size, device=x.device, dtype=x.dtype) + + BLOCK_H = triton.next_power_of_2(hidden_size) + num_warps = 4 if BLOCK_H <= 1024 else (8 if BLOCK_H <= 4096 else 16) + geglu_fwd_kernel[(M,)]( + x_2d, out, + x_2d.stride(0), out.stride(0), hidden_size, + BLOCK_H=BLOCK_H, num_warps=num_warps, num_stages=2, + ) + return out.view(*batch_dims, hidden_size) +``` + +## Template 4: AdaLN (Adaptive Layer Normalization) + +Fused RMSNorm + adaptive conditioning for DiT blocks. +Formula: `norm(x) * weight * (1 + scale) + shift` + +### Triton Kernel + +```python +@triton.jit +def adaln_fwd_kernel( + x_ptr, weight_ptr, scale_ptr, shift_ptr, out_ptr, + stride_x, stride_cond, D, + eps: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + x_norm = x * rms_inv + + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + scale = tl.load(scale_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) + shift = tl.load(shift_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) + + out = x_norm * w * (1.0 + scale) + shift + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) +``` + +### Python API + +```python +def triton_adaln( + x: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + shift: torch.Tensor, + eps: float = 1e-6, +) -> torch.Tensor: + """ + Adaptive Layer Normalization for DiT blocks. + + Args: + x: [batch, seq, hidden] + weight: [hidden] + scale: [batch, seq, hidden] or [batch, 1, hidden] + shift: [batch, seq, hidden] or [batch, 1, hidden] + """ + x_flat = x.contiguous().view(-1, x.shape[-1]) + scale_flat = scale.contiguous().view(-1, x.shape[-1]) + shift_flat = shift.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_flat) + M, D = x_flat.shape + + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + adaln_fwd_kernel[(M,)]( + x_flat, weight, scale_flat, shift_flat, out, + x_flat.stride(0), scale_flat.stride(0), D, eps, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) +``` + +## Common Math Replacements for ROCm + +| Standard | ROCm Triton Replacement | +|----------|------------------------| +| `tl.tanh(x)` | Manual: `e2x = tl.exp(2.0*x); (e2x-1)/(e2x+1)` | +| `tl.math.tanh(x)` | **Also NOT available on ROCm** — use manual formula above | +| `tl.libdevice.*` | Remove entirely, use manual implementations | +| `min(a, b)` | `tl.minimum(a, b)` | +| `max(a, b)` | `tl.maximum(a, b)` | +| GELU exact | `0.5 * x * (1 + erf(x / sqrt(2)))` | +| GELU approx | `0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))` | + +## Kernel-Specific Guidelines + +### RMSNorm +- Input: `[..., hidden_size]` — flatten to 2D `[M, D]` +- Epsilon default: 1e-6 +- **Weight may be None** if `elementwise_affine=False` +- Always accumulate `x*x` sum in FP32 +- **BLOCK_D = `triton.next_power_of_2(D)`** — compute in wrapper, NEVER autotune +- Autotuning BLOCK_D is dangerous: if BLOCK_D < D, only partial row is processed → wrong results + +### RoPE 3D +- 1D: `[batch, seq, heads, head_dim]` for text +- 3D: `[batch, t*h*w, heads, head_dim]` for video +- LTX-Video computes RoPE via `LTXVideoRotaryPosEmbed` — kernel replaces the apply step +- head_dim typically 64 or 128 +- **cos/sin shape is `[seq_len, head_dim]`** — use `pid_s % seq_len` for batch > 1 + +### GEGLU vs GELU +- **GEGLU**: Input `[B, S, 2*H]` → Output `[B, S, H]` — gate/value split +- **GELU**: Standard activation, no split +- **LTX-Video uses GELU, NOT GEGLU** +- GEGLU is for SD3/FLUX + +### AdaLN +- Formula: `norm(x) * weight * (1 + scale) + shift` +- Scale/shift come from timestep embedding MLP +- DiT computes 6 values per block: `(scale1, shift1, gate1, scale2, shift2, gate2)` +- Fusing norm + conditioning saves one memory round-trip + +## Template 5: GEMM with XCD Swizzle (MI355X) + +Tiled matrix multiplication with XCD swizzle for MI355X (32 XCDs). **Mandatory** for any GEMM-like operation on MI355X — without it, work clusters on a few chiplets, wasting 90%+ of the GPU. + +> See [mi355x-optimization-guide.md](mi355x-optimization-guide.md) for architecture details. + +**When to use XCD swizzle:** GEMM, batched GEMM, attention (Q@K, score@V). NOT needed for elementwise, reduction, or normalization kernels. + +### Triton Kernel + +```python +NUM_XCDS = 32 # MI355X has 32 XCDs + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=2, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=2, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=2, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, + num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=3, num_warps=8), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def gemm_xcd_swizzle_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pids = num_pid_m * num_pid_n + + # --- XCD Swizzle: distribute blocks across 32 chiplets --- + pids_per_xcd = (num_pids + NUM_XCDS - 1) // NUM_XCDS + xcd_id = pid % NUM_XCDS + local_pid = pid // NUM_XCDS + if local_pid < pids_per_xcd: + remapped_pid = xcd_id * pids_per_xcd + local_pid + if remapped_pid < num_pids: + pid = remapped_pid + + # --- L2 Cache Grouping (after XCD swizzle) --- + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # --- Compute GEMM tile --- + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + offs_k = tl.arange(0, BLOCK_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) + acc += tl.dot(a, b) + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + # --- Store result --- + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, acc.to(tl.float16), mask=mask) +``` + +### Python API + +```python +def triton_gemm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Matrix multiplication C = A @ B with XCD swizzle for MI355X. + + Args: + a: [M, K] input matrix + b: [K, N] input matrix + Returns: + c: [M, N] output matrix + """ + assert a.shape[1] == b.shape[0], "Inner dimensions must match" + assert a.is_contiguous() and b.is_contiguous() + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) + gemm_xcd_swizzle_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + ) + return c +``` + +### Benchmark + +```python +def benchmark_gemm(): + configs = [(4096, 4096, 4096), (8192, 8192, 4096), (2048, 8192, 2048)] + for M, N, K in configs: + a = torch.randn(M, K, device='cuda', dtype=torch.float16) + b = torch.randn(K, N, device='cuda', dtype=torch.float16) + + ref = torch.mm(a, b) + out = triton_gemm(a, b) + torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-1) + + # Benchmark + for _ in range(10): + triton_gemm(a, b) + torch.mm(a, b) + torch.cuda.synchronize() + + import time + iters = 50 + start = time.perf_counter() + for _ in range(iters): + triton_gemm(a, b) + torch.cuda.synchronize() + custom_ms = (time.perf_counter() - start) / iters * 1000 + + start = time.perf_counter() + for _ in range(iters): + torch.mm(a, b) + torch.cuda.synchronize() + torch_ms = (time.perf_counter() - start) / iters * 1000 + + print(f"[{M}x{N}x{K}] Custom: {custom_ms:.2f}ms, Torch: {torch_ms:.2f}ms, " + f"Speedup: {torch_ms/custom_ms:.2f}x") +``` + +### GEMM-Specific Guidelines + +- **XCD Swizzle is MANDATORY** on MI355X for any GEMM — without it, expect 0.3-0.5x +- **L2 Cache Grouping** (`GROUP_M=8-16`): Improves L2 hit rate after XCD swizzle +- **MFMA**: Use `matrix_instr_nonkdim=16` for MI355X matrix cores +- **FP32 accumulation**: Always accumulate in FP32, cast at store +- **LDS budget**: Check `BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N` * dtype * num_stages < 160 KB +- **Autotune**: GEMM benefits heavily from autotuning — always include 4+ configs +- **R9700**: Does NOT have XCDs — remove the XCD swizzle section for RDNA4 diff --git a/skills/rocm-kernels/references/kernelbench-classification.md b/skills/rocm-kernels/references/kernelbench-classification.md new file mode 100644 index 00000000..67ba8a76 --- /dev/null +++ b/skills/rocm-kernels/references/kernelbench-classification.md @@ -0,0 +1,162 @@ +# KernelBench Operator Classification & Skill Mapping + +This document classifies KernelBench operators into categories and maps each to the appropriate kernel skill/pattern. + +## Classification Taxonomy + +### Level 1: Basic Operators (53 operators) + +#### Category A: GEMM / Matrix Multiplication (18 operators) + +| ID | Name | Sub-type | Key Skill | +|----|------|----------|-----------| +| 1 | Square matrix multiplication | Dense GEMM | XCD Swizzle + Autotune | +| 2 | Standard matrix multiplication | Dense GEMM (M!=N) | XCD Swizzle + Autotune | +| 3 | Batched matrix multiplication | BMM | Batch-indexed GEMM | +| 4 | Matrix-vector multiplication | MatVec | 1D reduction pattern | +| 5 | Matrix-scalar multiplication | Elementwise | Scale kernel | +| 6 | Matmul with large K | Large-K GEMM | K-dimension blocking | +| 7 | Matmul with small K | Small-K GEMM | Fewer K-iterations | +| 8 | Matmul with irregular shapes | Non-square GEMM | Mask handling | +| 9 | Tall-skinny matmul | Tall-skinny GEMM | Tile shape tuning | +| 10 | 3D tensor-matrix mul | Batched GEMM | Reshape + GEMM | +| 11 | 4D tensor-matrix mul | Batched GEMM | Einsum decomposition | +| 12 | Diagonal matrix mul | Special GEMM | Elementwise pattern | +| 13 | Symmetric matrices | Dense GEMM | Standard GEMM | +| 14 | Upper triangular mul | Masked GEMM | Triangle mask | +| 15 | Lower triangular mul | Masked GEMM | Triangle mask | +| 16 | Transposed A | Transposed GEMM | Stride adjustment | +| 17 | Transposed B | Transposed GEMM | Stride adjustment | +| 18 | Both transposed | Transposed GEMM | Stride adjustment | + +**Key Pattern**: Template 5 (GEMM with XCD Swizzle) +**Critical Optimization**: XCD swizzle + L2 cache grouping + MFMA 16x16 + +#### Category B: Elementwise / Activation Functions (14 operators) + +| ID | Name | Sub-type | Key Skill | +|----|------|----------|-----------| +| 19 | ReLU | Branching | `tl.where(x > 0, x, 0)` | +| 20 | LeakyReLU | Branching | `tl.where(x > 0, x, alpha*x)` | +| 21 | Sigmoid | Transcendental | `1/(1+exp(-x))` | +| 22 | Tanh | Transcendental | `(exp(2x)-1)/(exp(2x)+1)` | +| 23 | Softmax | Row reduction | Online softmax | +| 24 | LogSoftmax | Row reduction | Online softmax + log | +| 25 | Swish/SiLU | Transcendental | `x * sigmoid(x)` | +| 26 | GELU | Transcendental | `0.5*x*(1+erf(x/sqrt(2)))` | +| 27 | SELU | Branching + exp | `scale * where(x>0, x, alpha*(exp(x)-1))` | +| 28 | HardSigmoid | Clamp | `clamp((x+3)/6, 0, 1)` | +| 29 | Softplus | Transcendental | `log(1+exp(x))` | +| 30 | Softsign | Division | `x/(1+abs(x))` | +| 31 | ELU | Branching + exp | `where(x>0, x, alpha*(exp(x)-1))` | +| 32 | HardTanh | Clamp | `clamp(x, -1, 1)` | + +**Key Pattern**: Template 1 (Elementwise) +**Critical Optimization**: Large BLOCK_SIZE (4096-16384), FP32 compute + +#### Category C: Normalization (8 operators) + +| ID | Name | Sub-type | Key Skill | +|----|------|----------|-----------| +| 33 | BatchNorm | Multi-dim reduction | Welford algorithm | +| 34 | InstanceNorm | Per-instance reduction | Per-sample norm | +| 35 | GroupNorm | Group reduction | Grouped channels | +| 36 | RMSNorm | Row reduction | `x * rsqrt(mean(x^2) + eps)` | +| 37 | FrobeniusNorm | Full reduction | `sqrt(sum(x^2))` | +| 38 | L1 Norm | Full reduction | `sum(abs(x))` | +| 39 | L2 Norm | Full reduction | `sqrt(sum(x^2))` | +| 40 | LayerNorm | Row reduction | `(x-mean)/std * w + b` | + +**Key Pattern**: Template 3 (Row-wise Reduction) +**Critical Optimization**: FP32 accumulation, proper reduction + +#### Category D: Pooling (6 operators) + +| ID | Name | Sub-type | Key Skill | +|----|------|----------|-----------| +| 41 | Max Pooling 1D | Sliding window | Max reduction | +| 42 | Max Pooling 2D | 2D window | 2D index mapping | +| 43 | Max Pooling 3D | 3D window | Program_id flattening | +| 44 | Average Pooling 1D | Sliding window | Sum + divide | +| 45 | Average Pooling 2D | 2D window | 2D index mapping | +| 46 | Average Pooling 3D | 3D window | Program_id flattening | + +**Key Challenge**: 3D grid mapping with Triton's program_id limits + +#### Category E: Reduction (7 operators) + +| ID | Name | Sub-type | Key Skill | +|----|------|----------|-----------| +| 47 | Sum reduction | Sum | `tl.sum()` | +| 48 | Mean reduction | Mean | `tl.sum() / count` | +| 49 | Max reduction | Max | `tl.max()` | +| 50 | Min reduction | Min | `tl.min()` | +| 51 | Argmax | Index + max | Two-pass or manual | +| 52 | Argmin | Index + min | Two-pass or manual | +| 53 | Min (duplicate) | Min | `tl.min()` | + +**Key Pattern**: Template 5 (Dimension Reduction) +**Key Challenge**: Argmax/Argmin require manual implementation + +### Level 2: Fused Operators (20+ operators) + +Combine multiple operations into single kernels. + +| Category | Examples | Strategy | +|----------|---------|----------| +| GEMM + Activation | Gemm_ReLU, Gemm_GELU | Fuse activation into GEMM epilogue | +| GEMM + Norm | Gemm_BatchNorm, Gemm_GroupNorm | Two-phase kernel | +| GEMM + Scale | Gemm_Scale, Gemm_Divide | Fuse into GEMM store | +| Multi-op fusion | Matmul_Sum_Max_AvgPool | Sequential fusion | + +**Key Pattern**: Template 6 (Fused GEMM + Activation) + +### Level 3-4: Network Models / Transformers + +Full models requiring multiple kernel types. Decompose into Level 1 operators. + +### Level 6-7: Advanced / Expert + +| Operator | Type | Strategy | +|----------|------|----------| +| MinGPTNewGelu | Fused activation | GELU approximation kernel | +| ScaledDotProductAttention | Attention | Flash Attention pattern | +| GELU_And_Mul | Fused activation | `gelu(x) * y` | +| MoE_TopK_Softmax | MoE routing | Specialized kernel | +| Gemm_A8W8_Blockwise | Quantized GEMM | INT8 with block scaling | + +## Category → Skill Mapping + +| Category | Skill File | Priority | +|----------|-----------|----------| +| **GEMM** | `gemm-skill.md` (planned) | P0 - Most impactful | +| **Elementwise** | `elementwise-skill.md` (planned) | P0 - Most common | +| **Normalization** | `normalization-skill.md` (planned) | P1 - Frequently used | +| **Reduction** | `reduction-skill.md` (planned) | P1 - Common pattern | +| **Softmax** | `softmax-skill.md` (planned) | P1 - Critical for attention | +| **Pooling** | `pooling-skill.md` (planned) | P2 - Moderate complexity | +| **Attention** | `attention-skill.md` (planned) | P2 - High complexity | +| **Fused** | `fused-skill.md` (planned) | P2 - Combination patterns | + +## Performance Expectations by Category + +Based on kernel-agent test results: + +| Category | Achievable Speedup | Difficulty | Notes | +|----------|-------------------|------------|-------| +| Elementwise | 1.0-3.0x | Low | Large blocks, memory-bound | +| Reduction (sum/mean) | 1.5-5.0x | Medium | Good parallelism | +| Pooling | 1.5-5.0x | Medium | Grid mapping challenge | +| LayerNorm/RMSNorm | 1.5-2.0x | Medium | Row-wise reduction | +| Dense GEMM | 0.8-1.2x | High | XCD swizzle critical | +| Batched GEMM | 0.6-0.9x | High | Memory bandwidth limited | +| BatchNorm | <0.1x | Very High | HIP sync issues | +| Argmax/Argmin | FAIL | Very High | Triton API limitation | +| Fused operators | 0.3-1.0x | Very High | Correctness challenges | + +## Recommended Skill Development Order + +1. **Phase 1 (Quick wins)**: Elementwise activations, Sum/Mean reduction +2. **Phase 2 (Core)**: GEMM with XCD swizzle, LayerNorm/RMSNorm +3. **Phase 3 (Advanced)**: Softmax, Pooling, Attention +4. **Phase 4 (Expert)**: Fused operators, BatchNorm, Quantized GEMM diff --git a/skills/rocm-kernels/references/mi355x-optimization-guide.md b/skills/rocm-kernels/references/mi355x-optimization-guide.md new file mode 100644 index 00000000..966ff7a3 --- /dev/null +++ b/skills/rocm-kernels/references/mi355x-optimization-guide.md @@ -0,0 +1,233 @@ +# MI355X (gfx950) Optimization Guide + +Deep dive into MI355X-specific optimizations for Triton kernels on ROCm. + +## MI355X CDNA3+ Architecture + +### Key Specifications + +| Component | Value | vs MI300X | +|-----------|-------|-----------| +| Compute Capability | gfx950 | gfx942 | +| Architecture | CDNA3+ | CDNA3 | +| **XCDs (Chiplets)** | **32** | 8 | +| CUs Total | 256 | 228 | +| CUs per XCD | 8 | 28 | +| **LDS per CU** | **160 KB** | 64 KB | +| L2 Cache | 256 MB | 256 MB | +| Wavefront Size | 64 | 64 | +| GPU Memory | 288 GB HBM3e | 192 GB HBM3 | +| **Memory Bandwidth** | **8 TB/s** | 5.3 TB/s | +| FP16/BF16 Matrix TFLOPS | ~2500 | 1307 | +| FP8 Matrix TFLOPS | ~5000 | 2615 | +| MFMA Instructions | 16x16, 32x32 | 16x16, 32x32 | +| FP8 Format | float8_e4m3fn (OCP) | float8_e4m3fnuz (AMD) | + +### Critical Architecture Differences from MI300X + +1. **32 XCDs vs 8**: XCD swizzle must use `NUM_XCDS=32` +2. **8 CUs per XCD vs 28**: Finer-grained chiplet distribution +3. **160 KB LDS vs 64 KB**: 2.5x larger local memory per CU +4. **8 TB/s vs 5.3 TB/s**: 50% more memory bandwidth +5. **OCP FP8 vs AMD FP8**: Different FP8 format + +## XCD Swizzle (MANDATORY for GEMM) + +MI355X has 32 XCDs. Without proper swizzle, GEMM blocks cluster on a few XCDs, wasting 90%+ of the GPU. + +### When to Use XCD Swizzle + +| Kernel Type | XCD Swizzle? | Why | +|-------------|-------------|-----| +| GEMM / matmul | **YES, MANDATORY** | Multi-block work distribution | +| Elementwise | No | Single-block independent | +| Reduction | No | Row-independent | +| Normalization | No | Row-independent | +| Attention | **YES** (for Q@K and score@V) | Contains GEMM | + +### XCD Swizzle Implementation + +```python +NUM_XCDS = 32 + +@triton.jit +def gemm_with_xcd_swizzle(...): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pids = num_pid_m * num_pid_n + + # Step 1: XCD Swizzle + pids_per_xcd = (num_pids + NUM_XCDS - 1) // NUM_XCDS + xcd_id = pid % NUM_XCDS + local_pid = pid // NUM_XCDS + if local_pid < pids_per_xcd: + remapped_pid = xcd_id * pids_per_xcd + local_pid + if remapped_pid < num_pids: + pid = remapped_pid + + # Step 2: L2 Cache Grouping (after XCD swizzle) + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m +``` + +### Performance Impact + +| Config | Without XCD Swizzle | With XCD Swizzle | Improvement | +|--------|-------------------|-----------------|-------------| +| Square GEMM 4096x4096 | 0.3-0.5x | 0.8-1.2x | 2-4x | +| Tall-skinny GEMM | 0.4-0.6x | 0.7-1.0x | 1.5-2.5x | + +## MFMA Instructions + +Use 16x16 MFMA for optimal matrix core utilization: + +```python +# Launch kernel with MFMA hint +kernel[grid](..., matrix_instr_nonkdim=16) +``` + +## LDS Optimization + +MI355X has 160 KB LDS per CU—2.5x more than MI300X. + +### LDS Budget Calculation + +``` +LDS usage = BLOCK_M × BLOCK_K × dtype_size + BLOCK_K × BLOCK_N × dtype_size + × num_stages + +Example (BLOCK_M=256, BLOCK_N=256, BLOCK_K=64, FP16, num_stages=2): + = (256×64×2 + 64×256×2) × 2 = 131,072 bytes = 128 KB < 160 KB ✓ + +Same config on MI300X (64 KB LDS): + 128 KB > 64 KB ✗ → Need num_stages=1 or smaller blocks +``` + +### Stage Configuration + +| LDS Budget | MI355X num_stages | MI300X num_stages | +|------------|------------------|------------------| +| < 80 KB | 2-3 | 2 | +| 80-160 KB | 2 | 1 (or reduce blocks) | +| > 160 KB | 1 (or reduce blocks) | Not possible | + +## Autotune Configurations + +### Elementwise Operations + +```python +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 1024}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE': 2048}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_SIZE': 4096}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_SIZE': 4096}, num_warps=16, num_stages=2), + triton.Config({'BLOCK_SIZE': 8192}, num_warps=16, num_stages=2), + triton.Config({'BLOCK_SIZE': 16384}, num_warps=16, num_stages=2), + ], + key=['n_elements'], +) +``` + +### GEMM Operations + +```python +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=2, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=2, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=2, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, + num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, + num_stages=3, num_warps=8), + ], + key=['M', 'N', 'K'], +) +``` + +### Problem-Specific Block Sizes + +| Problem Type | BLOCK_M | BLOCK_N | BLOCK_K | num_stages | num_warps | GROUP_M | +|-------------|---------|---------|---------|------------|-----------|---------| +| Square GEMM (M,N>=4096) | 256 | 256 | 32 | 3 | 8 | 16 | +| Large K (K > max(M,N)) | 128 | 128 | 64 | 2 | 8 | 8 | +| Fused GEMM+Activation | 128 | 128 | 64 | 2 | 8 | 8 | +| Element-wise ops | - | - | - | 2 | 4-16 | - | + +## Precision and Numerical Stability + +### FP32 Accumulation (Required) + +```python +# Always accumulate in FP32 +acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) +for k in range(...): + acc += tl.dot(a, b) +# Cast at store +c = acc.to(tl.float16) +``` + +### Math Operations + +```python +# Cast to FP32 for transcendental functions +x_f32 = x.to(tl.float32) +result = tl.exp(x_f32) # ✓ +result = tl.log(x_f32) # ✓ +result = tl.sqrt(x_f32) # ✓ +result = 1.0 / x_f32 # ✓ (division in FP32) + +# tanh workaround (tl.tanh not supported on AMD) +e2x = tl.exp(2.0 * x_f32) +tanh_x = (e2x - 1.0) / (e2x + 1.0) +``` + +## Performance Profiling + +```bash +# Basic kernel profiling +rocprof --stats python your_kernel.py + +# Detailed metrics +rocprofv3 -i metrics.txt python your_kernel.py + +# Key metrics to watch: +# - L2 cache hit rate (target >70%) +# - VGPR usage (128+ may limit occupancy) +# - LDS usage (max 160 KB on MI355X) +# - Memory bandwidth utilization (target 40-60% of 8 TB/s) +``` + +## Environment Variables + +```python +import os +# Block ping-pong for better latency hiding +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +# Async memory copies +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' +``` + +## Best Practices Summary + +1. **XCD Swizzle**: Always for GEMM, never for elementwise +2. **MFMA**: Use matrix_instr_nonkdim=16 +3. **LDS**: Leverage 160 KB, but check with num_stages +4. **num_stages**: 2-3 (safe), up to 4 if LDS permits +5. **num_warps**: 8 is default, autotune 4-16 +6. **BLOCK_SIZE**: Larger than MI300X (1024-16384 for 1D) +7. **GROUP_M**: 8 or 16 for L2 cache grouping +8. **FP32 acc**: Always accumulate in FP32 +9. **Env vars**: Set BLOCK_PINGPONG and ASYNC_COPY +10. **Profile**: Use rocprof to validate optimizations diff --git a/skills/rocm-kernels/references/r9700-optimization-guide.md b/skills/rocm-kernels/references/r9700-optimization-guide.md new file mode 100644 index 00000000..8eea3489 --- /dev/null +++ b/skills/rocm-kernels/references/r9700-optimization-guide.md @@ -0,0 +1,172 @@ +# R9700 (RDNA4, gfx1201) Optimization Guide + +Deep dive into R9700-specific optimizations for Triton kernels on ROCm. + +## R9700 RDNA4 Architecture + +### Key Specifications + +| Component | R9700 | vs MI355X | +|-----------|-------|-----------| +| Compute Capability | gfx1201 | gfx950 | +| Architecture | RDNA4 | CDNA3+ | +| **Wavefront Size** | **32 (Wave32)** | 64 (Wave64) | +| CUs | 64 | 256 | +| Stream Processors | 4096 | - | +| LDS per CU | 64 KB | 160 KB | +| L1 Cache | 32 KB | - | +| L2 Cache | 8 MB | 256 MB | +| L3 Cache | 64 MB | - | +| **Cacheline Size** | **256 B** | - | +| Max Threads/Block | 1024 | 1024 | +| Max Threads/CU | 2048 | 2048 | +| Max Waves/CU | 32 | - | +| SIMDs per CU | 2 | - | +| FP32 Vector TFLOPS | 47.8 | ~200 | +| FP16 Vector TFLOPS | 95.7 | ~2500 | +| FP16 Matrix TFLOPS | 191 | ~2500 | +| Matrix Cores | Limited (no FP8 MFMA) | Full MFMA | + +### Critical RDNA4 vs CDNA3+ Differences + +1. **Wave32 vs Wave64**: Warp size is 32, same as NVIDIA +2. **No XCD Swizzle**: Single die, no chiplet distribution needed +3. **Limited Matrix Cores**: No FP8 MFMA support +4. **Smaller LDS**: 64 KB vs 160 KB +5. **Smaller L2 Cache**: 8 MB vs 256 MB +6. **256B Cacheline**: Stricter memory alignment requirements +7. **Consumer GPU**: Optimized for inference, not training + +## Wave32 Implications + +### num_warps Mapping + +On RDNA4, `num_warps` still means "number of wavefronts per block": +- 1 warp = 32 threads (Wave32) +- Max 32 waves per CU +- num_warps range: 2-8 (smaller than CDNA) + +```python +# CDNA (MI355X): 1 warp = 64 threads +# num_warps=8 → 512 threads/block + +# RDNA4 (R9700): 1 warp = 32 threads +# num_warps=8 → 256 threads/block +# Use higher num_warps if needed for same thread count +``` + +### Reduction Code + +Warp-level reductions use different offsets: + +```python +# CDNA (Wave64): offsets = 32, 16, 8, 4, 2, 1 +# RDNA4 (Wave32): offsets = 16, 8, 4, 2, 1 + +# In Triton this is handled automatically by tl.sum(), tl.max(), etc. +# No manual shuffle code needed in Triton +``` + +## Memory Hierarchy + +### 256B Cacheline Alignment + +R9700 uses 256-byte cachelines (vs 128B on RDNA3). Misaligned accesses are penalized more. + +```python +# Ensure contiguous memory access +x = x.contiguous() + +# In kernel: sequential access pattern +offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) +x = tl.load(x_ptr + offsets, mask=mask) # Coalesced +``` + +### L2 Cache Strategy + +With only 8 MB L2, cache reuse is limited: + +```python +# For GEMM: use smaller tiles to fit in L2 +# BLOCK_M=64, BLOCK_N=64, BLOCK_K=32 +# Tile = 64×32×2 + 32×64×2 = 8 KB per stage +# With 2 stages: 16 KB fits in L2 per block +``` + +### LDS (64 KB) Budget + +``` +Max LDS per CU = 64 KB + +GEMM example (BLOCK_M=64, BLOCK_N=128, BLOCK_K=32, FP16, num_stages=2): + = (64×32×2 + 32×128×2) × 2 = 24,576 bytes = 24 KB ✓ + +GEMM example (BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, FP16, num_stages=2): + = (128×64×2 + 64×128×2) × 2 = 65,536 bytes = 64 KB → Borderline! +``` + +## Autotune Configurations + +### Elementwise Operations + +```python +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 256}, num_warps=2, num_stages=2), + triton.Config({'BLOCK_SIZE': 512}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE': 1024}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE': 1024}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_SIZE': 2048}, num_warps=8, num_stages=2), + ], + key=['n_elements'], +) +``` + +### GEMM Operations + +```python +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, + num_stages=2, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, + num_stages=2, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, + num_stages=2, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, + num_stages=2, num_warps=8), + ], + key=['M', 'N', 'K'], +) +``` + +## Grid Sizing + +With 64 CUs: + +```python +# Aim for multiples of 64 blocks +grid = (triton.cdiv(N, BLOCK_SIZE),) +# For GEMM: grid = (cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N),) +``` + +## Precision Considerations + +- FP16 Matrix TFLOPS = 191 (2x FP32 vector) +- FP16 Vector TFLOPS = 95.7 (2x FP32 vector) +- **No FP8 MFMA**: Cannot use FP8 matrix operations +- INT8 Matrix TOPS = 383 (quantized inference) +- Use FP16 for compute, FP32 for accumulation + +## Best Practices Summary + +1. **Wave32 awareness**: Use num_warps=2-8 +2. **No XCD Swizzle**: Not needed on single-die +3. **Smaller blocks**: 64-128 for GEMM tiles +4. **256B alignment**: Ensure contiguous memory access +5. **LDS budget**: Max 64 KB, keep num_stages=2 +6. **Grid sizing**: Multiples of 64 CUs +7. **FP16 preferred**: Best throughput, no FP8 MFMA +8. **L3 cache**: 64 MB can help with model weights +9. **Inference focus**: Best suited for inference workloads +10. **Cacheline**: 256B alignment is stricter than MI355X diff --git a/skills/rocm-kernels/references/transformers-integration.md b/skills/rocm-kernels/references/transformers-integration.md new file mode 100644 index 00000000..3841c489 --- /dev/null +++ b/skills/rocm-kernels/references/transformers-integration.md @@ -0,0 +1,340 @@ +# Transformers Library Integration Guide (ROCm / Triton) + +Complete guide for integrating custom Triton kernels into HuggingFace transformers models on AMD GPUs. + +> **Quick Start:** See [transformers_injection_example.py](../scripts/transformers_injection_example.py) for a minimal working example (~150 lines). + +## Overview + +The HuggingFace transformers library has different architecture patterns than diffusers. Understanding these patterns is critical for successful kernel integration with models like LLaMA, Mistral, Qwen, and other LLMs on ROCm. + +**Key difference from diffusers:** All transformers RMSNorm modules have weights (`elementwise_affine=True`). No need to handle the weight-less variant. + +## Model Architecture Analysis + +```python +from transformers import AutoModelForCausalLM, AutoConfig +import torch + +config = AutoConfig.from_pretrained("Qwen/Qwen3-8B") +print(f"Hidden size: {config.hidden_size}") # 4096 +print(f"Num layers: {config.num_hidden_layers}") # 32 +print(f"Num heads: {config.num_attention_heads}") # 32 +print(f"RMS norm eps: {config.rms_norm_eps}") # 1e-6 + +model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3-8B", + torch_dtype=torch.bfloat16, + device_map="cuda" # ROCm uses same API via HIP +) + +for name, module in model.named_modules(): + class_name = type(module).__name__ + if 'Norm' in class_name: + has_weight = hasattr(module, 'weight') and module.weight is not None + print(f"{name}: {class_name} (has_weight={has_weight})") +``` + +## Common Transformers Architectures + +### LLaMA / Llama-2 / Llama-3 + +| Component | Class | Has Weight | Notes | +|-----------|-------|------------|-------| +| `model.norm` | LlamaRMSNorm | Yes | Final layer norm | +| `model.layers.*.input_layernorm` | LlamaRMSNorm | Yes | Pre-attention norm | +| `model.layers.*.post_attention_layernorm` | LlamaRMSNorm | Yes | Pre-FFN norm | +| `model.layers.*.mlp` | LlamaMLP | - | Uses SiLU gating | + +### Mistral / Mixtral + +| Component | Class | Has Weight | Notes | +|-----------|-------|------------|-------| +| `model.norm` | MistralRMSNorm | Yes | Final layer norm | +| `model.layers.*.input_layernorm` | MistralRMSNorm | Yes | Pre-attention norm | +| `model.layers.*.post_attention_layernorm` | MistralRMSNorm | Yes | Pre-FFN norm | + +### Qwen / Qwen2 / Qwen3 + +| Component | Class | Has Weight | Notes | +|-----------|-------|------------|-------| +| `model.norm` | Qwen2RMSNorm | Yes | Final layer norm | +| `model.layers.*.input_layernorm` | Qwen2RMSNorm | Yes | Pre-attention norm | +| `model.layers.*.post_attention_layernorm` | Qwen2RMSNorm | Yes | Pre-FFN norm | + +### Kernel Applicability + +| Kernel | LLaMA | Mistral | Qwen | Notes | +|--------|-------|---------|------|-------| +| RMSNorm | **Yes** | **Yes** | **Yes** | All use RMSNorm with weights | +| GEGLU | No | No | No | Uses SiLU gating instead | +| RoPE | Indirect | Indirect | Indirect | Computed by transformers internally | +| Attention | Via SDPA | Via SDPA | Via SDPA | Use Flash Attention 2 | + +## Integration Pattern + +### Step 1: Set ROCm Environment Variables + +```python +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' +``` + +### Step 2: Define the Triton RMSNorm Kernel + +```python +import torch +import triton +import triton.language as tl + +@triton.jit +def rmsnorm_fwd_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x, D, + eps: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + out = x * rms_inv * w + + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) + + +def triton_rmsnorm(x, weight, eps=1e-6): + x_2d = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_2d) + M, D = x_2d.shape + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + rmsnorm_fwd_kernel[(M,)]( + x_2d, weight, out, x_2d.stride(0), D, eps, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) +``` + +### Step 3: Create RMSNorm Patcher + +```python +def patch_rmsnorm_modules(model) -> int: + """ + Patch all RMSNorm modules to use Triton kernel on ROCm. + + Works with LlamaRMSNorm, MistralRMSNorm, Qwen2RMSNorm, etc. + """ + patched_count = 0 + + for name, module in model.named_modules(): + class_name = type(module).__name__ + + if 'RMSNorm' in class_name: + # LLaMA uses 'variance_epsilon', others use 'eps' + eps = getattr(module, 'variance_epsilon', None) + if eps is None: + eps = getattr(module, 'eps', 1e-6) + + has_weight = hasattr(module, 'weight') and module.weight is not None + + if has_weight: + def make_patched_forward(mod, epsilon): + def patched_forward(hidden_states): + return triton_rmsnorm(hidden_states, mod.weight, eps=epsilon) + return patched_forward + module.forward = make_patched_forward(module, eps) + patched_count += 1 + else: + print(f"WARNING: {name} has no weight, skipping") + + return patched_count +``` + +### Step 4: Use in Script + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3-8B", + torch_dtype=torch.bfloat16, + device_map="cuda" # ROCm uses same device API via HIP +) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") + +count = patch_rmsnorm_modules(model) +print(f"Patched {count} RMSNorm modules") +# Expected: 65 modules (32 layers * 2 + 1 final) + +inputs = tokenizer("The capital of France is", return_tensors="pt").to("cuda") +outputs = model.generate(**inputs, max_new_tokens=20) +print(tokenizer.decode(outputs[0])) +``` + +## Key Differences from Diffusers + +### 1. RMSNorm Always Has Weight + +Unlike diffusers (where some RMSNorm modules have `elementwise_affine=False`), transformers RMSNorm modules **always** have weights. The `HAS_WEIGHT` branch is always true, so you can simplify the kernel to always load weights. + +### 2. Different Epsilon Attribute Names + +```python +# LLaMA uses 'variance_epsilon' +eps = getattr(module, 'variance_epsilon', 1e-6) + +# Some models use 'eps' +eps = getattr(module, 'eps', 1e-6) + +# Safe pattern +eps = getattr(module, 'variance_epsilon', None) or getattr(module, 'eps', 1e-6) +``` + +### 3. No Attention Processor Pattern + +Diffusers uses `set_processor()` for attention modules. Transformers does not: + +```python +# Transformers: Use Flash Attention 2 instead of custom processors +model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2" +) +``` + +### 4. Device Map vs Manual Move + +```python +# Transformers — use device_map for large models +model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map="auto" # Handles multi-GPU automatically +) + +# Diffusers — manual move then CPU offload +pipe = DiffusionPipeline.from_pretrained(model_id) +pipe.to("cuda") +pipe.enable_model_cpu_offload() +``` + +## ROCm-Specific Considerations + +### 1. ROCm Environment Setup + +```python +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' +``` + +### 2. No tl.libdevice / tl.math.tanh + +If you extend beyond RMSNorm (e.g., custom SiLU activation), remember tanh is not available: + +```python +# Manual tanh for ROCm +e2x = tl.exp(2.0 * x) +tanh_x = (e2x - 1.0) / (e2x + 1.0) +``` + +### 3. Verify HIP Backend + +```python +import torch +print(f"HIP version: {torch.version.hip}") # Should show ROCm version +print(f"GPU: {torch.cuda.get_device_name()}") +``` + +### 4. torch.compile on ROCm + +Custom Triton kernels and `torch.compile` can coexist on ROCm since Triton is already the compilation backend. However, test thoroughly as behavior may differ from eager mode. + +## Model-Specific Integration + +### LLaMA Models + +```python +from transformers import LlamaForCausalLM + +model = LlamaForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + torch_dtype=torch.bfloat16, + device_map="cuda" +) + +count = patch_rmsnorm_modules(model) +print(f"Patched {count} LlamaRMSNorm modules") +# Expected: 65 modules (32 layers * 2 + 1 final) +``` + +### Qwen3-8B + +```python +model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3-8B", + torch_dtype=torch.bfloat16, + device_map="cuda" +) + +count = patch_rmsnorm_modules(model) +print(f"Patched {count} Qwen2RMSNorm modules") +# Expected: 65 modules (32 layers * 2 + 1 final) +``` + +## Verification + +### Verify Injection Worked + +```python +x = torch.randn(1, 10, model.config.hidden_size, device='cuda', dtype=torch.bfloat16) +for name, module in model.named_modules(): + if 'RMSNorm' in type(module).__name__: + out = module(x) + print(f"RMSNorm forward pass: {x.shape} -> {out.shape}") + break +``` + +### Run Generation Test + +```python +inputs = tokenizer("Hello, my name is", return_tensors="pt").to("cuda") +with torch.inference_mode(): + outputs = model.generate(**inputs, max_new_tokens=20) +print(tokenizer.decode(outputs[0])) +``` + +### Profile on ROCm + +```bash +rocprof --stats python your_script.py +rocprofv3 -i metrics.txt python your_script.py +``` + +## Performance Optimization + +### Enable Flash Attention 2 + +```python +model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map="cuda" +) +``` + +### Combine with Custom Kernels + +```python +model = AutoModelForCausalLM.from_pretrained(model_id, ...) +patch_rmsnorm_modules(model) # Inject Triton RMSNorm +# Flash Attention 2 handles attention optimization +``` diff --git a/skills/rocm-kernels/references/troubleshooting.md b/skills/rocm-kernels/references/troubleshooting.md new file mode 100644 index 00000000..b0f0e443 --- /dev/null +++ b/skills/rocm-kernels/references/troubleshooting.md @@ -0,0 +1,292 @@ +# ROCm Triton Kernel Troubleshooting Guide + +Common issues and solutions when developing Triton kernels for AMD GPUs. + +## Build / Import Issues + +### 1. `tl.libdevice` Not Found + +**Error:** `AttributeError: module 'triton.language' has no attribute 'libdevice'` + +**Cause:** `tl.libdevice` is CUDA-only (NVIDIA's libdevice library). + +**Fix:** Replace with manual implementations: +```python +# WRONG (CUDA only) +tl.libdevice.tanh(x) +tl.libdevice.log1p(x) + +# CORRECT (ROCm compatible) +e2x = tl.exp(2.0 * x); tanh_x = (e2x - 1.0) / (e2x + 1.0) +log1p_x = tl.log(1.0 + x) +``` + +### 2. `tl.tanh` / `tl.math.tanh` Not Available + +**Error:** `AttributeError: module 'triton.language.math' has no attribute 'tanh'` + +**Cause:** Neither `tl.tanh`, `tl.math.tanh`, nor `tl.libdevice.tanh` exist on ROCm Triton. This is the most common GEGLU compilation failure. + +**Fix — manual tanh (ONLY reliable method):** +```python +x_f32 = x.to(tl.float32) +e2x = tl.exp(2.0 * x_f32) +tanh_x = (e2x - 1.0) / (e2x + 1.0) +``` + +## Runtime Errors + +### 3. HIP Runtime Error: Invalid Argument + +**Error:** `hipErrorInvalidValue` or `HIP Error: invalid argument` + +**Common causes:** +- Grid/block size exceeds hardware limits +- Mismatched tensor shapes +- LDS overflow + +**Fix:** +```python +# Check grid size +grid = (triton.cdiv(N, BLOCK_SIZE),) +assert grid[0] > 0, f"Grid size must be > 0, got {grid[0]}" + +# Ensure contiguous tensors +x = x.contiguous() + +# Reduce num_stages to avoid LDS overflow +# num_stages=2 is safest +``` + +### 4. HIP Out of Memory (LDS) + +**Error:** `AMDGPU_KERNEL_ERROR_OUT_OF_MEMORY` or `LDS size exceeds limit` + +**Cause:** Kernel uses more LDS than available (64 KB on R9700, 160 KB on MI355X). + +**Fix:** +```python +# Reduce num_stages +num_stages=2 # instead of 3 or 4 + +# Reduce block sizes +BLOCK_M=64, BLOCK_N=64, BLOCK_K=32 # smaller tiles +``` + +### 5. Kernel Timeout + +**Error:** Kernel hangs or times out. + +**Common cause:** Grid and Program ID mismatch. + +```python +# WRONG: 1D grid but 2D program_id +grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) +pid_m = tl.program_id(0) # OK +pid_n = tl.program_id(1) # ERROR: axis 1 doesn't exist in 1D grid + +# CORRECT: Compute 2D indices from 1D grid +grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) +pid = tl.program_id(0) +pid_m = pid // triton.cdiv(N, BLOCK_N) +pid_n = pid % triton.cdiv(N, BLOCK_N) +``` + +## Correctness Issues + +### 6. Autotuning BLOCK_D Causes Wrong Results + +**Symptom:** RMSNorm/AdaLN/GEGLU correctness fails with large `max_abs` errors (4-8+). Kernel runs fast but produces garbage. + +**Cause:** `@triton.autotune` with `BLOCK_D` configs (e.g., 512, 1024, 2048, 4096) may select a `BLOCK_D < D` (hidden dimension). Since `tl.arange(0, BLOCK_D)` only covers `BLOCK_D` elements, the kernel processes a partial row, computing wrong variance and writing incomplete output. + +**Fix:** Never autotune `BLOCK_D` for row-reduction kernels. Compute it dynamically: +```python +# WRONG — autotune may pick BLOCK_D=512 when D=2048 +@triton.autotune(configs=[ + triton.Config({'BLOCK_D': 512}, num_warps=4), + triton.Config({'BLOCK_D': 1024}, num_warps=8), +], key=['D']) + +# CORRECT — compute in Python wrapper +BLOCK_D = triton.next_power_of_2(D) +num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) +kernel[(M,)](..., BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2) +``` + +### 7. RoPE cos/sin Out-of-Bounds GPU Crash (batch > 1) + +**Symptom:** `Memory access fault by GPU node` crash. Only happens when batch_size > 1. + +**Cause:** cos/sin tensors have shape `[seq_len, head_dim]`, but when the grid is `(batch * seq_len, num_heads)`, `pid_s` ranges `[0, batch * seq_len)`. For `pid_s >= seq_len`, `cos_ptr + pid_s * head_dim` is out of bounds. + +**Fix:** Use modular indexing for cos/sin: +```python +# WRONG — crashes when pid_s >= seq_len +cos_val = tl.load(cos_ptr + pid_s * head_dim + offs, ...) + +# CORRECT — wrap position index for batch dimension +seq_idx = pid_s % seq_len +cos_val = tl.load(cos_ptr + seq_idx * head_dim + offs, ...) +``` + +### 8. FP16/BF16 Precision Loss + +**Symptom:** Results differ from PyTorch reference by more than tolerance. + +**Fix:** Always accumulate in FP32: +```python +# WRONG: Accumulate in FP16 +acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float16) + +# CORRECT: Accumulate in FP32, cast at store +acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) +# ... computation ... +result = acc.to(tl.float16) +tl.store(out_ptr + ..., result, mask=mask) +``` + +**Tolerance guidelines:** +- BF16 (7-bit mantissa): `atol=0.1`, `rtol=1e-2` +- FP16 (10-bit mantissa): `atol=0.01`, `rtol=1e-3` + +### 9. Mask Errors + +**Error:** `ValueError: Mask argument cannot be block type` + +**Fix:** Ensure mask dimensions match pointer dimensions: +```python +# 1D kernel +mask = offsets < n_elements +x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + +# 2D kernel +mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) +x = tl.load(ptr + offs_m[:, None] * stride + offs_n[None, :], mask=mask, other=0.0) +``` + +### 10. Python min/max Inside Kernel + +**Error:** `TypeError` or incorrect results. + +**Fix:** +```python +# WRONG: Python builtins +result = min(a, b) +result = max(a, b) + +# CORRECT: Triton functions +result = tl.minimum(a, b) +result = tl.maximum(a, b) +``` + +## Performance Issues + +### 11. GEMM Extremely Slow (0.3-0.5x) + +**Cause:** Missing XCD swizzle on MI355X. + +**Fix:** Add XCD swizzle pattern (see Template 5: GEMM with XCD Swizzle in kernel-templates.md). + +### 12. Elementwise Kernel Slow + +**Common causes:** +1. BLOCK_SIZE too small → not utilizing bandwidth +2. Internal loops → should process full block +3. Missing autotune → not finding optimal config + +**Fix:** +```python +# Use large BLOCK_SIZE for elementwise +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 2048}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_SIZE': 4096}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_SIZE': 8192}, num_warps=16, num_stages=2), + ], + key=['n_elements'], +) +``` + +### 13. Missing @triton.autotune (for elementwise) + +**Symptom:** Kernel runs but performance is poor. + +**Fix:** **EVERY kernel must have autotune with 4+ configs.** Fixed block sizes are almost never optimal. + +### 14. tl.store() Keyword Argument Error + +**Error:** `TypeError: store() got an unexpected keyword argument` + +**Fix:** Check Triton version API. Use positional arguments if needed: +```python +# Check your Triton version +# tl.store(ptr, value, mask=mask) # Most versions +# tl.store(ptr, value, mask) # Some older versions +``` + +### 15. eps: tl.constexpr Causes Recompilation Crash + +**Error:** `AttributeError("'NoneType' object has no attribute 'type'")` during Triton compilation + +**Cause:** When `eps` is declared as `tl.constexpr`, the kernel is compiled separately for each unique eps value. If the kernel first compiles with `eps=1e-6` and later is called with `eps=1e-8` (e.g., from `nn.RMSNorm.eps`), the recompilation on ROCm Triton can crash. + +**Fix:** Remove `tl.constexpr` from `eps` and pass it as a regular runtime parameter: +```python +# WRONG — triggers recompilation for each eps value, may crash on ROCm +@triton.jit +def rmsnorm_kernel(x_ptr, ..., eps: tl.constexpr, BLOCK_D: tl.constexpr): + ... + +# CORRECT — eps is a regular runtime float, no recompilation +@triton.jit +def rmsnorm_kernel(x_ptr, ..., eps, BLOCK_D: tl.constexpr): + ... + +# Also ensure eps is a plain float in the wrapper +rmsnorm_kernel[(M,)](..., float(eps), BLOCK_D=BLOCK_D, ...) +``` + +**Note:** Only `BLOCK_D`, `HAS_WEIGHT`, and other values that change kernel structure should be `tl.constexpr`. Parameters like `eps` that only affect numerical values should be regular parameters. + +## Debugging Tips + +### Check GPU Architecture + +```bash +rocminfo | grep "Name" +# Should show gfx950 (MI355X) or gfx1201 (R9700) +``` + +### Verify ROCm Triton Installation + +```python +import triton +print(triton.__version__) +import torch +print(torch.version.hip) # Should show ROCm version +print(torch.cuda.get_device_properties(0)) +``` + +### Profile Kernel + +```bash +# Basic profiling +rocprof --stats python your_kernel.py + +# Detailed kernel metrics +rocprofv3 -i metrics.txt python your_kernel.py +``` + +### Test Kernel Correctness + +```python +# Compare with PyTorch reference +ref_output = reference_model(inputs) +custom_output = custom_model(inputs) + +torch.testing.assert_close( + custom_output, ref_output, + rtol=1e-2, atol=1e-3 # FP16 tolerance +) +``` diff --git a/skills/rocm-kernels/scripts/benchmark_e2e.py b/skills/rocm-kernels/scripts/benchmark_e2e.py new file mode 100644 index 00000000..0a441914 --- /dev/null +++ b/skills/rocm-kernels/scripts/benchmark_e2e.py @@ -0,0 +1,296 @@ +#!/usr/bin/env python3 +""" +End-to-end benchmark: LTX-Video pipeline with/without custom Triton kernels on ROCm. + +Measures total generation time, per-step latency, and peak memory. + +Requirements: + pip install diffusers transformers accelerate torch triton + +Usage: + # Baseline (no custom kernels) + python benchmark_e2e.py --mode baseline + + # With custom Triton kernels + python benchmark_e2e.py --mode triton + + # With torch.compile + python benchmark_e2e.py --mode compile + + # Compare all three + python benchmark_e2e.py --mode all + + # Quick test (fewer frames/steps) + python benchmark_e2e.py --mode triton --num-frames 9 --steps 5 +""" +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' + +import argparse +import json +import time + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +# ============================================================================ +# Triton RMSNorm Kernel (same as in benchmark_kernels.py) +# ============================================================================ + +@triton.jit +def _rmsnorm_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x_row, D, + eps, + HAS_WEIGHT: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + + x = tl.load(x_ptr + row * stride_x_row + offs, mask=mask, other=0.0).to(tl.float32) + sq_sum = tl.sum(x * x, axis=0) + rms_inv = tl.rsqrt(sq_sum / D + eps) + + if HAS_WEIGHT: + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + out = x * rms_inv * w + else: + out = x * rms_inv + + tl.store(out_ptr + row * stride_x_row + offs, out.to(x.dtype), mask=mask) + + +def triton_rmsnorm(x, weight=None, eps=1e-6): + x_flat = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_flat) + M, D = x_flat.shape + has_weight = weight is not None + if not has_weight: + weight = torch.ones(D, device=x.device, dtype=x.dtype) + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + _rmsnorm_kernel[(M,)]( + x_flat, weight, out, x_flat.stride(0), D, float(eps), has_weight, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) + + +# ============================================================================ +# Attention Processor (uses Triton RMSNorm) +# ============================================================================ + +class TritonLTXVideoAttnProcessor: + def __call__(self, attn, hidden_states, encoder_hidden_states=None, + attention_mask=None, image_rotary_emb=None): + from diffusers.models.transformers.transformer_ltx import apply_rotary_emb + from diffusers.models.attention_dispatch import dispatch_attention_fn + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = triton_rmsnorm(query, attn.norm_q.weight, eps=attn.norm_q.eps) + key = triton_rmsnorm(key, attn.norm_k.weight, eps=attn.norm_k.eps) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + hidden_states = dispatch_attention_fn( + query, key, value, + attn_mask=attention_mask, dropout_p=0.0, is_causal=False, + ) + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +# ============================================================================ +# Module Patchers +# ============================================================================ + +def patch_rmsnorm_modules(model): + patched = 0 + for name, module in model.named_modules(): + if type(module).__name__ == 'RMSNorm': + eps = getattr(module, 'eps', 1e-6) + has_weight = hasattr(module, 'weight') and module.weight is not None + if has_weight: + def make_fwd(mod, e): + def fwd(x): return triton_rmsnorm(x, mod.weight, eps=e) + return fwd + module.forward = make_fwd(module, eps) + else: + def make_fwd_nw(e): + def fwd(x): return triton_rmsnorm(x, None, eps=e) + return fwd + module.forward = make_fwd_nw(eps) + patched += 1 + return patched + + +def inject_triton_kernels(pipe): + stats = {'attention_processors': 0, 'rmsnorm_modules': 0} + if not hasattr(pipe, 'transformer'): + return stats + for name, module in pipe.transformer.named_modules(): + if hasattr(module, 'set_processor') and hasattr(module, 'processor'): + module.set_processor(TritonLTXVideoAttnProcessor()) + stats['attention_processors'] += 1 + stats['rmsnorm_modules'] = patch_rmsnorm_modules(pipe.transformer) + return stats + + +# ============================================================================ +# Benchmark Runner +# ============================================================================ + +def run_benchmark(mode, prompt, num_frames, height, width, steps, + guidance_scale, seed, warmup_iters): + from diffusers import LTXPipeline + from diffusers.utils import export_to_video + + device = "cuda" + dtype = torch.bfloat16 + + print(f"\n{'='*60}") + print(f"MODE: {mode.upper()}") + print(f"{'='*60}") + + pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=dtype) + pipe.to(device) + + if mode == "triton": + stats = inject_triton_kernels(pipe) + print(f" Attention processors: {stats['attention_processors']}") + print(f" RMSNorm patched: {stats['rmsnorm_modules']}") + elif mode == "compile": + pipe.transformer.compile_repeated_blocks(fullgraph=True) + print(" torch.compile enabled (fullgraph=True)") + else: + print(" Baseline (no optimization)") + + # Warmup + if warmup_iters > 0: + print(f"\n Warmup ({warmup_iters} iters, {min(steps, 5)} steps)...") + for i in range(warmup_iters): + _ = pipe(prompt=prompt, num_frames=num_frames, height=height, width=width, + num_inference_steps=min(steps, 5), guidance_scale=guidance_scale) + torch.cuda.synchronize() + + # Benchmark run + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + print(f"\n Generating ({num_frames} frames, {steps} steps)...") + torch.cuda.synchronize() + start = time.time() + output = pipe( + prompt=prompt, num_frames=num_frames, height=height, width=width, + num_inference_steps=steps, guidance_scale=guidance_scale, + generator=torch.Generator(device=device).manual_seed(seed), + ) + torch.cuda.synchronize() + gen_time = time.time() - start + peak_mem = torch.cuda.max_memory_allocated() / 1e9 + + result = { + 'mode': mode, + 'gen_time_s': round(gen_time, 2), + 'time_per_frame_s': round(gen_time / num_frames, 3), + 'time_per_step_s': round(gen_time / steps, 3), + 'peak_memory_gb': round(peak_mem, 2), + } + + print(f"\n Results:") + print(f" Total: {result['gen_time_s']:.2f} s") + print(f" Per frame: {result['time_per_frame_s']:.3f} s") + print(f" Per step: {result['time_per_step_s']:.3f} s") + print(f" Peak mem: {result['peak_memory_gb']:.2f} GB") + + # Save video + out_path = f"ltx_video_{mode}.mp4" + export_to_video(output.frames[0], out_path, fps=24) + print(f" Saved to: {out_path}") + + del pipe + torch.cuda.empty_cache() + return result + + +def main(): + parser = argparse.ArgumentParser(description="E2E LTX-Video benchmark on ROCm") + parser.add_argument("--mode", type=str, default="all", + choices=["baseline", "triton", "compile", "all"]) + parser.add_argument("--prompt", type=str, + default="A cat sleeping in warm sunlight, cinematic, 4K") + parser.add_argument("--num-frames", type=int, default=25) + parser.add_argument("--height", type=int, default=480) + parser.add_argument("--width", type=int, default=704) + parser.add_argument("--steps", type=int, default=30) + parser.add_argument("--guidance-scale", type=float, default=7.5) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--warmup", type=int, default=1) + parser.add_argument("--output-json", type=str, default=None, + help="Save results to JSON for comparison") + args = parser.parse_args() + + print("=" * 60) + print("LTX-Video End-to-End Benchmark (ROCm)") + print("=" * 60) + print(f"Device: {torch.cuda.get_device_name(0)}") + print(f"ROCm: {torch.version.hip if hasattr(torch.version, 'hip') else 'N/A'}") + print(f"Config: {args.num_frames} frames, {args.height}x{args.width}, {args.steps} steps") + + modes = ["baseline", "triton", "compile"] if args.mode == "all" else [args.mode] + all_results = [] + + for mode in modes: + r = run_benchmark(mode, args.prompt, args.num_frames, args.height, + args.width, args.steps, args.guidance_scale, + args.seed, args.warmup) + all_results.append(r) + + # Comparison table + if len(all_results) > 1: + print(f"\n{'='*60}") + print("COMPARISON") + print(f"{'='*60}") + print(f"{'Mode':<12} {'Time (s)':<12} {'Per Step (s)':<15} {'Peak Mem (GB)':<15}") + print("-" * 54) + baseline_time = all_results[0]['gen_time_s'] + for r in all_results: + speedup = baseline_time / r['gen_time_s'] if r['gen_time_s'] > 0 else 0 + suffix = f" ({speedup:.2f}x)" if r['mode'] != 'baseline' else "" + print(f"{r['mode']:<12} {r['gen_time_s']:<12.2f} {r['time_per_step_s']:<15.3f} {r['peak_memory_gb']:<15.2f}{suffix}") + + if args.output_json: + with open(args.output_json, 'w') as f: + json.dump(all_results, f, indent=2) + print(f"\nResults saved to {args.output_json}") + + +if __name__ == "__main__": + main() diff --git a/skills/rocm-kernels/scripts/benchmark_kernels.py b/skills/rocm-kernels/scripts/benchmark_kernels.py new file mode 100644 index 00000000..26bb9f8d --- /dev/null +++ b/skills/rocm-kernels/scripts/benchmark_kernels.py @@ -0,0 +1,536 @@ +#!/usr/bin/env python3 +""" +Micro-benchmark for all 4 Triton kernels on ROCm: RMSNorm, RoPE 3D, GEGLU, AdaLN. + +Measures: + 1. Correctness vs PyTorch reference + 2. Latency (custom vs baseline, warmup + averaged) + 3. Memory bandwidth utilization + +Usage: + python benchmark_kernels.py + python benchmark_kernels.py --kernel rmsnorm + python benchmark_kernels.py --kernel rope + python benchmark_kernels.py --kernel geglu + python benchmark_kernels.py --kernel adaln + python benchmark_kernels.py --dtype float16 +""" +import os +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' + +import argparse +import time +from typing import Tuple + +import torch +import triton +import triton.language as tl + + +# ============================================================================ +# Kernel 1: RMSNorm +# ============================================================================ +# CRITICAL: BLOCK_D must be >= D (hidden dimension). +# Using autotune with fixed BLOCK_D configs is WRONG because autotune may +# pick BLOCK_D < D, causing only partial row processing. +# Fix: compute BLOCK_D = next_power_of_2(D) dynamically in the Python wrapper. + +@triton.jit +def rmsnorm_fwd_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x, D, + eps, + HAS_WEIGHT: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_D) + mask = col_offsets < D + + x = tl.load(x_ptr + row * stride_x + col_offsets, mask=mask, other=0.0).to(tl.float32) + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + + if HAS_WEIGHT: + w = tl.load(weight_ptr + col_offsets, mask=mask, other=1.0).to(tl.float32) + result = x * rms_inv * w + else: + result = x * rms_inv + + tl.store(out_ptr + row * stride_x + col_offsets, result.to(x.dtype), mask=mask) + + +def triton_rmsnorm(x, weight=None, eps=1e-6): + orig_shape = x.shape + x_2d = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_2d) + M, D = x_2d.shape + has_weight = weight is not None + if not has_weight: + weight = torch.empty(0, device=x.device) + + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + rmsnorm_fwd_kernel[(M,)]( + x_2d, weight, out, + x_2d.stride(0), D, float(eps), has_weight, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view(orig_shape) + + +def pytorch_rmsnorm(x, weight=None, eps=1e-6): + variance = x.pow(2).mean(dim=-1, keepdim=True) + out = x * torch.rsqrt(variance + eps) + if weight is not None: + out = out * weight + return out + + +# ============================================================================ +# Kernel 2: RoPE 3D +# ============================================================================ +# CRITICAL: cos/sin have shape [seq_len, head_dim], NOT [batch*seq_len, ...]. +# When grid is (batch * seq_len, num_heads), we must use pid_s % seq_len +# to index into cos/sin to avoid out-of-bounds access for batch > 1. + +@triton.jit +def rope_3d_fwd_kernel( + qk_ptr, cos_ptr, sin_ptr, out_ptr, + seq_len, num_heads, head_dim, + stride_s, stride_h, stride_d, + BLOCK_HD: tl.constexpr, +): + pid_s = tl.program_id(0) + pid_h = tl.program_id(1) + half_dim = head_dim // 2 + offs = tl.arange(0, BLOCK_HD) + mask = offs < half_dim + + base = pid_s * stride_s + pid_h * stride_h + x0 = tl.load(qk_ptr + base + offs, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load(qk_ptr + base + half_dim + offs, mask=mask, other=0.0).to(tl.float32) + + seq_idx = pid_s % seq_len + cos_val = tl.load(cos_ptr + seq_idx * head_dim + offs, mask=mask, other=1.0).to(tl.float32) + sin_val = tl.load(sin_ptr + seq_idx * head_dim + offs, mask=mask, other=0.0).to(tl.float32) + + out0 = x0 * cos_val - x1 * sin_val + out1 = x0 * sin_val + x1 * cos_val + + tl.store(out_ptr + base + offs, out0.to(x0.dtype), mask=mask) + tl.store(out_ptr + base + half_dim + offs, out1.to(x0.dtype), mask=mask) + + +def triton_rope_3d(qk, cos, sin): + qk = qk.contiguous() + out = torch.empty_like(qk) + batch, seq_len, num_heads, head_dim = qk.shape + half_dim = head_dim // 2 + qk_flat = qk.view(batch * seq_len, num_heads, head_dim) + out_flat = out.view(batch * seq_len, num_heads, head_dim) + grid = (batch * seq_len, num_heads) + BLOCK_HD = triton.next_power_of_2(half_dim) + num_warps = 4 if BLOCK_HD <= 64 else 8 + rope_3d_fwd_kernel[grid]( + qk_flat, cos, sin, out_flat, + seq_len, num_heads, head_dim, + qk_flat.stride(0), qk_flat.stride(1), qk_flat.stride(2), + BLOCK_HD=BLOCK_HD, num_warps=num_warps, num_stages=2, + ) + return out + + +def pytorch_rope(qk, cos, sin): + half = qk.shape[-1] // 2 + x0, x1 = qk[..., :half], qk[..., half:] + cos_exp = cos.unsqueeze(0).unsqueeze(2)[:, :qk.shape[1], :, :half] + sin_exp = sin.unsqueeze(0).unsqueeze(2)[:, :qk.shape[1], :, :half] + out0 = x0 * cos_exp - x1 * sin_exp + out1 = x0 * sin_exp + x1 * cos_exp + return torch.cat([out0, out1], dim=-1) + + +# ============================================================================ +# Kernel 3: GEGLU +# ============================================================================ +# Same BLOCK_SIZE fix as RMSNorm: compute dynamically, do NOT autotune. + +@triton.jit +def geglu_fwd_kernel( + input_ptr, output_ptr, + stride_in, stride_out, hidden_size, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_H) + mask = offs < hidden_size + + gate = tl.load(input_ptr + row * stride_in + offs, mask=mask, other=0.0).to(tl.float32) + value = tl.load(input_ptr + row * stride_in + hidden_size + offs, mask=mask, other=0.0).to(tl.float32) + + # GELU approx: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + # tl.math.tanh / tl.libdevice.tanh NOT available on ROCm — use manual formula + SQRT_2_OVER_PI = 0.7978845608028654 + tanh_arg = SQRT_2_OVER_PI * (gate + 0.044715 * gate * gate * gate) + e2x = tl.exp(2.0 * tanh_arg) + tanh_val = (e2x - 1.0) / (e2x + 1.0) + cdf = 0.5 * (1.0 + tanh_val) + gelu_gate = gate * cdf + result = gelu_gate * value + + tl.store(output_ptr + row * stride_out + offs, result.to(gate.dtype), mask=mask) + + +def triton_geglu(x): + x = x.contiguous() + *batch_dims, double_h = x.shape + hidden_size = double_h // 2 + x_2d = x.view(-1, double_h) + M = x_2d.shape[0] + out = torch.empty(M, hidden_size, device=x.device, dtype=x.dtype) + + BLOCK_H = triton.next_power_of_2(hidden_size) + num_warps = 4 if BLOCK_H <= 1024 else (8 if BLOCK_H <= 4096 else 16) + geglu_fwd_kernel[(M,)]( + x_2d, out, + x_2d.stride(0), out.stride(0), hidden_size, + BLOCK_H=BLOCK_H, num_warps=num_warps, num_stages=2, + ) + return out.view(*batch_dims, hidden_size) + + +def pytorch_geglu(x): + hidden_size = x.shape[-1] // 2 + gate, value = x[..., :hidden_size], x[..., hidden_size:] + return torch.nn.functional.gelu(gate, approximate='tanh') * value + + +# ============================================================================ +# Kernel 4: AdaLN +# ============================================================================ +# Same BLOCK_D fix: compute dynamically. + +@triton.jit +def adaln_fwd_kernel( + x_ptr, weight_ptr, scale_ptr, shift_ptr, out_ptr, + stride_x, stride_cond, D, + eps, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + x_norm = x * rms_inv + + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + scale = tl.load(scale_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) + shift = tl.load(shift_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) + + out = x_norm * w * (1.0 + scale) + shift + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) + + +def triton_adaln(x, weight, scale, shift, eps=1e-6): + x_flat = x.contiguous().view(-1, x.shape[-1]) + scale_flat = scale.contiguous().view(-1, x.shape[-1]) + shift_flat = shift.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_flat) + M, D = x_flat.shape + + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + adaln_fwd_kernel[(M,)]( + x_flat, weight, scale_flat, shift_flat, out, + x_flat.stride(0), scale_flat.stride(0), D, float(eps), + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) + + +def pytorch_adaln(x, weight, scale, shift, eps=1e-6): + variance = x.pow(2).mean(dim=-1, keepdim=True) + x_norm = x * torch.rsqrt(variance + eps) + return x_norm * weight * (1.0 + scale) + shift + + +# ============================================================================ +# Benchmark Utilities +# ============================================================================ + +def benchmark_fn(func, args, warmup=20, iterations=100) -> Tuple[float, float]: + for _ in range(warmup): + func(*args) + torch.cuda.synchronize() + + times = [] + for _ in range(iterations): + torch.cuda.synchronize() + start = time.perf_counter() + func(*args) + torch.cuda.synchronize() + end = time.perf_counter() + times.append((end - start) * 1000) + + return sum(times) / len(times), min(times) + + +def check_correctness(out, ref, name, dtype): + max_abs = (out.float() - ref.float()).abs().max().item() + max_rel = ((out.float() - ref.float()).abs() / (ref.float().abs() + 1e-8)).max().item() + + # BF16 has 7-bit mantissa; for values ~8-16 the ULP is 0.0625-0.125 + # FP16 has 10-bit mantissa; tighter but RoPE trig ops can accumulate 1-2 ULP error + atol = 0.15 if dtype == torch.bfloat16 else 0.02 + passed = max_abs < atol + status = "PASS" if passed else "FAIL" + print(f" [{status}] {name}: max_abs={max_abs:.6e}, max_rel={max_rel:.6e}") + return passed + + +# ============================================================================ +# Benchmark Runners +# ============================================================================ + +def benchmark_rmsnorm(dtype): + print("\n" + "=" * 70) + print("BENCHMARK: RMSNorm (168 instances in LTX-Video)") + print("=" * 70) + + configs = [ + (1, 1024, 2048), + (2, 1024, 2048), + (4, 1024, 2048), + (1, 4096, 2048), + (2, 4096, 3072), + (1, 8192, 2048), + (4, 4096, 3072), + ] + + print(f"\n{'Config':<25} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") + print("-" * 70) + + all_correct = True + total_speedup = 0 + + for batch, seq, hidden in configs: + x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") + w = torch.ones(hidden, dtype=dtype, device="cuda") + + ref = pytorch_rmsnorm(x, w) + out = triton_rmsnorm(x, w) + if not check_correctness(out, ref, f"[{batch}x{seq}x{hidden}]", dtype): + all_correct = False + + t_avg, _ = benchmark_fn(triton_rmsnorm, (x, w)) + p_avg, _ = benchmark_fn(pytorch_rmsnorm, (x, w)) + speedup = p_avg / t_avg + total_speedup += speedup + + print(f" [{batch}x{seq}x{hidden}]{'':<13} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") + + # No-weight variant + print("\n -- No-weight variant (elementwise_affine=False) --") + x = torch.randn(2, 4096, 2048, dtype=dtype, device="cuda") + ref_nw = pytorch_rmsnorm(x, None) + out_nw = triton_rmsnorm(x, None) + check_correctness(out_nw, ref_nw, "no-weight [2x4096x2048]", dtype) + + avg_speedup = total_speedup / len(configs) + print(f"\n Average speedup: {avg_speedup:.2f}x") + + # Bandwidth analysis + batch, seq, hidden = 4, 4096, 3072 + x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") + w = torch.ones(hidden, dtype=dtype, device="cuda") + bytes_per_elem = 2 if dtype in (torch.float16, torch.bfloat16) else 4 + total_bytes = batch * seq * hidden * bytes_per_elem * 2 + hidden * bytes_per_elem + t_avg, _ = benchmark_fn(triton_rmsnorm, (x, w)) + bw_gbps = (total_bytes / 1e9) / (t_avg / 1000) + print(f"\n Bandwidth analysis [{batch}x{seq}x{hidden}]:") + print(f" Data moved: {total_bytes / 1e6:.2f} MB") + print(f" Achieved: {bw_gbps:.1f} GB/s") + + return all_correct, avg_speedup + + +def benchmark_rope(dtype): + print("\n" + "=" * 70) + print("BENCHMARK: RoPE 3D (Video Position Encoding)") + print("=" * 70) + + configs = [ + (1, 1024, 16, 64), + (1, 4096, 16, 64), + (2, 4096, 16, 128), + (1, 8192, 32, 64), + ] + + print(f"\n{'Config':<30} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") + print("-" * 75) + + all_correct = True + total_speedup = 0 + + for batch, seq, heads, hdim in configs: + qk = torch.randn(batch, seq, heads, hdim, dtype=dtype, device="cuda") + cos = torch.randn(seq, hdim, dtype=dtype, device="cuda") + sin = torch.randn(seq, hdim, dtype=dtype, device="cuda") + + ref = pytorch_rope(qk, cos, sin) + out = triton_rope_3d(qk, cos, sin) + if not check_correctness(out, ref, f"[{batch}x{seq}x{heads}x{hdim}]", dtype): + all_correct = False + + t_avg, _ = benchmark_fn(triton_rope_3d, (qk, cos, sin)) + p_avg, _ = benchmark_fn(pytorch_rope, (qk, cos, sin)) + speedup = p_avg / t_avg + total_speedup += speedup + + cfg = f"[{batch}x{seq}x{heads}x{hdim}]" + print(f" {cfg:<28} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") + + avg_speedup = total_speedup / len(configs) + print(f"\n Average speedup: {avg_speedup:.2f}x") + return all_correct, avg_speedup + + +def benchmark_geglu(dtype): + print("\n" + "=" * 70) + print("BENCHMARK: GEGLU (For SD3/FLUX, NOT LTX-Video)") + print("=" * 70) + + configs = [ + (1, 1024, 2048), + (2, 1024, 4096), + (2, 4096, 3072), + (4, 4096, 4096), + ] + + print(f"\n{'Config':<30} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") + print("-" * 75) + + all_correct = True + total_speedup = 0 + + for batch, seq, hidden in configs: + x = torch.randn(batch, seq, hidden * 2, dtype=dtype, device="cuda") + + ref = pytorch_geglu(x) + out = triton_geglu(x) + if not check_correctness(out, ref, f"[{batch}x{seq}x{hidden*2}]", dtype): + all_correct = False + + t_avg, _ = benchmark_fn(triton_geglu, (x,)) + p_avg, _ = benchmark_fn(pytorch_geglu, (x,)) + speedup = p_avg / t_avg + total_speedup += speedup + + cfg = f"[{batch}x{seq}x{hidden*2}->{hidden}]" + print(f" {cfg:<28} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") + + avg_speedup = total_speedup / len(configs) + print(f"\n Average speedup: {avg_speedup:.2f}x") + return all_correct, avg_speedup + + +def benchmark_adaln(dtype): + print("\n" + "=" * 70) + print("BENCHMARK: AdaLN (Fused Norm + Conditioning for DiT)") + print("=" * 70) + + configs = [ + (1, 1024, 2048), + (2, 1024, 2048), + (2, 4096, 3072), + (4, 4096, 3072), + ] + + print(f"\n{'Config':<25} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") + print("-" * 70) + + all_correct = True + total_speedup = 0 + + for batch, seq, hidden in configs: + x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") + w = torch.ones(hidden, dtype=dtype, device="cuda") + scale = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") * 0.1 + shift = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") * 0.1 + + ref = pytorch_adaln(x, w, scale, shift) + out = triton_adaln(x, w, scale, shift) + if not check_correctness(out, ref, f"[{batch}x{seq}x{hidden}]", dtype): + all_correct = False + + t_avg, _ = benchmark_fn(triton_adaln, (x, w, scale, shift)) + p_avg, _ = benchmark_fn(pytorch_adaln, (x, w, scale, shift)) + speedup = p_avg / t_avg + total_speedup += speedup + + print(f" [{batch}x{seq}x{hidden}]{'':<13} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") + + avg_speedup = total_speedup / len(configs) + print(f"\n Average speedup: {avg_speedup:.2f}x") + return all_correct, avg_speedup + + +# ============================================================================ +# Main +# ============================================================================ + +def main(): + parser = argparse.ArgumentParser(description="Benchmark Triton kernels on ROCm") + parser.add_argument("--kernel", type=str, default="all", + choices=["all", "rmsnorm", "rope", "geglu", "adaln"]) + parser.add_argument("--dtype", type=str, default="bfloat16", + choices=["bfloat16", "float16"]) + args = parser.parse_args() + + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + + print("=" * 70) + print("ROCm Triton Kernel Micro-Benchmark") + print("=" * 70) + print(f"Device: {torch.cuda.get_device_name(0)}") + print(f"Dtype: {dtype}") + print(f"ROCm: {torch.version.hip if hasattr(torch.version, 'hip') else 'N/A'}") + + results = {} + runners = { + "rmsnorm": benchmark_rmsnorm, + "rope": benchmark_rope, + "geglu": benchmark_geglu, + "adaln": benchmark_adaln, + } + + if args.kernel == "all": + for name, runner in runners.items(): + correct, speedup = runner(dtype) + results[name] = {"correct": correct, "speedup": speedup} + else: + correct, speedup = runners[args.kernel](dtype) + results[args.kernel] = {"correct": correct, "speedup": speedup} + + # Summary + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + print(f"{'Kernel':<15} {'Correct':<12} {'Avg Speedup':<15}") + print("-" * 42) + for name, r in results.items(): + status = "PASS" if r["correct"] else "FAIL" + print(f"{name:<15} {status:<12} {r['speedup']:.2f}x") + + all_pass = all(r["correct"] for r in results.values()) + print(f"\nOverall: {'ALL PASS' if all_pass else 'SOME FAILED'}") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/skills/rocm-kernels/scripts/huggingface_kernels_example.py b/skills/rocm-kernels/scripts/huggingface_kernels_example.py new file mode 100644 index 00000000..80a0acf5 --- /dev/null +++ b/skills/rocm-kernels/scripts/huggingface_kernels_example.py @@ -0,0 +1,344 @@ +#!/usr/bin/env python3 +""" +Example: Using HuggingFace Kernels library to load and use optimized kernels on ROCm. + +This script demonstrates how to: +1. Load kernels from the HuggingFace Hub using get_kernel() +2. Check kernel availability with has_kernel() +3. Integrate Hub kernels with transformers/diffusers models +4. Fall back to local Triton kernels when Hub builds are unavailable + +Requirements: + pip install kernels torch numpy + +Usage: + python scripts/huggingface_kernels_example.py +""" + +import os +import time +from typing import Optional + +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +# ============================================================================= +# Local Triton RMSNorm (fallback when Hub kernel unavailable) +# ============================================================================= + +EPS_DEFAULT = 1e-6 + +@triton.jit +def rmsnorm_fwd_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x, D, eps, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + out = x * rms_inv * w + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) + + +def local_triton_rmsnorm(x, weight, eps=EPS_DEFAULT): + """Local Triton RMSNorm — used as fallback when Hub kernel is unavailable.""" + x_2d = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_2d) + M, D = x_2d.shape + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + rmsnorm_fwd_kernel[(M,)]( + x_2d, weight, out, x_2d.stride(0), D, float(eps), + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) + + +# ============================================================================= +# Part 1: Check Environment +# ============================================================================= + +def check_environment(): + """Print environment information for debugging.""" + print("=" * 60) + print("Environment") + print("=" * 60) + print(f"PyTorch: {torch.__version__}") + print(f"HIP version: {getattr(torch.version, 'hip', 'N/A')}") + print(f"CUDA available: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + print(f"GPU: {torch.cuda.get_device_name()}") + print() + + +# ============================================================================= +# Part 2: Basic Kernel Loading from Hub +# ============================================================================= + +def demo_basic_kernel_loading(): + """Demonstrate basic kernel loading from Hub.""" + print("=" * 60) + print("Part 1: Basic Kernel Loading from Hub") + print("=" * 60) + + try: + from kernels import get_kernel, has_kernel + + repo_id = "kernels-community/triton-layer-norm" + + print(f"\n1. Checking kernel availability: {repo_id}") + if has_kernel(repo_id): + print(" Kernel is available for this ROCm environment") + + print(f"\n2. Loading kernel from Hub...") + kernel = get_kernel(repo_id) + + print(f"\n3. Available functions:") + functions = [f for f in dir(kernel) if not f.startswith('_')] + for func in functions[:10]: + print(f" - {func}") + + print(f"\n4. Testing RMSNorm kernel...") + x = torch.randn(2, 1024, 2048, dtype=torch.bfloat16, device="cuda") + w = torch.ones(2048, dtype=torch.bfloat16, device="cuda") + + rms_fn_name = None + for name in ('rms_norm', 'rms_norm_fn', 'rmsnorm'): + if hasattr(kernel, name): + rms_fn_name = name + break + + if rms_fn_name: + rms_fn = getattr(kernel, rms_fn_name) + try: + out = rms_fn(x, w, eps=1e-6) + except TypeError: + # rms_norm_fn(x, weight, bias, ...) requires bias argument + out = rms_fn(x, w, None, eps=1e-6) + print(f" Using: kernel.{rms_fn_name}()") + print(f" Input: {x.shape}, Output: {out.shape}") + print(f" Success!") + else: + print(f" No RMSNorm function found. Available: {functions}") + + return kernel + else: + print(" No compatible build for this ROCm environment") + print(" Will use local Triton kernel as fallback") + return None + + except ImportError: + print("\n kernels library not installed. Install with: pip install kernels") + return None + except Exception as e: + print(f"\n Error: {e}") + return None + + +# ============================================================================= +# Part 3: Benchmark Hub Kernel vs Local Triton vs PyTorch +# ============================================================================= + +def demo_benchmark(hub_kernel): + """Benchmark Hub kernel vs local Triton vs PyTorch.""" + print("\n" + "=" * 60) + print("Part 2: Benchmark Hub vs Local Triton vs PyTorch") + print("=" * 60) + + shapes = [(2, 1024, 2048), (4, 4096, 3072)] + warmup, iterations = 20, 100 + + for shape in shapes: + x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") + w = torch.ones(shape[-1], dtype=torch.bfloat16, device="cuda") + + def _call_hub(fn, x, w, eps): + try: + return fn(x, w, eps=eps) + except TypeError: + return fn(x, w, None, eps=eps) + + hub_rms_fn_raw = None + if hub_kernel: + for fn_name in ('rms_norm', 'rms_norm_fn', 'rmsnorm'): + if hasattr(hub_kernel, fn_name): + hub_rms_fn_raw = getattr(hub_kernel, fn_name) + break + + # Warmup all implementations + for _ in range(warmup): + local_triton_rmsnorm(x, w, eps=1e-6) + variance = x.pow(2).mean(-1, keepdim=True) + _ = x * torch.rsqrt(variance + 1e-6) * w + if hub_rms_fn_raw: + _call_hub(hub_rms_fn_raw, x, w, 1e-6) + torch.cuda.synchronize() + + # PyTorch baseline + start = time.perf_counter() + for _ in range(iterations): + variance = x.pow(2).mean(-1, keepdim=True) + _ = x * torch.rsqrt(variance + 1e-6) * w + torch.cuda.synchronize() + pt_ms = (time.perf_counter() - start) / iterations * 1000 + + # Local Triton + start = time.perf_counter() + for _ in range(iterations): + local_triton_rmsnorm(x, w, eps=1e-6) + torch.cuda.synchronize() + local_ms = (time.perf_counter() - start) / iterations * 1000 + + print(f"\n Shape {shape}:") + print(f" PyTorch: {pt_ms:.4f} ms") + print(f" Local Triton: {local_ms:.4f} ms (speedup: {pt_ms/local_ms:.2f}x)") + + if hub_rms_fn_raw: + start = time.perf_counter() + for _ in range(iterations): + _call_hub(hub_rms_fn_raw, x, w, 1e-6) + torch.cuda.synchronize() + hub_ms = (time.perf_counter() - start) / iterations * 1000 + print(f" Hub kernel: {hub_ms:.4f} ms (speedup: {pt_ms/hub_ms:.2f}x)") + + +# ============================================================================= +# Part 4: Model Integration with Fallback +# ============================================================================= + +def demo_model_integration(hub_kernel): + """Demonstrate integrating kernels with models, with fallback.""" + print("\n" + "=" * 60) + print("Part 3: Model Integration with Fallback") + print("=" * 60) + + class SimpleModel(nn.Module): + def __init__(self, hidden_size=2048): + super().__init__() + self.norm = nn.RMSNorm(hidden_size) + self.linear = nn.Linear(hidden_size, hidden_size) + + def forward(self, x): + return self.linear(self.norm(x)) + + model = SimpleModel().cuda().to(torch.bfloat16) + + # Decide which RMSNorm to use + hub_rms_fn = None + if hub_kernel: + for fn_name in ('rms_norm', 'rms_norm_fn', 'rmsnorm'): + if hasattr(hub_kernel, fn_name): + hub_rms_fn = getattr(hub_kernel, fn_name) + break + + if hub_rms_fn: + def _hub_rmsnorm(x, w, eps): + try: + return hub_rms_fn(x, w, eps=eps) + except TypeError: + return hub_rms_fn(x, w, None, eps=eps) + rmsnorm_fn = _hub_rmsnorm + source = "Hub kernel" + else: + rmsnorm_fn = local_triton_rmsnorm + source = "Local Triton" + + print(f"\n1. Using {source} for RMSNorm") + + # Patch model + for name, module in model.named_modules(): + if isinstance(module, nn.RMSNorm): + raw_eps = getattr(module, 'eps', None) + eps = float(raw_eps) if raw_eps is not None else 1e-6 + + def make_forward(mod, epsilon, fn): + def forward(x): + return fn(x, mod.weight, epsilon) + return forward + + module.forward = make_forward(module, eps, rmsnorm_fn) + print(f" Patched: {name} (eps={eps})") + + # Test + print(f"\n2. Testing forward pass...") + x = torch.randn(2, 1024, 2048, dtype=torch.bfloat16, device="cuda") + with torch.inference_mode(): + y = model(x) + print(f" Input: {x.shape} -> Output: {y.shape}") + print(f" Success!") + + +# ============================================================================= +# Part 5: Publishing Info +# ============================================================================= + +def demo_publishing_info(): + """Show information about publishing kernels to Hub.""" + print("\n" + "=" * 60) + print("Part 4: Publishing Triton Kernels to Hub") + print("=" * 60) + + print(""" + For Triton kernels (best ROCm compatibility): + + 1. Create project structure: + my-triton-kernel/ + ├── build.toml + ├── kernel_src/ + │ └── rmsnorm.py # Triton kernel + └── torch-ext/ + ├── torch_binding.cpp + └── my_kernels/__init__.py + + 2. Configure build.toml with ROCm support: + [general] + name = "my_kernels" + backends = ["cuda", "rocm"] + + 3. Build and publish: + $ pip install kernel-builder + $ kernel-builder build + $ huggingface-cli upload my-username/my-kernel ./dist + + See: https://huggingface.co/docs/kernels + """) + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + print("=" * 60) + print("HuggingFace Kernels Integration Example (ROCm)") + print("=" * 60) + + check_environment() + + if not torch.cuda.is_available(): + print("GPU not available. This example requires an AMD GPU with ROCm.") + return + + hub_kernel = demo_basic_kernel_loading() + demo_benchmark(hub_kernel) + demo_model_integration(hub_kernel) + demo_publishing_info() + + print("\n" + "=" * 60) + print("Done!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/skills/rocm-kernels/scripts/transformers_injection_example.py b/skills/rocm-kernels/scripts/transformers_injection_example.py new file mode 100644 index 00000000..6bbc3ce7 --- /dev/null +++ b/skills/rocm-kernels/scripts/transformers_injection_example.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +""" +Minimal example: Inject custom Triton kernels into HuggingFace Transformers models on ROCm. + +This script demonstrates the essential pattern for integrating custom Triton kernels +with transformers models like LLaMA, Mistral, and Qwen on AMD GPUs. + +Key lessons: +1. Transformers RMSNorm modules always have weights (unlike some diffusers modules) +2. Use 'RMSNorm' substring match to catch LlamaRMSNorm, MistralRMSNorm, etc. +3. Check for 'variance_epsilon' (LLaMA) or 'eps' (others) for epsilon value +4. Use Flash Attention 2 for attention optimization instead of custom processors +5. ROCm: tl.libdevice/tl.math.tanh NOT available — use manual math + +Usage: + python scripts/transformers_injection_example.py +""" + +import os +import sys +import time + +os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1' +os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1' + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +# ============================================================================= +# Triton RMSNorm Kernel (ROCm compatible) +# ============================================================================= + +@triton.jit +def rmsnorm_fwd_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x, D, eps, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + out = x * rms_inv * w + + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) + + +def triton_rmsnorm(x, weight, eps=1e-6): + x_2d = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_2d) + M, D = x_2d.shape + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + rmsnorm_fwd_kernel[(M,)]( + x_2d, weight, out, x_2d.stride(0), D, float(eps), + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) + + +# ============================================================================= +# RMSNorm Module Patcher +# ============================================================================= + +def patch_rmsnorm_modules(model: nn.Module) -> int: + """ + Patch all RMSNorm modules to use Triton kernel on ROCm. + + Works with LlamaRMSNorm, MistralRMSNorm, Qwen2RMSNorm, etc. + Unlike diffusers, transformers RMSNorm always has weights. + """ + patched_count = 0 + + for name, module in model.named_modules(): + class_name = type(module).__name__ + + if 'RMSNorm' in class_name: + eps = getattr(module, 'variance_epsilon', None) + if eps is None: + eps = getattr(module, 'eps', 1e-6) + + has_weight = hasattr(module, 'weight') and module.weight is not None + + if has_weight: + def make_patched_forward(mod, epsilon): + def patched_forward(hidden_states): + return triton_rmsnorm(hidden_states, mod.weight, eps=epsilon) + return patched_forward + module.forward = make_patched_forward(module, eps) + patched_count += 1 + else: + print(f"WARNING: {name} has no weight, skipping") + + return patched_count + + +def inject_optimized_kernels(model) -> dict: + """Inject custom Triton kernels into a transformers model.""" + stats = {'rmsnorm_modules': 0} + stats['rmsnorm_modules'] = patch_rmsnorm_modules(model) + return stats + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + from transformers import AutoModelForCausalLM, AutoTokenizer + + print("=" * 60) + print("Transformers Triton Kernel Injection (ROCm)") + print("=" * 60) + + # Verify ROCm + print(f"\nROCm HIP version: {getattr(torch.version, 'hip', 'N/A')}") + print(f"GPU: {torch.cuda.get_device_name()}") + + model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + + print(f"\n1. Loading model: {model_id}...") + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + device_map="cuda" + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + rmsnorm_count = sum(1 for _, m in model.named_modules() if 'RMSNorm' in type(m).__name__) + print(f" Found {rmsnorm_count} RMSNorm modules") + + print("\n2. Injecting optimized Triton kernels...") + stats = inject_optimized_kernels(model) + print(f" RMSNorm modules patched: {stats['rmsnorm_modules']}") + + print("\n3. Verifying injection...") + x = torch.randn(1, 10, model.config.hidden_size, device='cuda', dtype=torch.bfloat16) + for name, module in model.named_modules(): + if 'RMSNorm' in type(module).__name__: + out = module(x) + print(f" RMSNorm forward pass: {x.shape} -> {out.shape}") + break + + print("\n4. Running generation test...") + prompt = "The capital of France is" + inputs = tokenizer(prompt, return_tensors="pt").to("cuda") + + with torch.inference_mode(): + _ = model.generate(**inputs, max_new_tokens=5, do_sample=False) + + num_tokens = 50 + start_time = time.perf_counter() + with torch.inference_mode(): + outputs = model.generate( + **inputs, + max_new_tokens=num_tokens, + do_sample=False, + pad_token_id=tokenizer.eos_token_id + ) + end_time = time.perf_counter() + + elapsed = end_time - start_time + tokens_per_second = num_tokens / elapsed + + print(f" Prompt: {prompt}") + print(f" Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") + print(f" Generated {num_tokens} tokens in {elapsed:.2f}s ({tokens_per_second:.1f} tokens/s)") + + print("\n" + "=" * 60) + print("Success! Custom Triton kernels are being used on ROCm.") + print("=" * 60) + + +if __name__ == "__main__": + main()