Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kernels/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ license = { text = "Apache-2.0" }
readme = "README.md"
requires-python = ">= 3.9"
dependencies = [
"huggingface_hub>=0.26.0,<2.0",
"huggingface_hub>=1.3.0,<2.0",
"Jinja2>=3.1.5",
"packaging>=20.0",
"pyyaml>=6",
Expand Down
7 changes: 6 additions & 1 deletion kernels/src/kernels/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import re
import warnings
from dataclasses import dataclass
from typing import ClassVar, Optional, Protocol
from typing import ClassVar, Optional, Protocol, runtime_checkable

from huggingface_hub.dataclasses import strict
from packaging.version import Version

from kernels.compat import has_torch


@runtime_checkable
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because @strict validates field types using isinstance() checks at construction time. When Arch.backend is typed as Backend (a Protocol), @strict tries to do isinstance(value, Backend) — and Python's typing.Protocol raises TypeError unless it's decorated with @runtime_checkable.

class Backend(Protocol):
@property
def name(self) -> str:
Expand Down Expand Up @@ -49,6 +51,7 @@ def parse(s: str) -> "CANN":
return CANN(version=Version(f"{m.group(1)}.{m.group(2)}"))


@strict
@dataclass(unsafe_hash=True)
class CPU:
@property
Expand Down Expand Up @@ -88,6 +91,7 @@ def parse(s: str) -> "CUDA":
return CUDA(version=Version(f"{m.group(1)}.{m.group(2)}"))


@strict
@dataclass(unsafe_hash=True)
class Metal:
@property
Expand All @@ -105,6 +109,7 @@ def parse(s: str) -> "Metal":
return Metal()


@strict
@dataclass(unsafe_hash=True)
class Neuron:
@property
Expand Down
14 changes: 14 additions & 0 deletions kernels/src/kernels/cli/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pathlib import Path
from typing import Any

from huggingface_hub.dataclasses import strict
from huggingface_hub.utils import (
build_hf_headers,
disable_progress_bars,
Expand Down Expand Up @@ -69,6 +70,7 @@ def _calculate_iqr_and_outliers(
return q1, q3, iqr, outliers


@strict
@dataclass
class TimingResults:
mean_ms: float
Expand All @@ -83,7 +85,18 @@ class TimingResults:
verified: bool | None = None # None = no verify fn, True = passed, False = failed
ref_mean_ms: float | None = None # Reference implementation mean time

def validate_iterations(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are "magic" methods. @strict auto-discovers any method named validate_* and calls it during __init__ (and on __setattr__). They replaced the manual checks that were previously inside from_dict().

if self.iterations <= 0:
raise ValueError(f"iterations must be > 0, got {self.iterations}")

def validate_timing_range(self):
if self.min_ms > self.max_ms:
raise ValueError(
f"min_ms ({self.min_ms}) must be <= max_ms ({self.max_ms})"
)


@strict
@dataclass
class MachineInfo:
gpu: str
Expand All @@ -94,6 +107,7 @@ class MachineInfo:
gpu_cores: int | None = None


@strict
@dataclass
class BenchmarkResult:
timing_results: dict[str, TimingResults] # workload name -> timing
Expand Down
121 changes: 94 additions & 27 deletions kernels/src/kernels/cli/kernel_card_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import ast
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from huggingface_hub import ModelCard
from huggingface_hub.dataclasses import strict

from ..compat import tomllib

Expand All @@ -22,18 +24,87 @@
LIBRARY_NAME = "kernels"


def _parse_build_toml(local_path: str | Path) -> dict | None:
local_path = Path(local_path)
build_toml_path = local_path / "build.toml"
@strict
@dataclass
class HubConfig:
repo_id: str | None

@staticmethod
def from_dict(data: dict) -> "HubConfig":
return HubConfig(repo_id=data.get("repo-id"))


@strict
@dataclass
class GeneralConfig:
name: str | None
version: int | None
license: str | None
backends: list[str] | None
hub: HubConfig | None

@staticmethod
def from_dict(data: dict) -> "GeneralConfig":
# TODO: revisit `from_dict` as per
# https://github.com/huggingface/kernels/pull/393/changes#r2981209411
hub_data = data.get("hub")
return GeneralConfig(
name=data.get("name"),
version=data.get("version"),
license=data.get("license"),
backends=data.get("backends"),
hub=HubConfig.from_dict(hub_data) if hub_data else None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note on the discussion from Slack, this is one thing that goes wrong with hand-rolled deserialization. What if the user put in something that makes hub_data truthy, but not a dict. Then HubConfig.from_dict will fail with a hard-to-understand error, whereas it should just say what the field/section is expected to be.

You can avoid this by programming very defensively (e.g. here checking that it is a dict first), but those things you get for free from a library that does deserialization for you.

I think for now it's ok, since we are going to replace this deserialization with Rust as discussed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a note about it as well in bc042be.

)

if not build_toml_path.exists():
return None

try:
with open(build_toml_path, "rb") as f:
return tomllib.load(f)
except Exception:
return None
@strict
@dataclass
class KernelConfig:
cuda_capabilities: list[str] | None

@staticmethod
def from_dict(data: dict) -> "KernelConfig":
return KernelConfig(cuda_capabilities=data.get("cuda-capabilities"))


@strict
@dataclass
class BuildConfig:
general: GeneralConfig
kernel: dict[str, KernelConfig] | None
upstream: str | None

@staticmethod
def from_dict(data: dict) -> "BuildConfig":
general_data = data.get("general", {})
kernel_data = data.get("kernel")
return BuildConfig(
general=GeneralConfig.from_dict(general_data),
kernel=(
{
name: KernelConfig.from_dict(info)
for name, info in kernel_data.items()
}
if kernel_data
else None
),
upstream=data.get("upstream"),
)

@staticmethod
def load(build_toml_path: Path) -> "BuildConfig | None":
if not build_toml_path.exists():
return None
try:
with open(build_toml_path, "rb") as f:
data = tomllib.load(f)
except Exception:
return None
return BuildConfig.from_dict(data)


def _parse_build_toml(local_path: Path) -> BuildConfig | None:
return BuildConfig.load(local_path / "build.toml")


def _find_torch_ext_init(local_path: str | Path) -> Path | None:
Expand All @@ -44,7 +115,7 @@ def _find_torch_ext_init(local_path: str | Path) -> Path | None:
return None

try:
kernel_name = config.get("general", {}).get("name")
kernel_name = config.general.name
if not kernel_name:
return None

Expand Down Expand Up @@ -88,8 +159,9 @@ def _parse_repo_id(local_path: str | Path) -> str | None:
if not config:
return None

repo_id = config.get("general", {}).get("hub", {}).get("repo-id", None)
return repo_id
if config.general.hub is None:
return None
return config.general.hub.repo_id


def _build_kernel_card_vars(
Expand All @@ -111,28 +183,25 @@ def _build_kernel_card_vars(
# --- backends, CUDA capabilities, upstream ---
config = _parse_build_toml(local_path)
if config:
general_config = config.get("general", {})
backends = general_config.get("backends")
backends = config.general.backends
if backends:
vars["supported_backends"] = "\n".join(f"- {b}" for b in backends)

kernel_configs = config.get("kernel", {})
cuda_capabilities: set[Any] = set()

# TODO (sayakpaul): implement this to read from `metadata.json` per each build
for k in kernel_configs:
caps = kernel_configs[k].get("cuda-capabilities")
if caps:
cuda_capabilities.update(caps)
if config.kernel:
for kernel_cfg in config.kernel.values():
if kernel_cfg.cuda_capabilities:
cuda_capabilities.update(kernel_cfg.cuda_capabilities)
if cuda_capabilities:
vars["cuda_capabilities"] = "\n".join(
f"- {cap}" for cap in cuda_capabilities
)

upstream_repo = config.get("upstream", None)
if upstream_repo:
if config.upstream:
vars["source_code"] = (
f"Source code of this kernel originally comes from {upstream_repo}"
f"Source code of this kernel originally comes from {config.upstream}"
" and it was repurposed for compatibility with `kernels`."
)

Expand All @@ -148,15 +217,13 @@ def _build_kernel_card_vars(
return vars


def _update_kernel_card_license(
kernel_card: ModelCard, local_path: str | Path
) -> ModelCard:
def _update_kernel_card_license(kernel_card: ModelCard, local_path: Path) -> ModelCard:
config = _parse_build_toml(local_path)
if not config:
return kernel_card

existing_license = kernel_card.data.get("license", None)
license_from_config = config.get("general", {}).get("license", None)
license_from_config = config.general.license
final_license = license_from_config or existing_license
kernel_card.data["license"] = final_license
return kernel_card
70 changes: 61 additions & 9 deletions kernels/src/kernels/deps.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,65 @@
import importlib.util
import json
from dataclasses import dataclass, field
from pathlib import Path

from huggingface_hub.dataclasses import strict

from kernels.backends import Backend


@strict
@dataclass
class PythonPackage:
pkg: str
import_name: str | None

@staticmethod
def from_dict(data: dict) -> "PythonPackage":
return PythonPackage(
pkg=data["pkg"],
import_name=data.get("import"),
)


@strict
@dataclass
class DependencyInfo:
nix: list[str]
python: list[PythonPackage]

@staticmethod
def from_dict(data: dict) -> "DependencyInfo":
return DependencyInfo(
nix=data.get("nix", []),
python=[PythonPackage.from_dict(p) for p in data.get("python", [])],
)


@strict
@dataclass
class DependencyData:
general: dict = field(default_factory=dict)
backends: dict = field(default_factory=dict)

@staticmethod
def from_dict(data: dict) -> "DependencyData":
general = {
name: DependencyInfo.from_dict(info)
for name, info in data.get("general", {}).items()
}
backends = {
backend_name: {
name: DependencyInfo.from_dict(info) for name, info in deps.items()
}
for backend_name, deps in data.get("backends", {}).items()
}
return DependencyData(general=general, backends=backends)


try:
with open(Path(__file__).parent / "python_depends.json", "r") as f:
DEPENDENCY_DATA: dict = json.load(f)
_DEPENDENCY_DATA = DependencyData.from_dict(json.load(f))
except FileNotFoundError:
raise FileNotFoundError(
"Cannot load dependency data, is `kernels` correctly installed?"
Expand All @@ -23,32 +76,31 @@ def validate_dependencies(
dependencies (`list[str]`): A list of dependency strings to validate.
backend (`str`): The backend to validate dependencies for.
"""
general_deps = DEPENDENCY_DATA.get("general", {})
backend_deps = DEPENDENCY_DATA.get("backends", {}).get(backend.name, {})
general_deps = _DEPENDENCY_DATA.general
backend_deps = _DEPENDENCY_DATA.backends.get(backend.name, {})

# Validate each dependency
for dependency in dependencies:
# Look up dependency in general dependencies first, then backend-specific
if dependency in general_deps:
python_packages = general_deps[dependency].get("python", [])
dep_info = general_deps[dependency]
elif dependency in backend_deps:
python_packages = backend_deps[dependency].get("python", [])
dep_info = backend_deps[dependency]
else:
# Dependency not found in general or backend-specific dependencies
raise ValueError(
f"Kernel module `{kernel_module_name}` uses unsupported kernel dependency: {dependency}"
)

# Check if each python package is installed
for python_package in python_packages:
# Convert package name to module name (replace - with _)
pkg_name = python_package.get("pkg")
for python_package in dep_info.python:
pkg_name = python_package.pkg
# Assertion because this should not happen and is a bug.
assert (
pkg_name is not None
), f"Invalid dependency data for `{dependency}`: missing `pkg` field."

module_name = python_package.get("import")
module_name = python_package.import_name
if module_name is None:
# These are typically packages that do not provide any Python
# code, but get installed to Python's library dirctory. E.g.
Expand Down
Loading
Loading