diff --git a/kernels/pyproject.toml b/kernels/pyproject.toml index b73f7fe7..46f59ffd 100644 --- a/kernels/pyproject.toml +++ b/kernels/pyproject.toml @@ -10,7 +10,7 @@ license = { text = "Apache-2.0" } readme = "README.md" requires-python = ">= 3.9" dependencies = [ - "huggingface_hub>=0.26.0,<2.0", + "huggingface_hub>=1.3.0,<2.0", "Jinja2>=3.1.5", "packaging>=20.0", "pyyaml>=6", diff --git a/kernels/src/kernels/backends.py b/kernels/src/kernels/backends.py index cc14ba5e..41378bc0 100644 --- a/kernels/src/kernels/backends.py +++ b/kernels/src/kernels/backends.py @@ -3,13 +3,15 @@ import re import warnings from dataclasses import dataclass -from typing import ClassVar, Optional, Protocol +from typing import ClassVar, Optional, Protocol, runtime_checkable +from huggingface_hub.dataclasses import strict from packaging.version import Version from kernels.compat import has_torch +@runtime_checkable class Backend(Protocol): @property def name(self) -> str: @@ -49,6 +51,7 @@ def parse(s: str) -> "CANN": return CANN(version=Version(f"{m.group(1)}.{m.group(2)}")) +@strict @dataclass(unsafe_hash=True) class CPU: @property @@ -88,6 +91,7 @@ def parse(s: str) -> "CUDA": return CUDA(version=Version(f"{m.group(1)}.{m.group(2)}")) +@strict @dataclass(unsafe_hash=True) class Metal: @property @@ -105,6 +109,7 @@ def parse(s: str) -> "Metal": return Metal() +@strict @dataclass(unsafe_hash=True) class Neuron: @property diff --git a/kernels/src/kernels/cli/benchmark.py b/kernels/src/kernels/cli/benchmark.py index 0257d084..c22eb100 100644 --- a/kernels/src/kernels/cli/benchmark.py +++ b/kernels/src/kernels/cli/benchmark.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Any +from huggingface_hub.dataclasses import strict from huggingface_hub.utils import ( build_hf_headers, disable_progress_bars, @@ -69,6 +70,7 @@ def _calculate_iqr_and_outliers( return q1, q3, iqr, outliers +@strict @dataclass class TimingResults: mean_ms: float @@ -83,7 +85,18 @@ class TimingResults: verified: bool | None = None # None = no verify fn, True = passed, False = failed ref_mean_ms: float | None = None # Reference implementation mean time + def validate_iterations(self): + if self.iterations <= 0: + raise ValueError(f"iterations must be > 0, got {self.iterations}") + def validate_timing_range(self): + if self.min_ms > self.max_ms: + raise ValueError( + f"min_ms ({self.min_ms}) must be <= max_ms ({self.max_ms})" + ) + + +@strict @dataclass class MachineInfo: gpu: str @@ -94,6 +107,7 @@ class MachineInfo: gpu_cores: int | None = None +@strict @dataclass class BenchmarkResult: timing_results: dict[str, TimingResults] # workload name -> timing diff --git a/kernels/src/kernels/cli/kernel_card_utils.py b/kernels/src/kernels/cli/kernel_card_utils.py index 7363860a..56775aa8 100644 --- a/kernels/src/kernels/cli/kernel_card_utils.py +++ b/kernels/src/kernels/cli/kernel_card_utils.py @@ -1,8 +1,10 @@ import ast +from dataclasses import dataclass from pathlib import Path from typing import Any from huggingface_hub import ModelCard +from huggingface_hub.dataclasses import strict from ..compat import tomllib @@ -22,18 +24,87 @@ LIBRARY_NAME = "kernels" -def _parse_build_toml(local_path: str | Path) -> dict | None: - local_path = Path(local_path) - build_toml_path = local_path / "build.toml" +@strict +@dataclass +class HubConfig: + repo_id: str | None + + @staticmethod + def from_dict(data: dict) -> "HubConfig": + return HubConfig(repo_id=data.get("repo-id")) + + +@strict +@dataclass +class GeneralConfig: + name: str | None + version: int | None + license: str | None + backends: list[str] | None + hub: HubConfig | None + + @staticmethod + def from_dict(data: dict) -> "GeneralConfig": + # TODO: revisit `from_dict` as per + # https://github.com/huggingface/kernels/pull/393/changes#r2981209411 + hub_data = data.get("hub") + return GeneralConfig( + name=data.get("name"), + version=data.get("version"), + license=data.get("license"), + backends=data.get("backends"), + hub=HubConfig.from_dict(hub_data) if hub_data else None, + ) - if not build_toml_path.exists(): - return None - try: - with open(build_toml_path, "rb") as f: - return tomllib.load(f) - except Exception: - return None +@strict +@dataclass +class KernelConfig: + cuda_capabilities: list[str] | None + + @staticmethod + def from_dict(data: dict) -> "KernelConfig": + return KernelConfig(cuda_capabilities=data.get("cuda-capabilities")) + + +@strict +@dataclass +class BuildConfig: + general: GeneralConfig + kernel: dict[str, KernelConfig] | None + upstream: str | None + + @staticmethod + def from_dict(data: dict) -> "BuildConfig": + general_data = data.get("general", {}) + kernel_data = data.get("kernel") + return BuildConfig( + general=GeneralConfig.from_dict(general_data), + kernel=( + { + name: KernelConfig.from_dict(info) + for name, info in kernel_data.items() + } + if kernel_data + else None + ), + upstream=data.get("upstream"), + ) + + @staticmethod + def load(build_toml_path: Path) -> "BuildConfig | None": + if not build_toml_path.exists(): + return None + try: + with open(build_toml_path, "rb") as f: + data = tomllib.load(f) + except Exception: + return None + return BuildConfig.from_dict(data) + + +def _parse_build_toml(local_path: Path) -> BuildConfig | None: + return BuildConfig.load(local_path / "build.toml") def _find_torch_ext_init(local_path: str | Path) -> Path | None: @@ -44,7 +115,7 @@ def _find_torch_ext_init(local_path: str | Path) -> Path | None: return None try: - kernel_name = config.get("general", {}).get("name") + kernel_name = config.general.name if not kernel_name: return None @@ -88,8 +159,9 @@ def _parse_repo_id(local_path: str | Path) -> str | None: if not config: return None - repo_id = config.get("general", {}).get("hub", {}).get("repo-id", None) - return repo_id + if config.general.hub is None: + return None + return config.general.hub.repo_id def _build_kernel_card_vars( @@ -111,28 +183,25 @@ def _build_kernel_card_vars( # --- backends, CUDA capabilities, upstream --- config = _parse_build_toml(local_path) if config: - general_config = config.get("general", {}) - backends = general_config.get("backends") + backends = config.general.backends if backends: vars["supported_backends"] = "\n".join(f"- {b}" for b in backends) - kernel_configs = config.get("kernel", {}) cuda_capabilities: set[Any] = set() # TODO (sayakpaul): implement this to read from `metadata.json` per each build - for k in kernel_configs: - caps = kernel_configs[k].get("cuda-capabilities") - if caps: - cuda_capabilities.update(caps) + if config.kernel: + for kernel_cfg in config.kernel.values(): + if kernel_cfg.cuda_capabilities: + cuda_capabilities.update(kernel_cfg.cuda_capabilities) if cuda_capabilities: vars["cuda_capabilities"] = "\n".join( f"- {cap}" for cap in cuda_capabilities ) - upstream_repo = config.get("upstream", None) - if upstream_repo: + if config.upstream: vars["source_code"] = ( - f"Source code of this kernel originally comes from {upstream_repo}" + f"Source code of this kernel originally comes from {config.upstream}" " and it was repurposed for compatibility with `kernels`." ) @@ -148,15 +217,13 @@ def _build_kernel_card_vars( return vars -def _update_kernel_card_license( - kernel_card: ModelCard, local_path: str | Path -) -> ModelCard: +def _update_kernel_card_license(kernel_card: ModelCard, local_path: Path) -> ModelCard: config = _parse_build_toml(local_path) if not config: return kernel_card existing_license = kernel_card.data.get("license", None) - license_from_config = config.get("general", {}).get("license", None) + license_from_config = config.general.license final_license = license_from_config or existing_license kernel_card.data["license"] = final_license return kernel_card diff --git a/kernels/src/kernels/deps.py b/kernels/src/kernels/deps.py index 7ae1a1d8..74ba5606 100644 --- a/kernels/src/kernels/deps.py +++ b/kernels/src/kernels/deps.py @@ -1,12 +1,65 @@ import importlib.util import json +from dataclasses import dataclass, field from pathlib import Path +from huggingface_hub.dataclasses import strict + from kernels.backends import Backend + +@strict +@dataclass +class PythonPackage: + pkg: str + import_name: str | None + + @staticmethod + def from_dict(data: dict) -> "PythonPackage": + return PythonPackage( + pkg=data["pkg"], + import_name=data.get("import"), + ) + + +@strict +@dataclass +class DependencyInfo: + nix: list[str] + python: list[PythonPackage] + + @staticmethod + def from_dict(data: dict) -> "DependencyInfo": + return DependencyInfo( + nix=data.get("nix", []), + python=[PythonPackage.from_dict(p) for p in data.get("python", [])], + ) + + +@strict +@dataclass +class DependencyData: + general: dict = field(default_factory=dict) + backends: dict = field(default_factory=dict) + + @staticmethod + def from_dict(data: dict) -> "DependencyData": + general = { + name: DependencyInfo.from_dict(info) + for name, info in data.get("general", {}).items() + } + backends = { + backend_name: { + name: DependencyInfo.from_dict(info) for name, info in deps.items() + } + for backend_name, deps in data.get("backends", {}).items() + } + return DependencyData(general=general, backends=backends) + + try: with open(Path(__file__).parent / "python_depends.json", "r") as f: - DEPENDENCY_DATA: dict = json.load(f) + _DEPENDENCY_DATA = DependencyData.from_dict(json.load(f)) except FileNotFoundError: raise FileNotFoundError( "Cannot load dependency data, is `kernels` correctly installed?" @@ -23,16 +76,16 @@ def validate_dependencies( dependencies (`list[str]`): A list of dependency strings to validate. backend (`str`): The backend to validate dependencies for. """ - general_deps = DEPENDENCY_DATA.get("general", {}) - backend_deps = DEPENDENCY_DATA.get("backends", {}).get(backend.name, {}) + general_deps = _DEPENDENCY_DATA.general + backend_deps = _DEPENDENCY_DATA.backends.get(backend.name, {}) # Validate each dependency for dependency in dependencies: # Look up dependency in general dependencies first, then backend-specific if dependency in general_deps: - python_packages = general_deps[dependency].get("python", []) + dep_info = general_deps[dependency] elif dependency in backend_deps: - python_packages = backend_deps[dependency].get("python", []) + dep_info = backend_deps[dependency] else: # Dependency not found in general or backend-specific dependencies raise ValueError( @@ -40,15 +93,14 @@ def validate_dependencies( ) # Check if each python package is installed - for python_package in python_packages: - # Convert package name to module name (replace - with _) - pkg_name = python_package.get("pkg") + for python_package in dep_info.python: + pkg_name = python_package.pkg # Assertion because this should not happen and is a bug. assert ( pkg_name is not None ), f"Invalid dependency data for `{dependency}`: missing `pkg` field." - module_name = python_package.get("import") + module_name = python_package.import_name if module_name is None: # These are typically packages that do not provide any Python # code, but get installed to Python's library dirctory. E.g. diff --git a/kernels/src/kernels/layer/device.py b/kernels/src/kernels/layer/device.py index b7faedfa..c796b647 100644 --- a/kernels/src/kernels/layer/device.py +++ b/kernels/src/kernels/layer/device.py @@ -1,6 +1,9 @@ from dataclasses import dataclass +from huggingface_hub.dataclasses import strict + +@strict @dataclass(frozen=True) class CUDAProperties: """ @@ -34,6 +37,12 @@ class CUDAProperties: min_capability: int max_capability: int + def validate_capability_range(self): + if self.min_capability > self.max_capability: + raise ValueError( + f"min_capability ({self.min_capability}) must be <= max_capability ({self.max_capability})" + ) + def __eq__(self, other): if not isinstance(other, CUDAProperties): return NotImplemented @@ -46,6 +55,7 @@ def __hash__(self): return hash((self.min_capability, self.max_capability)) +@strict @dataclass(frozen=True) class ROCMProperties: """ @@ -79,6 +89,12 @@ class ROCMProperties: min_capability: int max_capability: int + def validate_capability_range(self): + if self.min_capability > self.max_capability: + raise ValueError( + f"min_capability ({self.min_capability}) must be <= max_capability ({self.max_capability})" + ) + def __eq__(self, other): if not isinstance(other, ROCMProperties): return NotImplemented @@ -91,6 +107,7 @@ def __hash__(self): return hash((self.min_capability, self.max_capability)) +@strict @dataclass(frozen=True) class Device: """ diff --git a/kernels/src/kernels/lockfile.py b/kernels/src/kernels/lockfile.py index 2aca2eef..ccc7755a 100644 --- a/kernels/src/kernels/lockfile.py +++ b/kernels/src/kernels/lockfile.py @@ -2,17 +2,21 @@ from dataclasses import dataclass from pathlib import Path +from huggingface_hub.dataclasses import strict + from kernels._versions import resolve_version_spec_as_ref from kernels.compat import tomllib from kernels.status import resolve_status +@strict @dataclass class VariantLock: hash: str hash_type: str = "git_lfs_concat" +@strict @dataclass class KernelLock: repo_id: str diff --git a/kernels/src/kernels/metadata.py b/kernels/src/kernels/metadata.py index 1c50abc4..d102a099 100644 --- a/kernels/src/kernels/metadata.py +++ b/kernels/src/kernels/metadata.py @@ -2,9 +2,12 @@ from dataclasses import dataclass from pathlib import Path +from huggingface_hub.dataclasses import strict + from kernels.compat import tomllib +@strict @dataclass class Metadata: python_depends: list[str] diff --git a/kernels/src/kernels/status.py b/kernels/src/kernels/status.py index f22cc099..b0a219b5 100644 --- a/kernels/src/kernels/status.py +++ b/kernels/src/kernels/status.py @@ -3,11 +3,13 @@ from typing import Union from huggingface_hub import HfApi +from huggingface_hub.dataclasses import strict from huggingface_hub.utils import EntryNotFoundError from kernels.compat import tomllib +@strict @dataclass class Redirect: kind: str # must be "redirect" diff --git a/kernels/src/kernels/variants.py b/kernels/src/kernels/variants.py index c40d6e69..8bfdddb8 100644 --- a/kernels/src/kernels/variants.py +++ b/kernels/src/kernels/variants.py @@ -6,6 +6,7 @@ from typing import ClassVar from huggingface_hub import HfApi +from huggingface_hub.dataclasses import strict from huggingface_hub.hf_api import RepoFolder from packaging.version import Version, parse @@ -65,6 +66,7 @@ def parse(s: str) -> "TvmFfi": return TvmFfi(version=Version(f"{m.group(1)}.{m.group(2)}")) +@strict @dataclass(unsafe_hash=True) class Arch: """Aarch kernel information.""" @@ -101,6 +103,7 @@ def parse(parts: list[str]) -> "Arch": return Arch(backend=backend, platform=platform, os=os, cxx11_abi=cxx11_abi) +@strict @dataclass(unsafe_hash=True) class Noarch: """Noarch kernel information.""" @@ -116,6 +119,7 @@ def parse(s: str) -> "Noarch": return Noarch(backend_name=s) +@strict @dataclass(unsafe_hash=True) class Variant: """Kernel build variant.""" diff --git a/nix-builder/overlay.nix b/nix-builder/overlay.nix index b82cdbf9..4c451919 100644 --- a/nix-builder/overlay.nix +++ b/nix-builder/overlay.nix @@ -77,6 +77,22 @@ in else python-self.callPackage ./pkgs/python-modules/cuda-python { }; + huggingface-hub = python-super.huggingface-hub.overridePythonAttrs (old: rec { + version = "1.3.0"; + src = python-super.fetchPypi { + pname = "huggingface_hub"; + inherit version; + hash = "sha256-KJ4qNYb98B41iClE6qBvvVdDbeJLbmU9H6skhYSs1ms="; + }; + dependencies = (old.dependencies or [ ]) ++ [ + python-self.httpx + python-self.shellingham + python-self.typer-slim + ]; + # Skip tests since they require network access. + doCheck = false; + }); + fastapi = python-super.fastapi.overrideAttrs ( _: prevAttrs: { # Gets stuck sometimes, already tested in nixpkgs.