From 86839fa8cbcb620a86053257549bf438d72422d4 Mon Sep 17 00:00:00 2001 From: can-gaa-hou Date: Sat, 28 Feb 2026 10:17:17 +0000 Subject: [PATCH] Enhance platform configuration support with JSON loading and custom GPU checks --- README.md | 33 +++++++++++++ triton_kernel_agent/agent.py | 13 +++-- triton_kernel_agent/platform_config.py | 49 +++++++++++++++++++ .../templates/test_generation.j2 | 5 +- 4 files changed, 94 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 4bbd40e0..0c3dde8d 100644 --- a/README.md +++ b/README.md @@ -30,10 +30,13 @@ Every stage writes artifacts to a run directory under `.optimize//`, inc - **GPU Requirements (one of the following):** - **CUDA**: NVIDIA GPU with CUDA support - **XPU**: Intel GPU with oneAPI support (Arc, Data Center GPUs, or integrated Xe graphics) + - **Custom GPU**: Make sure that the GPU is supported by `torch.accelerator` - Triton (installed separately: `pip install triton` or nightly from source) - PyTorch (https://pytorch.org/get-started/locally/) - LLM provider ([OpenAI](https://openai.com/api/), [Anthropic](https://www.anthropic.com/), or a self-hosted relay) +**Important Note for custom gpu**: Please install correlated version of PyTorch and Triton!!! + ### Install ```bash pip install -e . @@ -234,6 +237,33 @@ When targeting Intel XPU, KernelAgent automatically: - Generates appropriate device availability checks - Removes CUDA-specific patterns from generated code +### Custom Platforms via JSON (Optional) + +You can add/override platform configs via a JSON file loaded at import time: + +```bash +export KERNELAGENT_PLATFORM_JSON=/abs/path/to/platforms.json +``` + +JSON format (top-level is a mapping; keys are platform names used by `--target-platform`): + +```json +// platforms.json +{ + "npu": { + "name": "npu", + "device_string": "npu", + "guidance_block": "...", + "kernel_guidance": "...", + "cuda_hacks_to_strip": [] + } +} +``` + +Notes: +- Set `KERNELAGENT_PLATFORM_JSON` before starting any KernelAgent/Fuser CLI. +- Set the platform name by `--target-platform` or `target_platform` explicitly (e.g `python -m Fuser.auto_agent --problem /abs/path/to/problem.py --target-platform npu`). + ### Verifying Platform Setup ```python # Check CUDA availability @@ -242,6 +272,9 @@ print("CUDA available:", torch.cuda.is_available()) # Check XPU availability print("XPU available:", hasattr(torch, 'xpu') and torch.xpu.is_available()) + +# Check custom gpu availability (Suppose the name is NPU) +print("NPU available:", torch.accelerator.is_available() and torch.accelerator.current_accelerator().type == 'npu') ``` ## Run Artifacts diff --git a/triton_kernel_agent/agent.py b/triton_kernel_agent/agent.py index c90e1036..7233f6c0 100644 --- a/triton_kernel_agent/agent.py +++ b/triton_kernel_agent/agent.py @@ -18,7 +18,7 @@ import json import re from pathlib import Path -from typing import Any +from typing import Any, Union from datetime import datetime import logging from dotenv import load_dotenv @@ -40,7 +40,7 @@ def __init__( model_name: str | None = None, high_reasoning_effort: bool = True, preferred_provider: BaseProvider | None = None, - target_platform: PlatformConfig | None = None, + target_platform: Union[PlatformConfig, str] | None = None, no_cusolver: bool = False, test_timeout_s: int = 30, ): @@ -87,9 +87,12 @@ def __init__( self.log_dir.mkdir(exist_ok=True, parents=True) # Normalize to PlatformConfig - self._platform_config = ( - target_platform if target_platform else get_platform("cuda") - ) + if not target_platform: + target_platform = get_platform("cuda") + elif isinstance(target_platform, str): + target_platform = get_platform(target_platform) + self._platform_config = target_platform + self.no_cusolver = no_cusolver self.test_timeout_s = test_timeout_s diff --git a/triton_kernel_agent/platform_config.py b/triton_kernel_agent/platform_config.py index f69c6b5e..5d063d8c 100644 --- a/triton_kernel_agent/platform_config.py +++ b/triton_kernel_agent/platform_config.py @@ -25,6 +25,11 @@ from dataclasses import dataclass, field +import json +import os +from pathlib import Path + + DEFAULT_PLATFORM = "cuda" @@ -94,6 +99,46 @@ class PlatformConfig: ), } +def load_platform_config_from_json(path: str | Path) -> dict[str, PlatformConfig]: + """Load platform configs from a JSON file.""" + + p = Path(path) + raw = json.loads(p.read_text(encoding="utf-8")) + + loaded: dict[str, PlatformConfig] = {} + for name, cfg in raw.items(): + + device_name = cfg.get("name", "") + device_string = str(cfg.get("device_string", "")) + guidance_block = str(cfg.get("guidance_block", "")) + kernel_guidance = str(cfg.get("kernel_guidance", "")) + hacks = cfg.get("cuda_hacks_to_strip", []) + + loaded[name] = PlatformConfig( + name=device_name, + device_string=device_string, + guidance_block=guidance_block, + kernel_guidance=kernel_guidance, + cuda_hacks_to_strip=tuple(hacks), + ) + + return loaded + + +def _maybe_load_external_platforms() -> None: + """Optionally merge JSON-defined platforms into PLATFORMS.""" + from dotenv import load_dotenv + load_dotenv() + + env_path = os.getenv("KERNELAGENT_PLATFORM_JSON") + path = Path(env_path) if env_path else None + if not path or not path.exists(): + return + try: + PLATFORMS.update(load_platform_config_from_json(path)) + except Exception: + raise ValueError(f"Failed to load platforms from JSON at {path}") + def get_platform(name: str) -> PlatformConfig: """Get platform configuration by name.""" @@ -106,3 +151,7 @@ def get_platform(name: str) -> PlatformConfig: def get_platform_choices() -> list[str]: """Get list of available platform names for CLI choices.""" return sorted(PLATFORMS.keys()) + + +# Load external platforms from JSON if specified in environment +_maybe_load_external_platforms() diff --git a/triton_kernel_agent/templates/test_generation.j2 b/triton_kernel_agent/templates/test_generation.j2 index 61c9b2e2..a46bf28c 100644 --- a/triton_kernel_agent/templates/test_generation.j2 +++ b/triton_kernel_agent/templates/test_generation.j2 @@ -120,9 +120,12 @@ def test_kernel(): {% if device_string == "xpu" %} if not hasattr(torch, 'xpu') or not torch.xpu.is_available(): raise RuntimeError("Intel XPU not available. Install PyTorch with Intel GPU support.") -{% else %} +{% elif device_string == "cuda" %} if not torch.cuda.is_available(): raise RuntimeError("CUDA not available") +{% else %} + if not torch.accelerator.is_available(): + raise RuntimeError(f"Device '{device}' not available: {e}") {% endif %} # Create test data using EXACT specifications from problem description