diff --git a/Fuser/auto_agent.py b/Fuser/auto_agent.py index 1e6daff..0b6e43d 100644 --- a/Fuser/auto_agent.py +++ b/Fuser/auto_agent.py @@ -349,7 +349,7 @@ def __init__( ignore_router_config: bool = False, use_router_cache: bool = True, no_cusolver: bool = False, - test_timeout_s: int = 30, + test_timeout_s: int = 300, test_code: str | None = None, ) -> None: self.ka_model = ka_model diff --git a/Fuser/config/autoagent_default.yml b/Fuser/config/autoagent_default.yml index 0947d89..f0c6d82 100644 --- a/Fuser/config/autoagent_default.yml +++ b/Fuser/config/autoagent_default.yml @@ -33,5 +33,5 @@ target_platform: cuda ignore_router_config: false use_router_cache: true no_cusolver: false -test_timeout_s: 30 +test_timeout_s: 300 test_code: null diff --git a/Fuser/dispatch_kernel_agent.py b/Fuser/dispatch_kernel_agent.py index 167dffe..46fd8e1 100644 --- a/Fuser/dispatch_kernel_agent.py +++ b/Fuser/dispatch_kernel_agent.py @@ -334,7 +334,7 @@ def run( target_platform: str = "cuda", max_iters: int = 10, no_cusolver: bool = False, - test_timeout_s: int = 30, + test_timeout_s: int = 300, ) -> Path: """Dispatch subgraphs to KernelAgent with optional parallelism. diff --git a/Fuser/pipeline.py b/Fuser/pipeline.py index f444473..9d89ba4 100644 --- a/Fuser/pipeline.py +++ b/Fuser/pipeline.py @@ -57,7 +57,7 @@ def run_pipeline( verify: bool = True, compose_max_iters: int = 5, target_platform: str = "cuda", - test_timeout_s: int = 30, + test_timeout_s: int = 300, ) -> dict: # Select default KernelAgent model if not provided: prefer GPT-5 for Level 2/3 if dispatch_model is None: diff --git a/README.md b/README.md index 5e2eaa7..159a801 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ Every stage writes artifacts to a run directory under `.optimize//`, inc - Linux or macOS - **GPU Requirements (one of the following):** - **CUDA**: NVIDIA GPU with CUDA support + - **ROCm**: AMD GPU with ROCm 6.x+ (e.g., Instinct MI300X) - **XPU**: Intel GPU with oneAPI support (Arc, Data Center GPUs, or integrated Xe graphics) - Triton (installed separately: `pip install triton` or nightly from source) - PyTorch (https://pytorch.org/get-started/locally/) @@ -42,6 +43,29 @@ pip install -e . ### Platform-Specific PyTorch Installation +#### AMD ROCm (AMD GPUs) +```bash +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2 +``` + +**Note:** AMD ROCm support requires: +- ROCm 6.x installed and `rocprofv3` (or `rocprof`) on `$PATH` +- Compatible AMD GPU (e.g., Instinct MI300X) + +For optimization, use the bundled config as a quickstart: +```bash +python examples/run_opt_manager.py \ + --kernel-dir examples/optimize_01_matvec/ \ + --config examples/configs/amd.yaml +``` + +Verify your ROCm installation: +```python +import torch +print(torch.cuda.is_available()) # True if ROCm PyTorch detects GPU +print(torch.version.hip) # Should print the HIP/ROCm version +``` + #### Intel XPU (Intel GPUs) ```bash pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu @@ -225,8 +249,18 @@ KernelAgent supports multiple GPU platforms for Triton kernel execution: | Platform | Device String | Flag | Status | |----------|---------------|------|--------| | NVIDIA CUDA | `cuda` | `--target-platform cuda` (default) | Fully supported | +| AMD ROCm | `rocm` | `--target-platform rocm` | Supported | | Intel XPU | `xpu` | `--target-platform xpu` | Supported | +### AMD ROCm Notes + +When targeting AMD ROCm, KernelAgent automatically: +- Uses `rocprofv3` (or `rocprof` as fallback) for hardware profiling +- Applies ROCm-specific Triton block/wave occupancy hints +- Generates appropriate device availability checks + +See `examples/configs/amd.yaml` for a ready-to-use MI300X configuration. + ### Intel XPU Notes When targeting Intel XPU, KernelAgent automatically: @@ -237,9 +271,9 @@ When targeting Intel XPU, KernelAgent automatically: ### Verifying Platform Setup ```python -# Check CUDA availability +# Check CUDA/ROCm availability import torch -print("CUDA available:", torch.cuda.is_available()) +print("CUDA/ROCm available:", torch.cuda.is_available()) # Check XPU availability print("XPU available:", hasattr(torch, 'xpu') and torch.xpu.is_available()) diff --git a/examples/configs/amd.yaml b/examples/configs/amd.yaml new file mode 100644 index 0000000..c292abf --- /dev/null +++ b/examples/configs/amd.yaml @@ -0,0 +1,39 @@ +# AMD platform config for MI300X +# +# Usage: +# python examples/run_opt_manager.py \ +# --kernel-dir examples/optimize_01_matvec \ +# --strategy amd +# --config examples/configs/amd.yaml + +strategy: beam_search +num_workers: 4 +strategy_config: + num_top_kernels: 2 + num_bottlenecks: 2 +openai_model: gpt-5 +high_reasoning_effort: true + +# Worker configuration +benchmark_warmup: 25 +benchmark_repeat: 100 +divergence_threshold: 50.0 +target_platform: rocm +gpu_name: "AMD Instinct MI300X" + +platform: + # Manager-level components + verifier: rocm + benchmarker: rocm + worker_runner: rocm + # Worker-level components + specs_provider: rocm + profiler: rocm + roofline_analyzer: rocm + bottleneck_analyzer: rocm + rag_prescriber: rocm + +templates: + kernel_optimization: triton_kernel_agent/templates/kernel_optimization.j2 + reflexion_prompt: triton_kernel_agent/templates/reflexion_prompt.j2 + triton_guidelines: triton_kernel_agent/templates/triton_guidelines.j2 diff --git a/examples/configs/nvidia.yaml b/examples/configs/nvidia.yaml index b06b8b6..196caab 100644 --- a/examples/configs/nvidia.yaml +++ b/examples/configs/nvidia.yaml @@ -7,6 +7,7 @@ # Usage: # python examples/run_opt_manager.py \ # --kernel-dir examples/optimize_01_matvec \ +# --strategy nvidia # --config examples/configs/nvidia.yaml strategy: beam_search diff --git a/examples/run_opt_manager.py b/examples/run_opt_manager.py index 14277b5..b0976bf 100644 --- a/examples/run_opt_manager.py +++ b/examples/run_opt_manager.py @@ -40,7 +40,7 @@ _CONFIGS_DIR = Path(__file__).resolve().parent / "configs" # Available strategies and their config files. -_STRATEGIES = ["beam_search", "greedy", "noop", "nvidia"] +_STRATEGIES = ["beam_search", "greedy", "noop", "nvidia", "amd"] def _run_strategy( diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py index 0c984b4..84ce5e9 100644 --- a/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py @@ -17,14 +17,16 @@ This module contains the GPU hardware specifications database used for performance analysis and bottleneck identification. Updated to include -specific SKU variants for multi-SKU GPUs like A100 and H100. +specific SKU variants for multi-SKU GPUs like A100 and H100, and AMD +Instinct GPUs for ROCm support. Sources: - NVIDIA official specifications and datasheets +- AMD official specifications and datasheets - TechPowerUp GPU Database - Manufacturer datasheets -Last Updated: January 2026 +Last Updated: March 2026 """ from types import MappingProxyType @@ -181,6 +183,95 @@ "form_factor": "PCIe", "tdp_w": 360, }, + # ----------------------------------------------------------------------- + # AMD Instinct GPU SKUs (ROCm / HIP) + # ----------------------------------------------------------------------- + # AMD Instinct MI300X (CDNA3 / gfx942) + # Sources: AMD product page, Hot Chips 35 (2023) + "AMD Instinct MI300X": { + "name": "AMD Instinct MI300X", + "architecture": "CDNA3", + "gfx_target": "gfx942", + "peak_fp32_tflops": 163.4, + "peak_fp16_tflops": 1307.4, # BF16/FP16 matrix (without sparsity) + "peak_bf16_tflops": 1307.4, + "peak_memory_bw_gbps": 5300, # 5.3 TB/s HBM3 + "cu_count": 304, # Compute Units (AMD equiv of SM) + "sm_count": 304, # Alias for compatibility + "max_threads_per_cu": 2048, + "max_threads_per_sm": 2048, # Alias for compatibility + "wavefront_size": 64, + "l1_cache_kb": 32, # L1 per CU (vector L1D) + "l2_cache_mb": 256, # Total Infinity Cache (across all dies) + "memory_gb": 192, + "memory_type": "HBM3", + "form_factor": "OAM", + "tdp_w": 750, + }, + # AMD Instinct MI300A (CDNA3 / gfx942, APU variant) + "AMD Instinct MI300A": { + "name": "AMD Instinct MI300A", + "architecture": "CDNA3", + "gfx_target": "gfx942", + "peak_fp32_tflops": 122.6, + "peak_fp16_tflops": 980.6, + "peak_bf16_tflops": 980.6, + "peak_memory_bw_gbps": 3200, # Unified HBM3 (shared with CPU) + "cu_count": 228, + "sm_count": 228, + "max_threads_per_cu": 2048, + "max_threads_per_sm": 2048, + "wavefront_size": 64, + "l1_cache_kb": 32, + "l2_cache_mb": 192, + "memory_gb": 128, + "memory_type": "HBM3", + "form_factor": "OAM", + "tdp_w": 550, + }, + # AMD Instinct MI350X (CDNA4 / gfx950) + # Sources: AMD press release (Nov 2024), estimated specs + "AMD Instinct MI350X": { + "name": "AMD Instinct MI350X", + "architecture": "CDNA4", + "gfx_target": "gfx950", + "peak_fp32_tflops": 288.0, + "peak_fp16_tflops": 2304.0, # BF16 matrix estimate + "peak_bf16_tflops": 2304.0, + "peak_memory_bw_gbps": 8000, # ~8 TB/s HBM3E + "cu_count": 304, + "sm_count": 304, + "max_threads_per_cu": 2048, + "max_threads_per_sm": 2048, + "wavefront_size": 64, + "l1_cache_kb": 32, + "l2_cache_mb": 256, + "memory_gb": 288, + "memory_type": "HBM3E", + "form_factor": "OAM", + "tdp_w": 1000, + }, + # AMD Instinct MI250X (CDNA2 / gfx90a) + "AMD Instinct MI250X": { + "name": "AMD Instinct MI250X", + "architecture": "CDNA2", + "gfx_target": "gfx90a", + "peak_fp32_tflops": 47.9, + "peak_fp16_tflops": 383.0, + "peak_bf16_tflops": 383.0, + "peak_memory_bw_gbps": 3277, # HBM2e + "cu_count": 220, + "sm_count": 220, + "max_threads_per_cu": 2048, + "max_threads_per_sm": 2048, + "wavefront_size": 64, + "l1_cache_kb": 16, + "l2_cache_mb": 32, + "memory_gb": 128, + "memory_type": "HBM2e", + "form_factor": "OAM", + "tdp_w": 560, + }, } # Make database read-only to prevent accidental modification diff --git a/kernel_perf_agent/kernel_opt/profiler/rocprof_profiler.py b/kernel_perf_agent/kernel_opt/profiler/rocprof_profiler.py new file mode 100644 index 0000000..760b5ee --- /dev/null +++ b/kernel_perf_agent/kernel_opt/profiler/rocprof_profiler.py @@ -0,0 +1,518 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +rocprof Profiling Module for Triton Kernels on AMD GPUs. + +This module wraps rocprof (AMD ROCm profiler) to collect hardware performance +counters for Triton kernels running on AMD GPUs. Supports both rocprof v1/v2 +(``rocprof``) and rocprofiler-sdk (``rocprofv3``) CLI interfaces. + +Counter Categories +------------------ +- **SQ counters**: Shader processor activity (SQ_WAVES, SQ_INSTS_VALU, etc.) +- **TCC counters**: L2 cache / memory controller (TCC_HIT, TCC_MISS, etc.) +- **EA / GRBM counters**: Memory bandwidth and bus activity + +The collected counters are mapped to a metric dictionary compatible with the +rest of the KernelAgent optimization pipeline. +""" + +from __future__ import annotations + +import json +import os +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional + +# --------------------------------------------------------------------------- +# Hardware counter sets collected via rocprof -i +# --------------------------------------------------------------------------- + +# PMC counters requested from the hardware. Each line in the input file +# specifies one pass of the hardware; rocprof runs the kernel once per pass. +# Keep pass count low to avoid excessive re-runs. + +# Pass 1: Wavefront / shader utilisation +_PMC_PASS1 = [ + "SQ_WAVES", # total wavefronts launched + "SQ_INSTS_VALU", # VALU (vector ALU) instructions executed + "SQ_INSTS_SALU", # SALU (scalar ALU) instructions executed + "SQ_INSTS_VMEM_RD", # vector memory read instructions + "SQ_INSTS_VMEM_WR", # vector memory write instructions + "SQ_INSTS_LDS", # LDS instructions + "SQ_WAIT_INST_ANY", # stall cycles waiting on any instruction +] + +# Pass 2: L2 cache activity +_PMC_PASS2 = [ + "TCC_HIT_sum", # L2 cache hits (sum over all TCC slices) + "TCC_MISS_sum", # L2 cache misses + "TCC_EA_RDREQ_sum", # EA read requests (to HBM) + "TCC_EA_WRREQ_sum", # EA write requests (to HBM) +] + +# Pass 3: Memory bandwidth (derived counters) +# FETCH_SIZE and WRITE_SIZE use the same internal hardware PMC block and +# cannot be collected together in a single pass on rocprofv3 / gfx942. +# Collecting them on separate lines avoids error 38 ("Request exceeds the +# capabilities of the hardware to collect") and the associated hang. +_PMC_PASS3 = [ + "FETCH_SIZE", # bytes fetched from memory (in KB, vendor metric) +] + +_PMC_PASS4 = [ + "WRITE_SIZE", # bytes written to memory (in KB, vendor metric) +] + +# Derived metric keys returned in the normalized metrics dict +ROCM_METRIC_KEYS = [ + "sq_waves", + "sq_insts_valu", + "sq_insts_salu", + "sq_insts_vmem_rd", + "sq_insts_vmem_wr", + "sq_insts_lds", + "sq_wait_inst_any", + "tcc_hit", + "tcc_miss", + "tcc_cache_hit_rate_pct", + "tcc_ea_rdreq", + "tcc_ea_wrreq", + "fetch_size_kb", + "write_size_kb", + # Derived + "valu_utilization_pct", # SQ_INSTS_VALU / (SQ_WAVES * max_valu_per_wave) + "memory_bound_pct", # heuristic: vmem instructions / total instructions + "compute_sol_pct", # heuristic compute SOL for roofline compat + "memory_sol_pct", # heuristic memory SOL for roofline compat +] + + +# --------------------------------------------------------------------------- +# rocprof input file generation +# --------------------------------------------------------------------------- + + +def _write_rocprof_input(output_dir: Path) -> Path: + """Write a rocprof PMC input file requesting our counter passes. + + rocprof reads this file and schedules one hardware pass per ``pmc:`` line. + Counter names listed on the same line are collected in a single pass. + + Args: + output_dir: Directory to write the input file. + + Returns: + Path to the written input file. + """ + lines = [ + "# rocprof PMC input file — generated by KernelAgent", + f"pmc: {' '.join(_PMC_PASS1)}", + f"pmc: {' '.join(_PMC_PASS2)}", + f"pmc: {' '.join(_PMC_PASS3)}", + f"pmc: {' '.join(_PMC_PASS4)}", + ] + input_file = output_dir / "rocprof_input.txt" + input_file.write_text("\n".join(lines) + "\n") + return input_file + + +# --------------------------------------------------------------------------- +# rocprof invocation +# --------------------------------------------------------------------------- + + +def _find_rocprof() -> str: + """Find the rocprof binary, preferring rocprofv3 when available.""" + for candidate in ("rocprofv3", "rocprof"): + found = shutil.which(candidate) + if found: + return found + # Fallback to common installation path + for path in ("/opt/rocm/bin/rocprof", "/usr/bin/rocprof"): + if Path(path).exists(): + return path + raise FileNotFoundError( + "rocprof binary not found. Install ROCm and ensure rocprof is on PATH." + ) + + +def profile_triton_kernel_rocm( + benchmark_script: Path, + workdir: Path, + out_prefix: str = "rocprof_output", + python_executable: Optional[str] = None, + rocprof_bin: Optional[str] = None, + timeout: int = 360, +) -> Path: + """Profile a Triton kernel using rocprof. + + Runs the benchmark script twice: + 1. ``rocprof --stats`` — collects kernel timing (duration in ns). + 2. ``rocprof -i input.txt`` — collects hardware PMC counters. + + The two CSV outputs are merged into a single JSON metrics file. + + Args: + benchmark_script: Path to the Python script that calls the kernel. + workdir: Working directory for execution and output files. + out_prefix: Prefix for output filenames. + python_executable: Python executable (default: sys.executable). + rocprof_bin: Path to rocprof binary (auto-detect if None). + timeout: Timeout in seconds per rocprof invocation. + + Returns: + Path to the merged JSON metrics file. + + Raises: + FileNotFoundError: If rocprof binary or benchmark script not found. + RuntimeError: If rocprof profiling fails. + """ + if python_executable is None: + python_executable = sys.executable + + if rocprof_bin is None: + rocprof_bin = _find_rocprof() + + benchmark_script = benchmark_script.resolve() + if not benchmark_script.exists(): + raise FileNotFoundError(f"Benchmark script not found: {benchmark_script}") + + workdir = workdir.resolve() + workdir.mkdir(parents=True, exist_ok=True) + + env = os.environ.copy() + env["TRITON_CACHE_DIR"] = str(workdir / ".triton_cache") + + stats_csv = workdir / f"{out_prefix}_stats.csv" + pmc_csv = workdir / f"{out_prefix}_pmc.csv" + metrics_json = workdir / f"{out_prefix}_metrics.json" + + # ------------------------------------------------------------------ + # Pass 1: timing stats (--stats) + # ------------------------------------------------------------------ + _run_rocprof_stats( + rocprof_bin=rocprof_bin, + benchmark_script=benchmark_script, + workdir=workdir, + out_csv=stats_csv, + python_executable=python_executable, + env=env, + timeout=timeout, + ) + + # ------------------------------------------------------------------ + # Pass 2: hardware PMC counters (-i input.txt) + # ------------------------------------------------------------------ + input_file = _write_rocprof_input(workdir) + _run_rocprof_pmc( + rocprof_bin=rocprof_bin, + benchmark_script=benchmark_script, + workdir=workdir, + input_file=input_file, + out_csv=pmc_csv, + python_executable=python_executable, + env=env, + timeout=timeout, + ) + + # ------------------------------------------------------------------ + # Merge outputs into metrics JSON + # ------------------------------------------------------------------ + metrics = _merge_rocprof_outputs(stats_csv, pmc_csv) + metrics_json.write_text(json.dumps(metrics, indent=2)) + + print(f"[rocprof] Metrics written: {metrics_json}") + return metrics_json + + +def _run_rocprof_stats( + rocprof_bin: str, + benchmark_script: Path, + workdir: Path, + out_csv: Path, + python_executable: str, + env: dict, + timeout: int, +) -> None: + """Run rocprof --stats to collect kernel timing.""" + # rocprofv3 requires --kernel-trace for --stats to work + is_v3 = "rocprofv3" in rocprof_bin + cmd = [rocprof_bin] + if is_v3: + cmd.extend(["--kernel-trace", "--stats", "-o", str(out_csv)]) + else: + cmd.extend(["--stats", "-o", str(out_csv)]) + cmd.extend(["--", python_executable, str(benchmark_script)]) + print(f"[rocprof] Running stats: {' '.join(cmd[:6])}...") + result = subprocess.run( + cmd, + cwd=str(workdir), + env=env, + capture_output=True, + text=True, + timeout=timeout, + ) + if result.returncode != 0: + raise RuntimeError( + f"rocprof --stats failed (rc={result.returncode}):\n" + f"{(result.stderr or result.stdout)[:500]}" + ) + print("[rocprof] Stats pass completed") + + +def _run_rocprof_pmc( + rocprof_bin: str, + benchmark_script: Path, + workdir: Path, + input_file: Path, + out_csv: Path, + python_executable: str, + env: dict, + timeout: int, +) -> None: + """Run rocprof -i to collect PMC counters.""" + # rocprofv3 uses --pmc or -i for counter collection + cmd = [ + rocprof_bin, + "-i", + str(input_file), + "-o", + str(out_csv), + "--", + python_executable, + str(benchmark_script), + ] + print(f"[rocprof] Running PMC counters: {' '.join(cmd[:6])}...") + result = subprocess.run( + cmd, + cwd=str(workdir), + env=env, + capture_output=True, + text=True, + timeout=timeout, + ) + if result.returncode != 0: + raise RuntimeError( + f"rocprof PMC pass failed (rc={result.returncode}):\n" + f"{(result.stderr or result.stdout)[:500]}" + ) + print("[rocprof] PMC pass completed") + + +# --------------------------------------------------------------------------- +# Output parsing +# --------------------------------------------------------------------------- + + +def _parse_stats_csv(csv_path: Path) -> Dict[str, Any]: + """Parse rocprof --stats CSV to extract kernel durations.""" + if not csv_path.exists(): + return {} + + import csv + + rows: List[Dict[str, str]] = [] + try: + with csv_path.open(newline="") as f: + reader = csv.DictReader(f) + rows = list(reader) + except Exception: + return {} + + if not rows: + return {} + + # rocprof stats CSV columns: Name, Calls, TotalDurationNs, AverageNs, ... + # Pick the last kernel invocation (after warmup) + row = rows[-1] + result: Dict[str, Any] = {} + for key in ("Name", "Calls", "TotalDurationNs", "AverageNs", "Percentage"): + if key in row: + try: + result[key.lower()] = float(row[key]) if key != "Name" else row[key] + except (ValueError, TypeError): + result[key.lower()] = row[key] + + if "averagens" in result: + result["duration_ns"] = result["averagens"] + result["duration_ms"] = result["averagens"] / 1e6 + + return result + + +def _parse_pmc_csv(csv_path: Path) -> Dict[str, Any]: + """Parse rocprof PMC CSV to extract hardware counter values.""" + if not csv_path.exists(): + return {} + + import csv + + rows: List[Dict[str, str]] = [] + try: + with csv_path.open(newline="") as f: + reader = csv.DictReader(f) + rows = list(reader) + except Exception: + return {} + + if not rows: + return {} + + # rocprof PMC CSV columns: Index, KernelName, gpu-id, queue-id, + # queue-index, pid, tid, grd, wgr, lds, scr, arch_vgpr, accum_vgpr, + # sgpr, wave_size, sig, obj, , , ... + # Aggregate all rows by summing numeric counter columns + aggregated: Dict[str, float] = {} + counter_cols = [ + k + for k in rows[0].keys() + if k + not in ( + "Index", + "KernelName", + "gpu-id", + "queue-id", + "queue-index", + "pid", + "tid", + "sig", + "obj", + ) + ] + + for row in rows: + for col in counter_cols: + try: + aggregated[col] = aggregated.get(col, 0.0) + float(row[col]) + except (ValueError, TypeError, KeyError): + pass + + return aggregated + + +def _merge_rocprof_outputs( + stats_csv: Path, + pmc_csv: Path, +) -> Dict[str, Any]: + """Merge stats and PMC outputs into a normalized metrics dictionary. + + The returned dict is compatible with the rest of the KernelAgent pipeline + (same structure expected by RooflineAnalyzer and BottleneckAnalyzer). + """ + stats = _parse_stats_csv(stats_csv) + pmc = _parse_pmc_csv(pmc_csv) + + # Raw counters (lower-cased for consistency) + raw: Dict[str, Any] = {} + for k, v in pmc.items(): + raw[k.lower()] = v + + # Normalize to our metric keys + metrics: Dict[str, Any] = { + "sq_waves": raw.get("sq_waves", 0.0), + "sq_insts_valu": raw.get("sq_insts_valu", 0.0), + "sq_insts_salu": raw.get("sq_insts_salu", 0.0), + "sq_insts_vmem_rd": raw.get("sq_insts_vmem_rd", 0.0), + "sq_insts_vmem_wr": raw.get("sq_insts_vmem_wr", 0.0), + "sq_insts_lds": raw.get("sq_insts_lds", 0.0), + "sq_wait_inst_any": raw.get("sq_wait_inst_any", 0.0), + "tcc_hit": raw.get("tcc_hit_sum", 0.0), + "tcc_miss": raw.get("tcc_miss_sum", 0.0), + "tcc_ea_rdreq": raw.get("tcc_ea_rdreq_sum", 0.0), + "tcc_ea_wrreq": raw.get("tcc_ea_wrreq_sum", 0.0), + "fetch_size_kb": raw.get("fetch_size", 0.0), + "write_size_kb": raw.get("write_size", 0.0), + # Timing from stats pass + "duration_ms": stats.get("duration_ms", 0.0), + } + + # Derived: L2 cache hit rate + tcc_total = metrics["tcc_hit"] + metrics["tcc_miss"] + metrics["tcc_cache_hit_rate_pct"] = ( + 100.0 * metrics["tcc_hit"] / tcc_total if tcc_total > 0 else 0.0 + ) + + # Derived: VALU utilization (fraction of cycles spent on vector ALU) + total_insts = ( + metrics["sq_insts_valu"] + + metrics["sq_insts_salu"] + + metrics["sq_insts_vmem_rd"] + + metrics["sq_insts_vmem_wr"] + + metrics["sq_insts_lds"] + ) + metrics["valu_utilization_pct"] = ( + 100.0 * metrics["sq_insts_valu"] / total_insts if total_insts > 0 else 0.0 + ) + + # Derived: memory-bound heuristic — fraction of instructions that are memory ops + vmem_insts = metrics["sq_insts_vmem_rd"] + metrics["sq_insts_vmem_wr"] + metrics["memory_bound_pct"] = ( + 100.0 * vmem_insts / total_insts if total_insts > 0 else 0.0 + ) + + # Derive roofline-compatible SOL estimates from counter ratios. + # These are heuristic approximations since rocprof does not provide + # "Speed of Light" metrics directly (unlike NCU). + # + # Compute SOL: use VALU utilization as a proxy for compute utilization. + # Memory SOL: use cache miss rate × memory traffic as a bandwidth proxy. + # Both are normalized to [0, 100]. + metrics["compute_sol_pct"] = min(metrics["valu_utilization_pct"], 100.0) + metrics["memory_sol_pct"] = min(metrics["memory_bound_pct"], 100.0) + + # Include raw counter dump for debugging / LLM consumption + metrics["_raw_counters"] = {k: v for k, v in raw.items()} + + return metrics + + +# --------------------------------------------------------------------------- +# Public helper: load metrics from a previously written JSON file +# --------------------------------------------------------------------------- + + +def load_rocm_metrics(json_path: Path) -> Dict[str, Any]: + """Load rocprof metrics from a JSON file produced by :func:`profile_triton_kernel_rocm`. + + Args: + json_path: Path to the JSON file. + + Returns: + Metrics dictionary. + + Raises: + FileNotFoundError: If the file does not exist. + """ + if not json_path.exists(): + raise FileNotFoundError(f"ROCm metrics JSON not found: {json_path}") + return json.loads(json_path.read_text()) + + +def metrics_to_prompt_rocm(metrics: Dict[str, Any]) -> str: + """Convert ROCm metrics dict to a JSON string for LLM prompts. + + Args: + metrics: Metrics dictionary from :func:`load_rocm_metrics`. + + Returns: + JSON string (drops internal ``_raw_counters`` key for brevity). + """ + # Exclude raw counter dump from LLM context (too verbose) + filtered = {k: v for k, v in metrics.items() if k != "_raw_counters"} + return json.dumps(filtered, indent=2) diff --git a/kernel_perf_agent/kernel_opt/roofline/rocm_roofline.py b/kernel_perf_agent/kernel_opt/roofline/rocm_roofline.py new file mode 100644 index 0000000..167f076 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/roofline/rocm_roofline.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +ROCm Roofline Analysis using rocprof hardware counters. + +Unlike the NVIDIA path (which uses NCU's built-in "Speed of Light" percent +metrics), AMD rocprof does not expose normalized SOL values directly. This +module derives heuristic SOL equivalents from raw hardware counters collected +by :mod:`kernel_perf_agent.kernel_opt.profiler.rocprof_profiler`. + +Counter → SOL mapping +--------------------- +*Compute SOL* ← VALU instruction utilization (SQ_INSTS_VALU fraction of + total shader instructions). A fully VALU-saturated kernel + approaches 100 %. + +*Memory SOL* ← Memory instruction fraction (SQ_INSTS_VMEM_RD + + SQ_INSTS_VMEM_WR as fraction of total instructions), amplified + by the L2 miss rate (a high miss rate means more HBM traffic, + pushing the kernel closer to the memory roof). + +Both are heuristic and share the same interface as +:class:`kernel_perf_agent.kernel_opt.roofline.ncu_roofline.RooflineAnalyzer` +so the rest of the optimization pipeline can use them transparently. +""" + +from __future__ import annotations + +import logging +from dataclasses import asdict, dataclass, field +from typing import Any + + +@dataclass +class ROCmRooflineConfig: + """Configuration for ROCm roofline analysis.""" + + threshold_pct: float = 85.0 # Lower than NCU (heuristic, not exact SOL) + early_stop: bool = True + convergence_rounds: int = 5 + min_improvement_pct: float = 0.1 + underutilized_threshold: float = 40.0 # Both SOL < this → underutilized + miss_rate_amplifier: float = 1.5 # Scales memory SOL by L2 miss rate + + +@dataclass +class ROCmRooflineResult: + """Result of ROCm roofline analysis from rocprof counters.""" + + # Heuristic SOL equivalents (0–100 %) + compute_sol_pct: float + memory_sol_pct: float + + # Derived efficiency + efficiency_pct: float + at_roofline: bool + headroom_pct: float + + # Classification + bottleneck: str # "compute" | "memory" | "underutilized" + uses_tensor_cores: bool # Always False on ROCm (matrix cores tracked differently) + + # Supporting data + tcc_cache_hit_rate_pct: float + valu_utilization_pct: float + memory_bound_pct: float + + warnings: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + +class ROCmRooflineAnalyzer: + """Analyses ROCm kernel performance using rocprof hardware counters. + + Interface is compatible with + :class:`kernel_perf_agent.kernel_opt.roofline.ncu_roofline.RooflineAnalyzer` + so it can be used as a drop-in replacement in the optimization pipeline. + """ + + def __init__( + self, + config: ROCmRooflineConfig | None = None, + logger: logging.Logger | None = None, + ) -> None: + self.config = config or ROCmRooflineConfig() + self.logger = logger or logging.getLogger(__name__) + self._efficiency_history: list[float] = [] + + # ------------------------------------------------------------------ + # Core analysis + # ------------------------------------------------------------------ + + def analyze(self, rocm_metrics: dict[str, Any]) -> ROCmRooflineResult: + """Analyse rocprof metrics and return a roofline result. + + Args: + rocm_metrics: Metrics dict from + :func:`kernel_perf_agent.kernel_opt.profiler.rocprof_profiler.load_rocm_metrics`. + + Returns: + :class:`ROCmRooflineResult` with heuristic SOL estimates. + """ + warnings: list[str] = [] + + # Retrieve pre-computed heuristic SOL values (set by rocprof_profiler) + compute_sol = float(rocm_metrics.get("compute_sol_pct", 0.0)) + memory_sol = float(rocm_metrics.get("memory_sol_pct", 0.0)) + + valu_pct = float(rocm_metrics.get("valu_utilization_pct", compute_sol)) + mem_bound_pct = float(rocm_metrics.get("memory_bound_pct", memory_sol)) + cache_hit_rate = float(rocm_metrics.get("tcc_cache_hit_rate_pct", 100.0)) + + # Amplify memory SOL by L2 miss rate to better reflect HBM pressure. + # A 100 % cache hit rate means all traffic is served from L2 → low HBM + # utilisation despite many VMEM instructions. + l2_miss_rate = max(0.0, 100.0 - cache_hit_rate) + amplified_memory_sol = min( + 100.0, + memory_sol + * (1.0 + (l2_miss_rate / 100.0) * (self.config.miss_rate_amplifier - 1.0)), + ) + + if compute_sol == 0.0 and amplified_memory_sol == 0.0: + warnings.append( + "No compute or memory activity detected in rocprof counters" + ) + + efficiency = max(compute_sol, amplified_memory_sol) + bottleneck = self._classify_bottleneck(compute_sol, amplified_memory_sol) + at_roofline = efficiency >= self.config.threshold_pct + + return ROCmRooflineResult( + compute_sol_pct=round(compute_sol, 2), + memory_sol_pct=round(amplified_memory_sol, 2), + efficiency_pct=round(efficiency, 2), + at_roofline=at_roofline, + headroom_pct=round(max(0.0, 100.0 - efficiency), 2), + bottleneck=bottleneck, + uses_tensor_cores=False, # Matrix core detection not implemented yet + tcc_cache_hit_rate_pct=round(cache_hit_rate, 2), + valu_utilization_pct=round(valu_pct, 2), + memory_bound_pct=round(mem_bound_pct, 2), + warnings=warnings, + ) + + def _classify_bottleneck(self, compute_sol: float, memory_sol: float) -> str: + """Classify bottleneck based on heuristic SOL values.""" + threshold = self.config.underutilized_threshold + if compute_sol < threshold and memory_sol < threshold: + return "underutilized" + if memory_sol >= compute_sol: + return "memory" + return "compute" + + # ------------------------------------------------------------------ + # Convergence tracking (same interface as NvidiaRooflineAnalyzer) + # ------------------------------------------------------------------ + + def should_stop(self, result: ROCmRooflineResult) -> tuple[bool, str]: + """Check whether optimization should stop. + + Args: + result: :class:`ROCmRooflineResult` from :meth:`analyze`. + + Returns: + ``(should_stop, reason)`` tuple. + """ + self._efficiency_history.append(result.efficiency_pct) + + if self.config.early_stop and result.at_roofline: + return ( + True, + f"At roofline ({result.efficiency_pct:.1f}% SOL >= " + f"{self.config.threshold_pct}%)", + ) + + if len(self._efficiency_history) >= self.config.convergence_rounds: + recent = self._efficiency_history[-self.config.convergence_rounds :] + improvement = max(recent) - min(recent) + if improvement < self.config.min_improvement_pct: + return ( + True, + f"Converged (improvement {improvement:.2f}% < " + f"{self.config.min_improvement_pct}%)", + ) + + return False, "" + + def reset_history(self) -> None: + """Reset convergence tracking for a new optimization run.""" + self._efficiency_history = [] + + +def format_rocm_roofline_summary(result: ROCmRooflineResult) -> str: + """Format a human-readable summary of ROCm roofline analysis.""" + lines = [ + "=== ROCm Roofline Analysis (rocprof counters) ===", + f"SOL Efficiency (heuristic): {result.efficiency_pct:.1f}%", + f" Compute SOL: {result.compute_sol_pct:.1f}% (VALU utilization)", + f" Memory SOL: {result.memory_sol_pct:.1f}% (VMEM + L2 miss amplification)", + f" Bottleneck: {result.bottleneck}", + f" L2 Hit Rate: {result.tcc_cache_hit_rate_pct:.1f}%", + "", + ] + + if result.at_roofline: + lines.append("Status: AT ROOFLINE (heuristic)") + else: + lines.append(f"Headroom: {result.headroom_pct:.1f}%") + + if result.warnings: + lines.append(f"Warnings: {'; '.join(result.warnings)}") + + lines.append("") + lines.append("Note: ROCm SOL values are heuristic estimates from rocprof counters,") + lines.append("not exact hardware-reported percentages like NCU SOL metrics.") + + return "\n".join(lines) diff --git a/triton_kernel_agent/agent.py b/triton_kernel_agent/agent.py index 84bd5d5..0f75bad 100644 --- a/triton_kernel_agent/agent.py +++ b/triton_kernel_agent/agent.py @@ -43,7 +43,7 @@ def __init__( preferred_provider: BaseProvider | None = None, target_platform: PlatformConfig | None = None, no_cusolver: bool = False, - test_timeout_s: int = 30, + test_timeout_s: int = 300, ): """ Initialize the Triton Kernel Agent. diff --git a/triton_kernel_agent/manager.py b/triton_kernel_agent/manager.py index 015fbac..d48329c 100644 --- a/triton_kernel_agent/manager.py +++ b/triton_kernel_agent/manager.py @@ -39,7 +39,7 @@ def __init__( high_reasoning_effort: bool = True, target_platform: str = "cuda", no_cusolver: bool = False, - test_timeout_s: int = 30, + test_timeout_s: int = 300, ): """ Initialize the worker manager. @@ -237,7 +237,7 @@ def worker_process( high_reasoning_effort: bool, target_platform: str, no_cusolver: bool = False, - test_timeout_s: int = 30, + test_timeout_s: int = 300, ): """ Worker process for kernel verification and refinement. diff --git a/triton_kernel_agent/opt_worker_component/orchestrator/optimization_orchestrator.py b/triton_kernel_agent/opt_worker_component/orchestrator/optimization_orchestrator.py index 219b60c..6e97b05 100644 --- a/triton_kernel_agent/opt_worker_component/orchestrator/optimization_orchestrator.py +++ b/triton_kernel_agent/opt_worker_component/orchestrator/optimization_orchestrator.py @@ -177,6 +177,11 @@ def _get_triton_kernel_metrics(ncu_metrics: dict[str, Any]) -> dict[str, Any]: NCU profiles all CUDA kernels including PyTorch internals (at::*). This function finds the actual Triton kernel metrics. + Note: This helper is NCU/CUDA-specific. On the ROCm path, + ``profiler_results.metrics`` is already a flat dict and should be + passed directly to ``ROCmRooflineAnalyzer.analyze`` without going + through this function. + Args: ncu_metrics: Dict keyed by kernel name with metric dicts as values @@ -359,6 +364,7 @@ def optimize_kernel( # Two-kernel tracking: track best-by-runtime and best-by-SOL independently # This prevents mixing metrics from different kernels + profiling_globally_available = baseline_sol > 0.0 best_runtime_kernel = kernel_code best_runtime_time = best_time best_runtime_sol = baseline_sol @@ -374,23 +380,67 @@ def optimize_kernel( self.logger.info(f"ROUND {round_num}/{max_opt_rounds}") self.logger.info("=" * 80) - # Profile and analyze bottleneck - bottleneck_results, roofline_result, ncu_metrics = ( - self._profile_and_analyze(current_kernel, problem_file, round_num) - ) - - # Log roofline for the kernel we just profiled - if ncu_metrics: - flat_metrics = _get_triton_kernel_metrics(ncu_metrics) - roofline_check = self.roofline_analyzer.analyze( - ncu_metrics=flat_metrics, + # Profile and analyze bottleneck unless a prior ROCm profiling failure + # already forced this optimization session into synthetic mode. + if profiling_globally_available: + bottleneck_results, roofline_result, ncu_metrics = ( + self._profile_and_analyze(current_kernel, problem_file, round_num) ) - self.logger.info( - f"[{round_num}] Roofline (kernel_round_{round_num - 1}): " - f"{roofline_check.bottleneck}-bound, {roofline_check.efficiency_pct:.1f}% SOL " - f"(Compute: {roofline_check.compute_sol_pct:.1f}%, " - f"Memory: {roofline_check.memory_sol_pct:.1f}%)" + profiling_available = bool(ncu_metrics) + if not profiling_available: + profiling_globally_available = False + else: + self.logger.warning( + f"[{round_num}] Profiling unavailable from earlier ROCm failure; reusing synthetic fallback mode for this round" ) + synthetic_category = self.bottleneck_override or "underutilized" + bottleneck_results = [ + BottleneckResult( + category=synthetic_category, + summary=f"Synthetic session fallback ({synthetic_category}-bound).", + reasoning="Earlier ROCm profiling failed in this optimization session, so subsequent rounds stay in synthetic mode instead of retrying the profiler.", + root_causes=[ + { + "cause": "ROCm profiling unavailable for this optimization session", + "evidence": [], + "fixes": [ + { + "fix": "Continue with conservative AMD-friendly heuristic tuning and benchmark feedback only.", + "rationale": "Avoids repeated rocprof failures after the first confirmed profiler failure.", + } + ], + } + ], + recommended_fixes=[ + { + "fix": "Do not re-enter rocprof in later rounds of the same optimization session.", + "rationale": "Keeps the search alive without repeating the same failing profiler path.", + } + ], + ) + ] + roofline_result = None + ncu_metrics = None + profiling_available = False + + # Log roofline for the kernel we just profiled. + # _get_triton_kernel_metrics / roofline_analyzer are NCU-specific; + # on the ROCm path this block is a no-op (roofline is already logged + # inside _profile_and_analyze via bottleneck_analyzer.roofline). + if ncu_metrics: + try: + flat_metrics = _get_triton_kernel_metrics(ncu_metrics) + roofline_check = self.roofline_analyzer.analyze( + ncu_metrics=flat_metrics, + ) + self.logger.info( + f"[{round_num}] Roofline (kernel_round_{round_num - 1}): " + f"{roofline_check.bottleneck}-bound, {roofline_check.efficiency_pct:.1f}% SOL " + f"(Compute: {roofline_check.compute_sol_pct:.1f}%, " + f"Memory: {roofline_check.memory_sol_pct:.1f}%)" + ) + except Exception: + pass # ROCm path: roofline already logged in _profile_and_analyze if not bottleneck_results: self.logger.warning( @@ -505,10 +555,18 @@ def optimize_kernel( ) new_time = bench_results["time_ms"] - # Profile the NEW kernel to get its SOL metrics - new_kernel_metrics = self._profile_kernel_for_sol( - optimized_kernel, problem_file, round_num - ) + # Profile the NEW kernel to get its SOL metrics only when profiling + # is available for this round. Synthetic fallback rounds must not + # re-enter the profiler path. + if profiling_available: + new_kernel_metrics = self._profile_kernel_for_sol( + optimized_kernel, problem_file, round_num + ) + else: + self.logger.info( + f"[{round_num}] Skipping post-benchmark SOL profiling because this round is running on synthetic fallback analysis" + ) + new_kernel_metrics = None new_sol = ( new_kernel_metrics.get("efficiency_pct", 0.0) if new_kernel_metrics @@ -614,8 +672,9 @@ def optimize_kernel( early_stop_reason = stop_reason break - # Profile the final best kernel to get its roofline - if best_round_num > 0: + # Profile the final best kernel to get its roofline only when profiler + # data was actually available during optimization. + if best_round_num > 0 and best_ncu_metrics is not None: final_kernel_file = self.artifact_dir / f"kernel_round_{best_round_num}.py" if final_kernel_file.exists(): self.logger.info( @@ -726,7 +785,14 @@ def _profile_and_analyze( Tuple of (bottleneck_results, roofline_result, ncu_metrics). All can be None if profiling fails. """ - self.logger.info(f"[{round_num}] Profiling current kernel with NCU...") + profiler_name = ( + "rocprof" + if self.profiler.__class__.__name__ == "ROCmKernelProfiler" + else "NCU" + ) + self.logger.info( + f"[{round_num}] Profiling current kernel with {profiler_name}..." + ) kernel_file_round = self.artifact_dir / f"kernel_round_{round_num - 1}.py" kernel_file_round.write_text(current_kernel) @@ -735,12 +801,80 @@ def _profile_and_analyze( ) if profiler_results is None: + if self.profiler.__class__.__name__ == "ROCmKernelProfiler": + self.logger.warning( + f"[{round_num}] ROCm profiling unavailable; using synthetic fallback bottleneck analysis" + ) + synthetic_category = self.bottleneck_override or "underutilized" + synthetic = BottleneckResult( + category=synthetic_category, + summary=f"Synthetic fallback after rocprof failure ({synthetic_category}-bound). Continue optimization without profiling metrics.", + reasoning="rocprof profiling failed on the ROCm path, so optimization continues with a conservative synthetic diagnosis instead of re-entering another profiling stage.", + root_causes=[ + { + "cause": "ROCm profiling unavailable for this round", + "evidence": [ + { + "metric": "rocprof", + "value": 0.0, + "interpretation": "profiling failed or timed out, so no hardware counters are available", + } + ], + "fixes": [ + { + "fix": "Apply conservative AMD-friendly tuning changes without depending on profiler counters, and avoid any immediate re-profiling loop in the same analysis step.", + "rationale": "Keeps the optimization loop moving forward when rocprof is unstable or blocked.", + } + ], + } + ], + recommended_fixes=[ + { + "fix": "Prefer AMD-safe heuristic tuning and continue to verification/benchmarking without another profiling pass in this analysis stage.", + "rationale": "Avoids the bad fallback->reprofile control flow.", + }, + { + "fix": "Reduce launch overhead and improve occupancy with conservative Triton meta-parameter tuning (e.g. BLOCK sizes, num_warps, num_stages) consistent with AMD execution characteristics.", + "rationale": "Provides actionable optimization direction even when hardware counters are unavailable.", + }, + ], + ) + return [synthetic], None, None self.logger.warning(f"[{round_num}] Profiling failed") return None, None, None ncu_metrics = profiler_results.metrics if not ncu_metrics: + if self.profiler.__class__.__name__ == "ROCmKernelProfiler": + self.logger.warning( + f"[{round_num}] ROCm profiling returned no metrics; using synthetic fallback bottleneck analysis" + ) + synthetic_category = self.bottleneck_override or "underutilized" + synthetic = BottleneckResult( + category=synthetic_category, + summary=f"Synthetic fallback after empty rocprof metrics ({synthetic_category}-bound).", + reasoning="rocprof returned no usable metrics on ROCm, so optimization continues with a conservative synthetic diagnosis.", + root_causes=[ + { + "cause": "ROCm profiler produced no usable metrics", + "evidence": [], + "fixes": [ + { + "fix": "Continue with heuristic AMD-friendly tuning instead of aborting the round.", + "rationale": "Preserves optimization progress when profiling data is unavailable.", + } + ], + } + ], + recommended_fixes=[ + { + "fix": "Use conservative tuning guided by the selected bottleneck class and benchmark feedback.", + "rationale": "Lets the round proceed without profiler-derived counters.", + } + ], + ) + return [synthetic], None, None return None, None, ncu_metrics # Run roofline analysis @@ -818,6 +952,14 @@ def _profile_kernel_for_sol( Dict with efficiency_pct, roofline_result, ncu_metrics, or None if profiling fails """ try: + # On ROCm, skip SOL profiling entirely when the profiler backend is rocprof. + # This avoids re-entering the same fragile profiling path after fallback. + if self.profiler.__class__.__name__ == "ROCmKernelProfiler": + self.logger.info( + f"[{round_num}] Skipping SOL profiling on ROCm to avoid re-entering rocprof after fallback" + ) + return None + # Write kernel to temp file for profiling kernel_file = self.artifact_dir / f"kernel_round_{round_num}_sol.py" kernel_file.write_text(kernel_code) diff --git a/triton_kernel_agent/opt_worker_component/profiling/__init__.py b/triton_kernel_agent/opt_worker_component/profiling/__init__.py index 8f8d18b..c255bab 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/__init__.py +++ b/triton_kernel_agent/opt_worker_component/profiling/__init__.py @@ -12,9 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Profiling infrastructure for NCU-based kernel analysis.""" +"""Profiling infrastructure for hardware-specific kernel analysis.""" from .kernel_profiler import KernelProfiler from .ncu_wrapper_factory import NCUWrapperFactory +from .rocm_kernel_profiler import ROCmKernelProfiler +from .rocprof_wrapper_factory import ROCmWrapperFactory -__all__ = ["NCUWrapperFactory", "KernelProfiler"] +__all__ = [ + "NCUWrapperFactory", + "KernelProfiler", + "ROCmWrapperFactory", + "ROCmKernelProfiler", +] diff --git a/triton_kernel_agent/opt_worker_component/profiling/rocm_kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/rocm_kernel_profiler.py new file mode 100644 index 0000000..e29c74e --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/rocm_kernel_profiler.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Profiles Triton kernels on AMD GPUs using rocprof.""" + +from __future__ import annotations + +import json +import logging +import subprocess +import time +from dataclasses import asdict, dataclass +from datetime import datetime +from functools import cached_property +from pathlib import Path +from typing import Any, Dict + +from kernel_perf_agent.kernel_opt.profiler.rocprof_profiler import ( + load_rocm_metrics, + profile_triton_kernel_rocm, +) +from triton_kernel_agent.opt_worker_component.profiling.rocprof_wrapper_factory import ( + ROCmWrapperFactory, +) + +# Default timeout for rocprof profiling in seconds (two passes = 2x NVIDIA time) +DEFAULT_ROCPROF_TIMEOUT_SECONDS = 600 + +# Default timeout for profiling semaphore (15 minutes) +DEFAULT_SEMAPHORE_TIMEOUT_SECONDS = 900 + + +@dataclass +class ROCmProfilerMetadata: + """Metadata about a ROCm profiling run.""" + + kernel_file: str + problem_file: str + round_num: int + timestamp: str + rocprof_bin: str | None + + +@dataclass +class ROCmProfilerResults: + """Results from a ROCm kernel profiling run. + + Designed to be a drop-in replacement for + :class:`triton_kernel_agent.opt_worker_component.profiling.kernel_profiler.ProfilerResults` + in the optimization pipeline. + """ + + metrics: Dict[str, Any] + metadata: ROCmProfilerMetadata + + def to_dict(self) -> Dict[str, Any]: + return { + "metrics": self.metrics, + "metadata": asdict(self.metadata), + } + + def to_json(self) -> str: + return json.dumps(self.to_dict(), indent=2) + + +class ROCmKernelProfiler: + """Profiles Triton kernels on AMD GPUs using rocprof. + + Drop-in replacement for + :class:`triton_kernel_agent.opt_worker_component.profiling.kernel_profiler.KernelProfiler` + for the ROCm/HIP platform. + """ + + def __init__( + self, + logger: logging.Logger, + artifacts_dir: Path, + logs_dir: Path, + rocprof_bin_path: str | None = None, + rocprof_timeout_seconds: int = DEFAULT_ROCPROF_TIMEOUT_SECONDS, + profiling_semaphore: Any | None = None, + ) -> None: + """ + Initialize the ROCm kernel profiler. + + Args: + logger: Logger instance. + artifacts_dir: Directory for optimization artifacts. + logs_dir: Directory for saving profiling logs. + rocprof_bin_path: Path to rocprof binary (auto-detect if None). + rocprof_timeout_seconds: Timeout per rocprof invocation. + profiling_semaphore: Semaphore to limit concurrent rocprof runs. + """ + self.logger = logger + self.artifacts_dir = artifacts_dir + self.logs_dir = logs_dir + self.rocprof_bin_path = rocprof_bin_path + self.rocprof_timeout_seconds = rocprof_timeout_seconds + self.profiling_semaphore = profiling_semaphore + self.wrapper_factory = ROCmWrapperFactory(logger) + + @cached_property + def rocprof_bin(self) -> str | None: + """Resolved rocprof binary path (cached).""" + import shutil + + if self.rocprof_bin_path: + return self.rocprof_bin_path + for candidate in ("rocprofv3", "rocprof"): + found = shutil.which(candidate) + if found: + return found + return None + + def _wait_with_backoff(self, attempt: int) -> None: + wait_time = 2**attempt + self.logger.warning(f"Retrying in {wait_time}s...") + time.sleep(wait_time) + + def profile_kernel( + self, + kernel_file: Path, + problem_file: Path, + round_num: int, + max_retries: int = 2, + ) -> ROCmProfilerResults | None: + """Profile a Triton kernel with rocprof (with retry logic). + + Args: + kernel_file: Path to kernel file. + problem_file: Path to problem file. + round_num: Current optimization round number. + max_retries: Maximum number of retry attempts. + + Returns: + :class:`ROCmProfilerResults` or ``None`` on failure. + """ + wrapper_file = self.wrapper_factory.create_rocprof_wrapper( + kernel_file, problem_file, self.artifacts_dir + ) + + semaphore_acquired = False + if self.profiling_semaphore is not None: + self.logger.info(f"[Round {round_num}] Waiting for profiling semaphore...") + semaphore_acquired = self.profiling_semaphore.acquire( + timeout=DEFAULT_SEMAPHORE_TIMEOUT_SECONDS + ) + if not semaphore_acquired: + self.logger.warning( + f"[Round {round_num}] Semaphore timeout after " + f"{DEFAULT_SEMAPHORE_TIMEOUT_SECONDS}s, skipping profiling" + ) + return None + self.logger.info(f"[Round {round_num}] Acquired profiling semaphore") + + try: + return self._profile_kernel_impl( + wrapper_file, kernel_file, problem_file, round_num, max_retries + ) + finally: + if semaphore_acquired: + self.profiling_semaphore.release() + self.logger.debug(f"[Round {round_num}] Released profiling semaphore") + + def _profile_kernel_impl( + self, + wrapper_file: Path, + kernel_file: Path, + problem_file: Path, + round_num: int, + max_retries: int, + ) -> ROCmProfilerResults | None: + """Internal profiling implementation (called with semaphore held).""" + + for attempt in range(1, max_retries + 1): + try: + self.logger.info( + f"[Round {round_num}] rocprof profiling attempt {attempt}/{max_retries}..." + ) + + metrics_json = profile_triton_kernel_rocm( + benchmark_script=wrapper_file, + workdir=self.artifacts_dir, + out_prefix=f"rocprof_round_{round_num}", + rocprof_bin=self.rocprof_bin, + timeout=self.rocprof_timeout_seconds, + ) + + metrics = load_rocm_metrics(metrics_json) + + results = ROCmProfilerResults( + metrics=metrics, + metadata=ROCmProfilerMetadata( + kernel_file=str(kernel_file), + problem_file=str(problem_file), + round_num=round_num, + timestamp=datetime.utcnow().isoformat() + "Z", + rocprof_bin=self.rocprof_bin, + ), + ) + + self._save_profiler_results(results) + self.logger.info( + f"✅ rocprof profiling completed for round {round_num}" + ) + return results + + except FileNotFoundError as e: + self.logger.error(f"❌ File not found during profiling: {e}") + return None + + except subprocess.TimeoutExpired: + is_final = attempt >= max_retries + if is_final: + self.logger.error( + f"❌ rocprof timed out after {self.rocprof_timeout_seconds}s " + f"(final attempt {attempt}/{max_retries})" + ) + return None + self.logger.debug( + f"rocprof timed out (attempt {attempt}/{max_retries})" + ) + self._wait_with_backoff(attempt) + + except Exception as e: + is_final = attempt >= max_retries + err_str = str(e) + if ( + "signal" in err_str.lower() + or "segfault" in err_str.lower() + or "sigsegv" in err_str.lower() + ): + self.logger.error( + "❌ rocprof crashed with segfault (likely ROCm/PyTorch version mismatch). " + "Profiling unavailable — optimization will use timing-only fallback." + ) + return None + if is_final: + self.logger.error( + f"❌ Unexpected error during profiling (final attempt): {e}", + exc_info=True, + ) + return None + self.logger.debug( + f"Unexpected error (attempt {attempt}/{max_retries}): {e}" + ) + self._wait_with_backoff(attempt) + + self.logger.error( + f"❌ rocprof profiling failed after {max_retries} attempts for round {round_num}" + ) + return None + + def _save_profiler_results(self, results: ROCmProfilerResults) -> None: + metrics_file = ( + self.logs_dir + / f"round{results.metadata.round_num:03d}_rocprof_metrics.json" + ) + with open(metrics_file, "w") as f: + f.write(results.to_json()) + self.logger.debug(f"Saved ROCm metrics: {metrics_file}") diff --git a/triton_kernel_agent/opt_worker_component/profiling/rocprof_wrapper_factory.py b/triton_kernel_agent/opt_worker_component/profiling/rocprof_wrapper_factory.py new file mode 100644 index 0000000..90199b6 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/rocprof_wrapper_factory.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""rocprof wrapper script generation for ROCm kernel profiling.""" + +import logging +from functools import cached_property +from pathlib import Path + +from jinja2 import Template + + +class ROCmWrapperFactory: + """Factory for creating rocprof wrapper scripts for profiling Triton kernels on AMD GPUs. + + Mirrors :class:`triton_kernel_agent.opt_worker_component.profiling.ncu_wrapper_factory.NCUWrapperFactory` + but targets ROCm / HIP instead of NVIDIA NCU. + """ + + WRAPPER_TEMPLATE = Path(__file__).parent / "rocprof_wrapper_template.j2" + + def __init__(self, logger: logging.Logger) -> None: + self.logger = logger + + @cached_property + def template(self) -> Template: + if not self.WRAPPER_TEMPLATE.exists(): + raise FileNotFoundError(f"Template not found: {self.WRAPPER_TEMPLATE}") + return Template(self.WRAPPER_TEMPLATE.read_text()) + + def create_rocprof_wrapper( + self, + kernel_file: Path, + problem_file: Path, + output_dir: Path, + dtype_inference: bool = True, + model_extraction: bool = True, + ) -> Path: + """Create a rocprof wrapper script for profiling. + + Args: + kernel_file: Path to kernel file. + problem_file: Path to problem file. + output_dir: Directory to write wrapper script. + dtype_inference: Enable automatic dtype inference from kernel source. + model_extraction: Enable model weight extraction for Conv/Linear kernels. + + Returns: + Path to created wrapper script. + """ + if not kernel_file.exists(): + raise FileNotFoundError(f"Kernel file not found: {kernel_file}") + if not problem_file.exists(): + raise FileNotFoundError(f"Problem file not found: {problem_file}") + + output_dir.mkdir(parents=True, exist_ok=True) + + wrapper_file = output_dir / "rocprof_wrapper.py" + wrapper_content = self.template.render( + kernel_file_parent=repr(str(kernel_file.parent)), + problem_file_parent=repr(str(problem_file.parent)), + kernel_module=kernel_file.stem, + problem_module=problem_file.stem, + dtype_inference=dtype_inference, + model_extraction=model_extraction, + ) + + wrapper_file.write_text(wrapper_content) + self.logger.info(f"Created rocprof wrapper: {wrapper_file}") + return wrapper_file diff --git a/triton_kernel_agent/opt_worker_component/profiling/rocprof_wrapper_template.j2 b/triton_kernel_agent/opt_worker_component/profiling/rocprof_wrapper_template.j2 new file mode 100644 index 0000000..6194871 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/rocprof_wrapper_template.j2 @@ -0,0 +1,236 @@ +{# +Copyright (c) Meta Platforms, Inc. and affiliates. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +#} + +"""rocprof profiling wrapper for AMD ROCm GPUs.""" +import importlib +import sys +import torch +import inspect +sys.path.insert(0, str({{ kernel_file_parent }})) +sys.path.insert(0, str({{ problem_file_parent }})) + +from {{ kernel_module }} import kernel_function + +_problem_mod = importlib.import_module({{ problem_module | tojson }}) +get_inputs = _problem_mod.get_inputs +get_init_inputs = _problem_mod.get_init_inputs + +# Try to get Model if it exists (for Conv, Linear, etc.) +has_model = hasattr(_problem_mod, 'Model') +if has_model: + Model = _problem_mod.Model + +# Get inputs +inputs = get_inputs() + +# Get additional initialization inputs (e.g., features, eps for RMSNorm) +init_inputs = get_init_inputs() + +{% if dtype_inference %} +# Infer required dtype from kernel function signature/docstring +required_dtype = None +try: + kernel_source = inspect.getsource(kernel_function) + if 'bfloat16' in kernel_source.lower(): + required_dtype = torch.bfloat16 + elif 'float16' in kernel_source.lower() or 'half' in kernel_source.lower(): + required_dtype = torch.float16 + elif 'float32' in kernel_source.lower(): + required_dtype = torch.float32 +except Exception: + pass +{% else %} +required_dtype = None +{% endif %} + +# Validate dtype +if required_dtype is not None and not isinstance(required_dtype, torch.dtype): + required_dtype = None + +# Prepare inputs: move to CUDA (ROCm uses torch.cuda via HIP) and convert dtype. +# IMPORTANT: Only convert floating-point tensors; preserve integer tensors. +cuda_inputs = [] +for inp in inputs: + if isinstance(inp, torch.Tensor): + if not inp.is_cuda: + inp = inp.cuda() + if required_dtype is not None and inp.is_floating_point() and inp.dtype != required_dtype: + inp = inp.to(required_dtype) + cuda_inputs.append(inp) + else: + cuda_inputs.append(inp) + +{% if model_extraction %} +# Check if this is a conv-like kernel that needs a Model to extract weights +needs_model = False +has_var_positional = False +has_var_keyword = False +kernel_params = [] +try: + sig = inspect.signature(kernel_function) + kernel_params = [ + name for name, p in sig.parameters.items() + if p.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + ] + param_kinds = [p.kind for p in sig.parameters.values()] + has_var_positional = any(k == inspect.Parameter.VAR_POSITIONAL for k in param_kinds) + has_var_keyword = any(k == inspect.Parameter.VAR_KEYWORD for k in param_kinds) + _MODEL_PARAM_NAMES = {'weight', 'w', 'kernel_size', 'stride', 'padding', 'dilation'} + if _MODEL_PARAM_NAMES.intersection(kernel_params): + needs_model = True + if not needs_model and (has_var_positional or has_var_keyword): + try: + src = inspect.getsource(kernel_function) + needs_model = any(kw in src for kw in ("weight", "is_weight", + "w.shape", "w.ndim", + "kernel_size", "dilation")) + except (OSError, TypeError): + pass +except Exception: + pass + +_CONV_TYPES = ( + torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, + torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, +) +_NORM_TYPES = ( + torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, + torch.nn.LayerNorm, torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, +) +_POOL_TYPES = ( + torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, + torch.nn.AvgPool1d, torch.nn.AvgPool2d, torch.nn.AvgPool3d, + torch.nn.AdaptiveAvgPool1d, torch.nn.AdaptiveAvgPool2d, torch.nn.AdaptiveAvgPool3d, + torch.nn.AdaptiveMaxPool1d, torch.nn.AdaptiveMaxPool2d, torch.nn.AdaptiveMaxPool3d, +) + +model_params = {} +_all_weights = [] + +if needs_model and has_model and init_inputs: + try: + model = Model(*init_inputs) if init_inputs else Model() + except TypeError: + model = Model() + + model = model.cuda() # ROCm: torch.cuda works via HIP + if required_dtype is not None: + model = model.to(required_dtype) + + for _, module in model.named_modules(): + if isinstance(module, (*_CONV_TYPES, torch.nn.Linear)): + if hasattr(module, 'weight') and module.weight is not None: + _all_weights.append(module.weight) + model_params.setdefault('weight', module.weight) + model_params.setdefault('w', module.weight) + if getattr(module, 'bias', None) is not None: + model_params.setdefault('conv_bias', module.bias) + model_params.setdefault('bias', module.bias) + for attr in ('stride', 'padding', 'dilation', 'output_padding'): + val = getattr(module, attr, None) + if val is not None: + model_params.setdefault(attr, val) + if hasattr(module, 'groups'): + model_params.setdefault('groups', module.groups) + + elif isinstance(module, _NORM_TYPES): + if getattr(module, 'weight', None) is not None: + model_params.setdefault('weight', module.weight) + model_params.setdefault('w', module.weight) + if getattr(module, 'bias', None) is not None: + model_params.setdefault('bias', module.bias) + if hasattr(module, 'eps'): + model_params['eps'] = module.eps + if hasattr(module, 'num_groups'): + model_params['num_groups'] = module.num_groups + if hasattr(module, 'normalized_shape'): + model_params['normalized_shape'] = module.normalized_shape + + elif isinstance(module, _POOL_TYPES): + for attr in ('kernel_size', 'stride', 'padding', 'dilation'): + val = getattr(module, attr, None) + if val is not None: + model_params.setdefault(attr, val) + + if hasattr(model, 'bias') and isinstance(model.bias, (torch.Tensor, torch.nn.Parameter)): + model_params['add_bias'] = model.bias + model_params.setdefault('bias', model.bias) + if not model_params['add_bias'].is_cuda: + model_params['add_bias'] = model_params['add_bias'].cuda() + if required_dtype is not None and model_params['add_bias'].is_floating_point(): + model_params['add_bias'] = model_params['add_bias'].to(required_dtype) + + +def run_kernel(): + """Helper to run the kernel with appropriate arguments.""" + {% if model_extraction %} + if needs_model and model_params: + if has_var_positional and _all_weights: + pos_args = list(cuda_inputs) + list(_all_weights) + config_kwargs = {} + for k, v in model_params.items(): + if k not in ('weight', 'w', 'bias', 'conv_bias', 'add_bias'): + if isinstance(v, (tuple, list)) and len(v) >= 1 and all(e == v[0] for e in v): + v = v[0] + config_kwargs[k] = v + return kernel_function(*pos_args, **config_kwargs) + else: + bound = {} + positional_idx = 0 + for pname in kernel_params: + if pname in model_params: + v = model_params[pname] + if isinstance(v, (tuple, list)) and len(v) >= 1 and all(e == v[0] for e in v): + v = v[0] + bound[pname] = v + elif positional_idx < len(cuda_inputs): + bound[pname] = cuda_inputs[positional_idx] + positional_idx += 1 + return kernel_function(**bound) + else: + return kernel_function(*cuda_inputs, *init_inputs) + {% else %} + return kernel_function(*cuda_inputs, *init_inputs) + {% endif %} + + +# Warmup phase: complete autotuning and JIT compilation BEFORE rocprof profiling. +# rocprof traces the entire process; these warmup runs ensure that JIT overhead +# is not included in the final profile capture. +WARMUP_ITERATIONS = 3 +for _ in range(WARMUP_ITERATIONS): + _ = run_kernel() +torch.cuda.synchronize() # Works on ROCm via HIP + +# Main execution: rocprof will trace this invocation +{% else %} +def run_kernel(): + """Helper to run the kernel with appropriate arguments.""" + return kernel_function(*cuda_inputs, *init_inputs) + + +# Warmup phase +WARMUP_ITERATIONS = 3 +for _ in range(WARMUP_ITERATIONS): + _ = run_kernel() +torch.cuda.synchronize() + +# Main execution +{% endif %} +output = run_kernel() + +print("Kernel executed successfully, output shape: " + str(output.shape if hasattr(output, 'shape') else type(output))) diff --git a/triton_kernel_agent/platform/__init__.py b/triton_kernel_agent/platform/__init__.py index a4b5bcc..467fe54 100644 --- a/triton_kernel_agent/platform/__init__.py +++ b/triton_kernel_agent/platform/__init__.py @@ -27,6 +27,11 @@ NvidiaAcceleratorSpecsProvider, NvidiaKernelProfiler, NvidiaRooflineAnalyzer, NvidiaBottleneckAnalyzer, NvidiaRAGPrescriber +AMD ROCm implementations: + ROCmVerifier, ROCmBenchmarker, ROCmWorkerRunner, + ROCmAcceleratorSpecsProvider, ROCmKernelProfilerWrapper, + ROCmRooflineAnalyzerWrapper, ROCmBottleneckAnalyzer, ROCmRAGPrescriber + No-op implementations (for dry-run / CI / testing): NoOpVerifier, NoOpBenchmarker, NoOpWorkerRunner, NoOpSpecsProvider, NoOpProfiler, NoOpRooflineAnalyzer, @@ -67,6 +72,16 @@ NoOpVerifier, NoOpWorkerRunner, ) +from triton_kernel_agent.platform.rocm import ( + ROCmAcceleratorSpecsProvider, + ROCmBenchmarker, + ROCmBottleneckAnalyzer, + ROCmKernelProfilerWrapper, + ROCmRAGPrescriber, + ROCmRooflineAnalyzerWrapper, + ROCmVerifier, + ROCmWorkerRunner, +) from triton_kernel_agent.platform.registry import PlatformRegistry, registry __all__ = [ @@ -91,6 +106,16 @@ "NvidiaRooflineAnalyzer", "NvidiaBottleneckAnalyzer", "NvidiaRAGPrescriber", + # ROCm implementations (manager) + "ROCmVerifier", + "ROCmBenchmarker", + "ROCmWorkerRunner", + # ROCm implementations (worker) + "ROCmAcceleratorSpecsProvider", + "ROCmKernelProfilerWrapper", + "ROCmRooflineAnalyzerWrapper", + "ROCmBottleneckAnalyzer", + "ROCmRAGPrescriber", # No-op implementations (manager) "NoOpVerifier", "NoOpBenchmarker", diff --git a/triton_kernel_agent/platform/registry.py b/triton_kernel_agent/platform/registry.py index c5e70a7..ca42481 100644 --- a/triton_kernel_agent/platform/registry.py +++ b/triton_kernel_agent/platform/registry.py @@ -251,5 +251,31 @@ def _register_builtins() -> None: for component, factory in _noop.items(): registry.register(component, "noop", factory) + from triton_kernel_agent.platform.rocm import ( + ROCmAcceleratorSpecsProvider, + ROCmBenchmarker, + ROCmBottleneckAnalyzer, + ROCmKernelProfilerWrapper, + ROCmRAGPrescriber, + ROCmRooflineAnalyzerWrapper, + ROCmVerifier, + ROCmWorkerRunner, + ) + + _rocm = { + # Manager-level + "verifier": ROCmVerifier, + "benchmarker": ROCmBenchmarker, + "worker_runner": ROCmWorkerRunner, + # Worker-level + "specs_provider": ROCmAcceleratorSpecsProvider, + "profiler": ROCmKernelProfilerWrapper, + "roofline_analyzer": ROCmRooflineAnalyzerWrapper, + "bottleneck_analyzer": ROCmBottleneckAnalyzer, + "rag_prescriber": ROCmRAGPrescriber, + } + for component, factory in _rocm.items(): + registry.register(component, "rocm", factory) + _register_builtins() diff --git a/triton_kernel_agent/platform/rocm.py b/triton_kernel_agent/platform/rocm.py new file mode 100644 index 0000000..5875127 --- /dev/null +++ b/triton_kernel_agent/platform/rocm.py @@ -0,0 +1,651 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AMD ROCm / HIP implementations of platform interfaces. + +These wrap the ROCm-specific code (rocprof profiling, AMD GPU specs, +ROCm roofline analysis) behind the same abstract interfaces used by the +NVIDIA CUDA path in ``triton_kernel_agent/platform/nvidia.py``. + +ROCm-specific notes +------------------- +- ``torch.cuda`` works on ROCm via the HIP compatibility layer, so CUDA + event timing, ``torch.cuda.synchronize()``, etc. all work unchanged. +- Profiling uses ``rocprof`` instead of NVIDIA NCU. +- GPU names are detected via ``torch.cuda.get_device_name()``; AMD Instinct + GPUs return strings like ``"AMD Instinct MI300X"``. +- Wavefront size is 64 (vs 32 for NVIDIA warps). +""" + +from __future__ import annotations + +import logging +import multiprocessing as mp +import shutil +import time +import traceback +from pathlib import Path +from typing import Any + +from triton_kernel_agent.platform.interfaces import ( + AcceleratorSpecsProvider, + BottleneckAnalyzerBase, + KernelBenchmarker, + KernelProfilerBase, + KernelVerifier, + RAGPrescriberBase, + RooflineAnalyzerBase, + WorkerRunner, +) + + +# --------------------------------------------------------------------------- +# GPU name detection helper +# --------------------------------------------------------------------------- + + +def detect_amd_gpu_name() -> str | None: + """Detect the AMD GPU name via ``torch.cuda.get_device_name()``. + + On ROCm, PyTorch's CUDA layer is backed by HIP, and + ``torch.cuda.get_device_name()`` returns the real AMD GPU name + (e.g., ``"AMD Instinct MI300X"``). + + Returns: + GPU name string, or ``None`` if no GPU is found or the GPU is NVIDIA. + """ + try: + import torch + + if not torch.cuda.is_available(): + return None + name = torch.cuda.get_device_name(0) + if "AMD" in name or "Instinct" in name or "Radeon" in name: + return name + return None + except Exception: + return None + + +def _normalize_amd_gpu_name(raw_name: str) -> str: + """Normalize the raw torch GPU name to a key in the GPU specs database. + + ``torch.cuda.get_device_name()`` may return strings like + ``"AMD Instinct MI300X OAM"`` or ``"Instinct MI300X"``. + We strip the OAM/PCIe suffix and ensure the ``"AMD Instinct"`` prefix. + + Args: + raw_name: Raw name from ``torch.cuda.get_device_name()``. + + Returns: + Normalized name string. + """ + # Strip trailing form-factor suffixes + for suffix in (" OAM", " PCIe", " SXM", " NVL"): + if raw_name.endswith(suffix): + raw_name = raw_name[: -len(suffix)].strip() + + # Ensure "AMD Instinct" prefix + if not raw_name.startswith("AMD "): + raw_name = "AMD " + raw_name + + return raw_name + + +# --------------------------------------------------------------------------- +# Verifier +# --------------------------------------------------------------------------- + + +class ROCmVerifier(KernelVerifier): + """Verifies kernel correctness on AMD GPUs. + + Uses the same ``VerificationWorker`` as the NVIDIA path — ROCm exposes + ``torch.cuda`` via HIP so no changes are needed at the verification level. + """ + + def __init__(self, log_dir: Path, logger: logging.Logger) -> None: + self.log_dir = log_dir + self.logger = logger + + def verify( + self, + kernel_code: str, + problem_file: Path, + test_code: list[str], + ) -> bool: + from triton_kernel_agent.worker import VerificationWorker + + verify_dir = self.log_dir / "initial_verify" + verify_dir.mkdir(parents=True, exist_ok=True) + + shutil.copy(problem_file, verify_dir / "problem.py") + + worker = VerificationWorker( + worker_id=-1, + workdir=verify_dir, + log_dir=verify_dir, + target_platform="rocm", + ) + + success, _, error = worker.verify_with_refinement( + kernel_code=kernel_code, + test_code=test_code, + problem_description=problem_file.read_text(), + max_refine_attempts=0, + ) + + if not success: + self.logger.error( + f"Initial kernel failed correctness verification: {error[:200]}" + ) + else: + self.logger.info("Initial kernel passed correctness verification") + + return success + + +# --------------------------------------------------------------------------- +# Benchmarker +# --------------------------------------------------------------------------- + + +class ROCmBenchmarker(KernelBenchmarker): + """Benchmarks Triton kernels on AMD GPUs. + + Uses ``triton.testing.do_bench`` / ``torch.cuda.Event`` timing, which + works on ROCm via the HIP compatibility layer. + """ + + def __init__( + self, + log_dir: Path, + logger: logging.Logger, + benchmark_lock: Any, + warmup: int = 25, + repeat: int = 100, + ) -> None: + self.log_dir = log_dir + self.logger = logger + self.benchmark_lock = benchmark_lock + self.warmup = warmup + self.repeat = repeat + + def _get_benchmarker(self): + from triton_kernel_agent.opt_worker_component.benchmarking.benchmark import ( + Benchmark, + ) + + artifacts_dir = self.log_dir / "artifacts" + artifacts_dir.mkdir(parents=True, exist_ok=True) + return Benchmark( + logger=self.logger, + artifacts_dir=artifacts_dir, + benchmark_lock=self.benchmark_lock, + worker_id=-1, + warmup=self.warmup, + repeat=self.repeat, + ) + + def benchmark_kernel(self, kernel_code: str, problem_file: Path) -> float: + benchmarker = self._get_benchmarker() + artifacts_dir = self.log_dir / "artifacts" + artifacts_dir.mkdir(parents=True, exist_ok=True) + + kernel_file = artifacts_dir / "initial_kernel.py" + kernel_file.write_text(kernel_code, encoding="utf-8") + + result = benchmarker.benchmark_kernel(kernel_file, problem_file) + kernel_time = result.get("time_ms", float("inf")) + + if kernel_time != float("inf"): + self.logger.info(f"Initial kernel time: {kernel_time:.4f}ms") + + return kernel_time + + def benchmark_reference(self, problem_file: Path) -> float: + benchmarker = self._get_benchmarker() + result = benchmarker.benchmark_pytorch(problem_file) + pytorch_time = result.get("time_ms", float("inf")) + + if pytorch_time != float("inf"): + self.logger.info(f"PyTorch baseline: {pytorch_time:.4f}ms") + + return pytorch_time + + def benchmark_reference_compiled(self, problem_file: Path) -> float: + benchmarker = self._get_benchmarker() + result = benchmarker.benchmark_pytorch_compile(problem_file) + compile_time = result.get("time_ms", float("inf")) + + if compile_time != float("inf"): + self.logger.info(f"PyTorch compile baseline: {compile_time:.4f}ms") + + return compile_time + + +# --------------------------------------------------------------------------- +# Worker runner +# --------------------------------------------------------------------------- + + +class ROCmWorkerRunner(WorkerRunner): + """Spawns ``OptimizationWorker`` processes on AMD GPUs.""" + + def __init__( + self, + log_dir: Path, + logger: logging.Logger, + benchmark_lock: Any, + profiling_semaphore: Any, + openai_model: str, + high_reasoning_effort: bool, + bottleneck_override: str | None, + worker_kwargs: dict[str, Any], + ) -> None: + self.log_dir = log_dir + self.logger = logger + self.benchmark_lock = benchmark_lock + self.profiling_semaphore = profiling_semaphore + self.openai_model = openai_model + self.high_reasoning_effort = high_reasoning_effort + self.bottleneck_override = bottleneck_override + self.worker_kwargs = worker_kwargs + + def run_workers( + self, + candidates: list[dict[str, Any]], + round_num: int, + problem_file: Path, + test_code: list[str], + pytorch_baseline: float, + shared_history: list[dict], + shared_reflexions: list[dict], + ) -> list[dict[str, Any]]: + result_queue = mp.Queue() + workers = [] + + for i, candidate in enumerate(candidates): + workdir = self.log_dir / "workers" / f"w{i}" / f"r{round_num}" + workdir.mkdir(parents=True, exist_ok=True) + + args = ( + i, + candidate["parent"].kernel_code, + candidate["parent"].metrics.time_ms, + candidate["parent"].program_id, + problem_file, + test_code, + workdir, + workdir / "logs", + result_queue, + self.benchmark_lock, + self.profiling_semaphore, + pytorch_baseline, + candidate["bottleneck_id"], + self.openai_model, + self.high_reasoning_effort, + self.bottleneck_override, + self.worker_kwargs, + shared_history, + shared_reflexions, + ) + + p = mp.Process(target=_rocm_worker_process, args=args) + p.start() + workers.append(p) + + worker_timeout = 1800 # 30 minutes (longer due to two rocprof passes) + deadline = time.time() + worker_timeout + for w in workers: + remaining = max(0, deadline - time.time()) + w.join(timeout=remaining) + if w.is_alive(): + self.logger.warning(f"Worker {w.pid} timed out, terminating") + w.terminate() + w.join(timeout=5) + if w.is_alive(): + self.logger.warning(f"Worker {w.pid} still alive, killing") + w.kill() + w.join(timeout=2) + w.close() + + results: list[dict[str, Any]] = [] + while not result_queue.empty(): + try: + results.append(result_queue.get_nowait()) + except Exception: + break + + result_queue.close() + result_queue.join_thread() + + successful = sum(1 for r in results if r.get("success")) + self.logger.info( + f"Round {round_num}: {successful}/{len(candidates)} workers succeeded " + f"({len(results)} results received)" + ) + return results + + +def _rocm_worker_process( + worker_id: int, + kernel_code: str, + known_time: float, + parent_id: str, + problem_file: Path, + test_code: list[str], + workdir: Path, + log_dir: Path, + result_queue: mp.Queue, + benchmark_lock: Any, + profiling_semaphore: Any, + pytorch_baseline: float, + bottleneck_id: int, + openai_model: str, + high_reasoning_effort: bool, + bottleneck_override: str | None, + worker_kwargs: dict, + prior_history: list[dict], + prior_reflexions: list[dict], +) -> None: + """Worker process function for AMD ROCm GPUs.""" + import sys + + kernel_agent_path = Path(__file__).parent.parent.parent + if str(kernel_agent_path) not in sys.path: + sys.path.insert(0, str(kernel_agent_path)) + + try: + from triton_kernel_agent.opt_worker import OptimizationWorker + + workdir.mkdir(parents=True, exist_ok=True) + log_dir.mkdir(parents=True, exist_ok=True) + + shutil.copy(problem_file, workdir / "problem.py") + + worker_kwargs.pop("target_platform", None) + worker = OptimizationWorker( + worker_id=worker_id, + workdir=workdir, + log_dir=log_dir, + openai_model=openai_model, + high_reasoning_effort=high_reasoning_effort, + bottleneck_id=bottleneck_id, + benchmark_lock=benchmark_lock, + profiling_semaphore=profiling_semaphore, + pytorch_baseline_time=pytorch_baseline, + bottleneck_override=bottleneck_override, + prior_history=prior_history, + prior_reflexions=prior_reflexions, + target_platform="rocm", + **worker_kwargs, + ) + + success, best_kernel, metrics = worker.optimize_kernel( + kernel_code=kernel_code, + problem_file=problem_file, + test_code=test_code, + known_kernel_time=known_time, + max_opt_rounds=1, + ) + + result_queue.put( + { + "success": success, + "worker_id": worker_id, + "kernel_code": best_kernel, + "time_ms": metrics.get("best_time_ms", float("inf")), + "parent_id": parent_id, + "attempt": metrics.get("last_attempt"), + "reflexion": metrics.get("last_reflexion"), + } + ) + + except Exception as e: + result_queue.put( + { + "success": False, + "worker_id": worker_id, + "error": str(e), + "traceback": traceback.format_exc(), + } + ) + + +# --------------------------------------------------------------------------- +# Worker-level ROCm implementations +# --------------------------------------------------------------------------- + + +class ROCmAcceleratorSpecsProvider(AcceleratorSpecsProvider): + """Looks up AMD GPU specs from the GPU specs database. + + Auto-detects the GPU name via ``torch.cuda.get_device_name()`` when + ``device_name`` is not provided. + """ + + def get_specs(self, device_name: str | None = None) -> dict[str, Any]: + from kernel_perf_agent.kernel_opt.diagnose_prompt.gpu_specs import ( + get_gpu_specs, + ) + + if not device_name: + raw = detect_amd_gpu_name() + if raw is None: + raise ValueError( + "Could not detect AMD GPU name. Provide gpu_name explicitly." + ) + device_name = _normalize_amd_gpu_name(raw) + + specs = get_gpu_specs(device_name) + if specs is None: + # Try normalization in case the raw name was passed directly + normalized = _normalize_amd_gpu_name(device_name) + specs = get_gpu_specs(normalized) + + if specs is None: + raise ValueError( + f"AMD GPU '{device_name}' not found in specs database. " + "Add it to kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py" + ) + return specs + + +class ROCmKernelProfilerWrapper(KernelProfilerBase): + """Wraps :class:`ROCmKernelProfiler` with lazy construction.""" + + def __init__( + self, + logger: logging.Logger | None = None, + log_dir: Path | None = None, + artifacts_dir: Path | None = None, + rocprof_bin_path: str | None = None, + rocprof_timeout_seconds: int | None = None, + profiling_semaphore: Any | None = None, + ) -> None: + self._logger = logger or logging.getLogger(__name__) + self._log_dir = Path(log_dir) if log_dir else Path(".") + self._artifacts_dir = Path(artifacts_dir) if artifacts_dir else None + self._rocprof_bin_path = rocprof_bin_path + self._rocprof_timeout_seconds = rocprof_timeout_seconds + self._profiling_semaphore = profiling_semaphore + self._delegate: Any | None = None + + def _get_delegate(self) -> Any: + if self._delegate is None: + from triton_kernel_agent.opt_worker_component.profiling.rocm_kernel_profiler import ( + ROCmKernelProfiler, + ) + + artifacts_dir = self._artifacts_dir or self._log_dir / "artifacts" + artifacts_dir.mkdir(parents=True, exist_ok=True) + kwargs: dict[str, Any] = { + "logger": self._logger, + "artifacts_dir": artifacts_dir, + "logs_dir": self._log_dir, + "rocprof_bin_path": self._rocprof_bin_path, + "profiling_semaphore": self._profiling_semaphore, + } + if self._rocprof_timeout_seconds is not None: + kwargs["rocprof_timeout_seconds"] = self._rocprof_timeout_seconds + self._delegate = ROCmKernelProfiler(**kwargs) + return self._delegate + + def profile_kernel( + self, + kernel_file: Path, + problem_file: Path, + round_num: int, + max_retries: int = 2, + ) -> Any | None: + return self._get_delegate().profile_kernel( + kernel_file, problem_file, round_num, max_retries + ) + + +class ROCmRooflineAnalyzerWrapper(RooflineAnalyzerBase): + """Wraps :class:`ROCmRooflineAnalyzer` with lazy construction.""" + + def __init__( + self, + logger: logging.Logger | None = None, + roofline_config: Any | None = None, + ) -> None: + self._logger = logger + self._roofline_config = roofline_config + self._delegate: Any | None = None + + def _get_delegate(self) -> Any: + if self._delegate is None: + from kernel_perf_agent.kernel_opt.roofline.rocm_roofline import ( + ROCmRooflineAnalyzer, + ROCmRooflineConfig, + ) + + kwargs: dict[str, Any] = {"logger": self._logger} + if self._roofline_config is not None: + # Allow passing either a ROCmRooflineConfig or dict + if isinstance(self._roofline_config, dict): + kwargs["config"] = ROCmRooflineConfig(**self._roofline_config) + else: + kwargs["config"] = self._roofline_config + self._delegate = ROCmRooflineAnalyzer(**kwargs) + return self._delegate + + def analyze(self, rocm_metrics: dict[str, Any]) -> Any: + return self._get_delegate().analyze(rocm_metrics) + + def should_stop(self, result: Any) -> tuple[bool, str]: + return self._get_delegate().should_stop(result) + + def reset_history(self) -> None: + self._get_delegate().reset_history() + + +class ROCmBottleneckAnalyzer(BottleneckAnalyzerBase): + """Bottleneck analyzer for AMD GPUs. + + Uses the same LLM-based ``BottleneckAnalyzer`` as the NVIDIA path but + passes AMD GPU specs and ROCm-specific counter names in the context. + """ + + def __init__( + self, + logger: logging.Logger | None = None, + log_dir: Path | None = None, + openai_model: str = "gpt-5", + gpu_name: str | None = None, + roofline_config: Any | None = None, + ) -> None: + self._logger = logger or logging.getLogger(__name__) + self._log_dir = Path(log_dir) if log_dir else None + self._openai_model = openai_model + self._gpu_name = gpu_name + self._delegate: Any | None = None + self.roofline = ROCmRooflineAnalyzerWrapper( + logger=self._logger, roofline_config=roofline_config + ) + + def _get_delegate(self) -> Any: + if self._delegate is None: + from kernel_perf_agent.kernel_opt.diagnose_prompt.gpu_specs import ( + get_gpu_specs, + ) + from triton_kernel_agent.opt_worker_component.prescribing.bottleneck_analyzer import ( + BottleneckAnalyzer, + ) + from utils.providers import get_model_provider + + gpu_name = self._gpu_name + if not gpu_name: + raw = detect_amd_gpu_name() + if raw: + gpu_name = _normalize_amd_gpu_name(raw) + if not gpu_name: + raise ValueError("gpu_name is required for ROCmBottleneckAnalyzer") + + provider = get_model_provider(self._openai_model) + gpu_specs = get_gpu_specs(gpu_name) + + self._delegate = BottleneckAnalyzer( + provider=provider, + model=self._openai_model, + gpu_specs=gpu_specs, + logs_dir=self._log_dir, + logger=self._logger, + ) + return self._delegate + + def analyze( + self, + kernel_code: str, + rocm_metrics: dict[str, Any], + round_num: int = 0, + roofline_result: Any | None = None, + ) -> list[Any]: + return self._get_delegate().analyze( + kernel_code, rocm_metrics, round_num, roofline_result + ) + + +class ROCmRAGPrescriber(RAGPrescriberBase): + """RAG prescriber for ROCm — delegates to the same implementation as NVIDIA.""" + + def __init__( + self, + logger: logging.Logger | None = None, + database_path: Path | None = None, + ) -> None: + self._logger = logger + self._database_path = database_path + self._delegate: Any | None = None + + def _get_delegate(self) -> Any: + if self._delegate is None: + from triton_kernel_agent.opt_worker_component.prescribing.RAG_based_prescriber import ( + RAGPrescriber, + ) + + kwargs: dict[str, Any] = {"logger": self._logger} + if self._database_path is not None: + kwargs["database_path"] = self._database_path + self._delegate = RAGPrescriber(**kwargs) + return self._delegate + + def retrieve(self, query: str) -> tuple[Any | None, Any]: + return self._get_delegate().retrieve(query) + + def build_context(self, opt_node: Any, **kwargs: Any) -> str: + return self._get_delegate().build_context(opt_node, **kwargs) diff --git a/triton_kernel_agent/platform_config.py b/triton_kernel_agent/platform_config.py index f69c6b5..7254008 100644 --- a/triton_kernel_agent/platform_config.py +++ b/triton_kernel_agent/platform_config.py @@ -76,6 +76,38 @@ class PlatformConfig: "XPUDriver.is_available = classmethod(lambda cls: False)", ) +# ROCm/AMD GPU platform constants +_ROCM_GUIDANCE = """\ +**CRITICAL PLATFORM REQUIREMENTS FOR AMD ROCm:** +- Default tensor allocations to device='cuda' (ROCm exposes HIP as CUDA-compatible via torch.cuda) +- Check availability with: torch.cuda.is_available() (returns True on ROCm/HIP) +- AMD wavefront size is 64 (not 32 like NVIDIA warps) — account for this in tiling +- Do NOT assume NVIDIA-specific ISA features (e.g., warp shuffle semantics differ) +- Use torch.cuda.synchronize() for synchronization (works on ROCm via HIP) +- Preferred block sizes: 64, 128, 256, or 512 (multiples of wavefront size 64) +- triton.language.constexpr BLOCK_SIZE should be a power of 2 >= 64""" + +_ROCM_KERNEL_GUIDANCE = """\ +## AMD ROCm-Specific Optimizations + +You are generating a Triton kernel for AMD GPUs (ROCm/HIP). Follow these guidelines: + +1. **Device Context**: Use 'cuda' as the device string (ROCm provides HIP-CUDA compatibility) +2. **Wavefront Size**: AMD GPUs use wavefront size 64 (vs NVIDIA warp size 32) + - Prefer BLOCK_SIZE multiples of 64 (64, 128, 256, 512) + - num_warps maps to num_wavefronts on AMD +3. **Memory Hierarchy**: AMD CDNA GPUs have HBM memory with very high bandwidth + - MI300X: 5.3 TB/s, MI350X: ~8 TB/s + - Optimize for memory coalescing and avoid strided access patterns +4. **Compute Units**: AMD uses Compute Units (CUs), each with 64-lane SIMD + - MI300X has 304 CUs, MI350X has 304 CUs +5. **Data Types**: AMD CDNA supports fp32, fp16, bf16, fp8 (gfx942+) + - BF16 matrix units available on MI300X (CDNA3) and later +6. **Thread Configuration**: + - BLOCK_SIZE: prefer 64, 128, 256, or 512 + - num_warps: typically 4, 8 for AMD (maps to wavefronts) +7. **Avoid NVIDIA-specific patterns**: Do not use warp-level primitives that assume warp_size=32""" + # Platform registry PLATFORMS: dict[str, PlatformConfig] = { "cuda": PlatformConfig( @@ -85,6 +117,13 @@ class PlatformConfig: kernel_guidance="", cuda_hacks_to_strip=(), ), + "rocm": PlatformConfig( + name="rocm", + device_string="cuda", # ROCm uses torch.cuda (HIP compatibility layer) + guidance_block=_ROCM_GUIDANCE, + kernel_guidance=_ROCM_KERNEL_GUIDANCE, + cuda_hacks_to_strip=(), + ), "xpu": PlatformConfig( name="xpu", device_string="xpu", diff --git a/triton_kernel_agent/worker.py b/triton_kernel_agent/worker.py index ea5d982..011ea32 100644 --- a/triton_kernel_agent/worker.py +++ b/triton_kernel_agent/worker.py @@ -135,7 +135,7 @@ def __init__( high_reasoning_effort: bool = True, target_platform: str = "cuda", no_cusolver: bool = False, - test_timeout_s: int = 30, + test_timeout_s: int = 300, ): """ Initialize a verification worker.