Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ Every stage writes artifacts to a run directory under `.optimize/<run_id>/`, inc
- **GPU Requirements (one of the following):**
- **CUDA**: NVIDIA GPU with CUDA support
- **XPU**: Intel GPU with oneAPI support (Arc, Data Center GPUs, or integrated Xe graphics)
- **Custom GPU**: Make sure that the GPU is supported by `torch.accelerator`
- Triton (installed separately: `pip install triton` or nightly from source)
- PyTorch (https://pytorch.org/get-started/locally/)
- LLM provider ([OpenAI](https://openai.com/api/), [Anthropic](https://www.anthropic.com/), or a self-hosted relay)

**Important Note for custom gpu**: Please install correlated version of PyTorch and Triton!!!

### Install
```bash
pip install -e .
Expand Down Expand Up @@ -234,6 +237,33 @@ When targeting Intel XPU, KernelAgent automatically:
- Generates appropriate device availability checks
- Removes CUDA-specific patterns from generated code

### Custom Platforms via JSON (Optional)

You can add/override platform configs via a JSON file loaded at import time:

```bash
export KERNELAGENT_PLATFORM_JSON=/abs/path/to/platforms.json
```

JSON format (top-level is a mapping; keys are platform names used by `--target-platform`):

```json
// platforms.json
{
"npu": {
"name": "npu",
"device_string": "npu",
"guidance_block": "...",
"kernel_guidance": "...",
"cuda_hacks_to_strip": []
}
}
```

Notes:
- Set `KERNELAGENT_PLATFORM_JSON` before starting any KernelAgent/Fuser CLI.
- Set the platform name by `--target-platform` or `target_platform` explicitly (e.g `python -m Fuser.auto_agent --problem /abs/path/to/problem.py --target-platform npu`).

### Verifying Platform Setup
```python
# Check CUDA availability
Expand All @@ -242,6 +272,9 @@ print("CUDA available:", torch.cuda.is_available())

# Check XPU availability
print("XPU available:", hasattr(torch, 'xpu') and torch.xpu.is_available())

# Check custom gpu availability (Suppose the name is NPU)
print("NPU available:", torch.accelerator.is_available() and torch.accelerator.current_accelerator().type == 'npu')
```

## Run Artifacts
Expand Down
13 changes: 8 additions & 5 deletions triton_kernel_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import json
import re
from pathlib import Path
from typing import Any
from typing import Any, Union
from datetime import datetime
import logging
from dotenv import load_dotenv
Expand All @@ -40,7 +40,7 @@ def __init__(
model_name: str | None = None,
high_reasoning_effort: bool = True,
preferred_provider: BaseProvider | None = None,
target_platform: PlatformConfig | None = None,
target_platform: Union[PlatformConfig, str] | None = None,
no_cusolver: bool = False,
test_timeout_s: int = 30,
):
Expand Down Expand Up @@ -87,9 +87,12 @@ def __init__(
self.log_dir.mkdir(exist_ok=True, parents=True)

# Normalize to PlatformConfig
self._platform_config = (
target_platform if target_platform else get_platform("cuda")
)
if not target_platform:
target_platform = get_platform("cuda")
elif isinstance(target_platform, str):
target_platform = get_platform(target_platform)
self._platform_config = target_platform

self.no_cusolver = no_cusolver
self.test_timeout_s = test_timeout_s

Expand Down
49 changes: 49 additions & 0 deletions triton_kernel_agent/platform_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@

from dataclasses import dataclass, field

import json
import os
from pathlib import Path


DEFAULT_PLATFORM = "cuda"


Expand Down Expand Up @@ -94,6 +99,46 @@ class PlatformConfig:
),
}

def load_platform_config_from_json(path: str | Path) -> dict[str, PlatformConfig]:
"""Load platform configs from a JSON file."""

p = Path(path)
raw = json.loads(p.read_text(encoding="utf-8"))

loaded: dict[str, PlatformConfig] = {}
for name, cfg in raw.items():

device_name = cfg.get("name", "")
device_string = str(cfg.get("device_string", ""))
guidance_block = str(cfg.get("guidance_block", ""))
kernel_guidance = str(cfg.get("kernel_guidance", ""))
hacks = cfg.get("cuda_hacks_to_strip", [])

loaded[name] = PlatformConfig(
name=device_name,
device_string=device_string,
guidance_block=guidance_block,
kernel_guidance=kernel_guidance,
cuda_hacks_to_strip=tuple(hacks),
)

return loaded


def _maybe_load_external_platforms() -> None:
"""Optionally merge JSON-defined platforms into PLATFORMS."""
from dotenv import load_dotenv
load_dotenv()

env_path = os.getenv("KERNELAGENT_PLATFORM_JSON")
path = Path(env_path) if env_path else None
if not path or not path.exists():
return
try:
PLATFORMS.update(load_platform_config_from_json(path))
except Exception:
raise ValueError(f"Failed to load platforms from JSON at {path}")


def get_platform(name: str) -> PlatformConfig:
"""Get platform configuration by name."""
Expand All @@ -106,3 +151,7 @@ def get_platform(name: str) -> PlatformConfig:
def get_platform_choices() -> list[str]:
"""Get list of available platform names for CLI choices."""
return sorted(PLATFORMS.keys())


# Load external platforms from JSON if specified in environment
_maybe_load_external_platforms()
5 changes: 4 additions & 1 deletion triton_kernel_agent/templates/test_generation.j2
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,12 @@ def test_kernel():
{% if device_string == "xpu" %}
if not hasattr(torch, 'xpu') or not torch.xpu.is_available():
raise RuntimeError("Intel XPU not available. Install PyTorch with Intel GPU support.")
{% else %}
{% elif device_string == "cuda" %}
if not torch.cuda.is_available():
raise RuntimeError("CUDA not available")
{% else %}
if not torch.accelerator.is_available():
raise RuntimeError(f"Device '{device}' not available: {e}")
{% endif %}

# Create test data using EXACT specifications from problem description
Expand Down
Loading