Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions .github/workflows/rocm.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
name: ROCm backend CI

# Runs on AMD GPU runners when available, falls back to CPU-only smoke test.
# GitHub-hosted runners don't have AMD GPUs; this workflow is designed to
# run on self-hosted runners with RDNA/CDNA hardware labeled 'rocm'.

on:
push:
paths:
- 'comfy_kitchen/backends/rocm/**'
- 'tests/test_rocm_backend.py'
- '.github/workflows/rocm.yml'
pull_request:
paths:
- 'comfy_kitchen/backends/rocm/**'
- 'tests/test_rocm_backend.py'
- '.github/workflows/rocm.yml'

jobs:

# -------------------------------------------------------------------------
# Smoke test: import and registration (no GPU required)
# -------------------------------------------------------------------------
smoke-test:
name: Import smoke test (CPU)
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: '3.12'

- name: Install CPU PyTorch
run: pip install torch --index-url https://download.pytorch.org/whl/cpu packaging

- name: Install comfy-kitchen (no CUDA)
env:
COMFY_NO_CUDA: "1"
run: pip install -e . --no-build-isolation

- name: Check ROCm backend imports without error on CPU
run: |
python -c "
import torch
print('PyTorch:', torch.__version__)
print('ROCm:', getattr(torch.version, 'hip', None))

import comfy_kitchen as ck
import json
backends = ck.list_backends()
print(json.dumps(backends, indent=2))

# rocm should be present but unavailable (not a ROCm build)
assert 'rocm' in backends, 'rocm backend missing from registry'
assert not backends['rocm']['available'], 'rocm should not be available on CPU torch'
assert backends['eager']['available'], 'eager backend must always be available'
print('OK: ROCm backend registered correctly on CPU')
"

# -------------------------------------------------------------------------
# Full GPU test: runs on self-hosted AMD runner
# -------------------------------------------------------------------------
gpu-test:
name: GPU test (ROCm, self-hosted)
runs-on: [self-hosted, rocm]
# Only run if a self-hosted runner is available; skip otherwise
if: ${{ vars.ROCM_RUNNER_AVAILABLE == 'true' }}

container:
image: rocm/pytorch:rocm6.4_ubuntu22.04_py3.12_pytorch_release_2.6.0
options: --device /dev/kfd --device /dev/dri --group-add video --group-add render

steps:
- uses: actions/checkout@v4

- name: Show ROCm info
run: |
rocminfo | head -30
python -c "import torch; print('PyTorch:', torch.__version__); print('ROCm:', torch.version.hip)"

- name: Install comfy-kitchen from source
run: |
pip install nanobind cmake
pip install -e . --no-build-isolation

- name: List detected backends
run: |
python -c "
import comfy_kitchen as ck
import json
print(json.dumps(ck.list_backends(), indent=2, default=str))
"

- name: Smoke test all ops
run: |
python -c "
import torch, comfy_kitchen as ck

x = torch.randn(128, 128, device='cuda', dtype=torch.bfloat16)
scale = torch.tensor(1.0, device='cuda', dtype=torch.float32)
q = ck.quantize_per_tensor_fp8(x, scale)
dq = ck.dequantize_per_tensor_fp8(q, scale)
print('quantize_fp8:', q.dtype, 'roundtrip_err:', (x-dq).abs().max().item())

a = torch.randn(64, 128, device='cuda', dtype=torch.bfloat16)
b = torch.randn(64, 128, device='cuda', dtype=torch.bfloat16)
qa, sa = ck.quantize_mxfp8(a)
qb, sb = ck.quantize_mxfp8(b)
out = ck.scaled_mm_mxfp8(qa, qb, sa, sb)
print('scaled_mm_mxfp8:', out.dtype, out.shape)

xq = torch.randn(4, 8, 8, 64, device='cuda', dtype=torch.bfloat16)
xk = torch.randn(4, 8, 8, 64, device='cuda', dtype=torch.bfloat16)
freqs = torch.randn(1, 8, 1, 32, 2, 2, device='cuda', dtype=torch.float32)
rq, rk = ck.apply_rope(xq, xk, freqs)
print('apply_rope:', rq.dtype, rq.shape)
print('All ops OK')
"

# -------------------------------------------------------------------------
# Windows ROCm smoke test (no GPU, import only)
# -------------------------------------------------------------------------
windows-smoke:
name: Windows ROCm smoke test
runs-on: windows-latest

steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: '3.12'

- name: Install CPU PyTorch
run: pip install torch --index-url https://download.pytorch.org/whl/cpu

- name: Install package
env:
COMFY_NO_CUDA: "1"
run: pip install -e . --no-build-isolation

- name: Verify import and backend registration
run: |
python -c "
import comfy_kitchen as ck
import json
backends = ck.list_backends()
print(json.dumps(backends, indent=2))
assert 'rocm' in backends
print('OK')
"
78 changes: 54 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,35 @@ Fast kernel library for Diffusion inference with multiple compute backends.

## Backend Capabilities Matrix

| Function | eager | cuda | triton |
|-----------------------------|-------|------|--------|
| `quantize_per_tensor_fp8` | ✓ | ✓ | ✓ |
| `dequantize_per_tensor_fp8` | ✓ | ✓ | ✓ |
| `quantize_nvfp4` | ✓ | ✓ | ✓ |
| `dequantize_nvfp4` | ✓ | ✓ | |
| `scaled_mm_nvfp4` | ✓ | ✓ | |
| `quantize_mxfp8` | ✓ | ✓ | ✓ |
| `dequantize_mxfp8` | ✓ | | |
| `scaled_mm_mxfp8` | ✓ | | |
| `apply_rope` | ✓ | ✓ | ✓ |
| `apply_rope1` | ✓ | ✓ | ✓ |
| Function | eager | cuda | triton | rocm |
|-----------------------------|-------|------|--------|------|
| `quantize_per_tensor_fp8` | ✓ | ✓ | ✓ | ✓ |
| `dequantize_per_tensor_fp8` | ✓ | ✓ | ✓ | ✓ |
| `quantize_nvfp4` | ✓ | ✓ | ✓ | ✓ |
| `dequantize_nvfp4` | ✓ | ✓ | | ✓ |
| `scaled_mm_nvfp4` | ✓ | ✓ | | ✓ ¹ |
| `quantize_mxfp8` | ✓ | ✓ | ✓ | ✓ |
| `dequantize_mxfp8` | ✓ | | | ✓ |
| `scaled_mm_mxfp8` | ✓ | | | ✓ ² |
| `apply_rope` | ✓ | ✓ | ✓ | ✓ |
| `apply_rope1` | ✓ | ✓ | ✓ | ✓ |

> ¹ AMD RDNA hardware lacks native FP4. `scaled_mm_nvfp4` dequantises to BF16 then runs `torch.mm`.
>
> ² Uses `torch._scaled_mm` → hipBLASLt FP8 GEMM on RDNA3 (gfx1100+) and RDNA4 (gfx1200+).
> Falls back to dequant + `torch.mm` on older hardware. TensorWise scaling is used as an
> approximation until PyTorch ROCm exposes block-scaled hipBLASLt directly.


## Quantized Tensors

The library provides `QuantizedTensor`, a `torch.Tensor` subclass that transparently intercepts PyTorch operations and dispatches them to optimized quantized kernels when available.

| Layout | Format | HW Requirement | Description |
|------------------------|--------------|-----------------|----------------------------------------|
| `TensorCoreFP8Layout` | FP8 E4M3 | SM ≥ 8.9 (Ada) | Per-tensor scaling, 1:1 element mapping |
| `TensorCoreNVFP4Layout`| NVFP4 E2M1 | SM ≥ 10.0 (Blackwell) | Block quantization with 16-element blocks |
| `TensorCoreMXFP8Layout`| MXFP8 E4M3 | SM ≥ 10.0 (Blackwell) | Block quantization with 32-element blocks, E8M0 scales |
| Layout | Format | HW Requirement | Description |
|-------------------------|------------|-----------------------|----------------------------------------------------|
| `TensorCoreFP8Layout` | FP8 E4M3 | SM ≥ 8.9 (Ada) / RDNA3+ | Per-tensor scaling, 1:1 element mapping |
| `TensorCoreNVFP4Layout` | NVFP4 E2M1 | SM ≥ 10.0 (Blackwell) | Block quantization with 16-element blocks |
| `TensorCoreMXFP8Layout` | MXFP8 E4M3 | SM ≥ 10.0 (Blackwell) / RDNA3+ | Block quantization with 32-element blocks, E8M0 scales |

```python
from comfy_kitchen.tensor import QuantizedTensor, TensorCoreFP8Layout, TensorCoreNVFP4Layout
Expand All @@ -45,6 +51,8 @@ dq = qt.dequantize()

## Installation

> Note: If you are on a system with a non-UTF-8 locale, builds may fail with a `UnicodeDecodeError`. Set `PYTHONUTF8=1` in your environment.

### From PyPI

```bash
Expand All @@ -62,6 +70,16 @@ pip install comfy-kitchen[cublas]

Wheels are built for Python 3.10, 3.11, and 3.12+ (using Stable ABI for 3.12+).

### AMD ROCm

The ROCm backend is pure Python and requires no compilation. Install the
ROCm build of PyTorch for your platform, then install comfy-kitchen normally —
everything is detected and configured automatically, with no extra setup needed.

**Supported hardware:** RDNA3 (RX 7000 series, gfx1100+), RDNA4 (RX 9000 series,
gfx1200+), and CDNA3 (MI300X, gfx940+). Older AMD GPUs are supported with
eager fallback for all ops.

### From Source

```bash
Expand All @@ -73,6 +91,9 @@ pip install -e ".[dev]"

# For faster rebuilds during development (skip build isolation)
pip install -e . --no-build-isolation -v

# Installation with ROCm support
pip install . --no-build-isolation -v
```

#### Build Options
Expand All @@ -95,6 +116,10 @@ python setup.py build_ext --cuda-archs="80;89" bdist_wheel

# Debug build with line info for profiling
python setup.py build_ext --debug-build --lineinfo bdist_wheel

# Build with ROCm
python setup.py bdist_wheel

```


Expand All @@ -106,6 +131,8 @@ python setup.py build_ext --debug-build --lineinfo bdist_wheel
- **CUDA Runtime** (for CUDA wheels): ≥13.0
- Pre-built wheels require NVIDIA Driver r580+
- Building from source requires CUDA Toolkit ≥12.8 and `CUDA_HOME` environment variable
- **ROCm PyTorch** (for ROCm backend)
- No additional compilation or ROCm stack installation required on Windows
- **nanobind**: ≥2.0.0 (for building from source)
- **CMake**: ≥3.18 (for building from source)

Expand All @@ -115,7 +142,7 @@ python setup.py build_ext --debug-build --lineinfo bdist_wheel
import comfy_kitchen as ck
import torch

# Automatic backend selection (triton -> cuda -> eager)
# Automatic backend selection (cuda -> rocm -> triton -> eager)
x = torch.randn(100, 100, device="cuda")
scale = torch.tensor([1.0], device="cuda")
result = ck.quantize_per_tensor_fp8(x, scale)
Expand All @@ -127,27 +154,30 @@ print(ck.list_backends())
result = ck.quantize_per_tensor_fp8(x, scale, backend="eager")

# Temporarily use a different backend
with ck.use_backend("triton"):
with ck.use_backend("rocm"):
result = ck.quantize_per_tensor_fp8(x, scale)
```

## Backend System

The library supports multiple backends:
- **eager**: Pure PyTorch implementation
- **cuda**: Custom CUDA C kernels (CUDA only)
- **triton**: Triton JIT-compiled kernels

- **eager**: Pure PyTorch implementation, works on any device
- **cuda**: Custom CUDA C kernels (NVIDIA GPUs)
- **triton**: Triton JIT-compiled kernels (NVIDIA and AMD)
- **rocm**: Pure-Python AMD backend using hipBLASLt via `torch._scaled_mm`

### Automatic Backend Selection

When you call a function, the registry selects the best backend by checking **constraints** in priority order (`cuda` → `triton` → `eager`):
When you call a function, the registry selects the best backend by checking **constraints** in priority order (`cuda` → `rocm` → `triton` → `eager`):

```python
# Backend is selected automatically based on input constraints
result = ck.quantize_per_tensor_fp8(x, scale)

# On CPU tensors → falls back to eager (only backend supporting CPU)
# On CUDA tensors → uses cuda or triton (higher priority)
# On CUDA tensors (NVIDIA) → uses cuda or triton (higher priority)
# On CUDA tensors (AMD/ROCm) → uses rocm
```

### Constraint System
Expand Down
1 change: 1 addition & 0 deletions comfy_kitchen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# Import backends to trigger auto-registration
from .backends import eager as _eager_backend # noqa: F401
from .backends import rocm as _rocm_backend # noqa: F401
from .backends import triton as _triton_backend # noqa: F401
from .backends.eager.quantization import DTYPE_TO_CODE
from .exceptions import (
Expand Down
Loading