From a8cabe0e35737eef9217187729587f13520e72bc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Mar 2026 14:44:30 +0530 Subject: [PATCH 01/10] feat: implement strict type validation using strict from huggingface_hub. --- kernels/pyproject.toml | 2 +- kernels/src/kernels/backends.py | 11 +- kernels/src/kernels/cli/benchmark.py | 14 +++ kernels/src/kernels/cli/kernel_card_utils.py | 120 +++++++++++++++---- kernels/src/kernels/deps.py | 71 +++++++++-- kernels/src/kernels/layer/device.py | 17 +++ kernels/src/kernels/lockfile.py | 4 + kernels/src/kernels/metadata.py | 3 + kernels/src/kernels/status.py | 2 + kernels/src/kernels/variants.py | 6 + 10 files changed, 213 insertions(+), 37 deletions(-) diff --git a/kernels/pyproject.toml b/kernels/pyproject.toml index b73f7fe7..3468d159 100644 --- a/kernels/pyproject.toml +++ b/kernels/pyproject.toml @@ -10,7 +10,7 @@ license = { text = "Apache-2.0" } readme = "README.md" requires-python = ">= 3.9" dependencies = [ - "huggingface_hub>=0.26.0,<2.0", + "huggingface_hub>=0.31.4,<2.0", "Jinja2>=3.1.5", "packaging>=20.0", "pyyaml>=6", diff --git a/kernels/src/kernels/backends.py b/kernels/src/kernels/backends.py index cc14ba5e..2cc3fc98 100644 --- a/kernels/src/kernels/backends.py +++ b/kernels/src/kernels/backends.py @@ -3,13 +3,15 @@ import re import warnings from dataclasses import dataclass -from typing import ClassVar, Optional, Protocol +from typing import ClassVar, Optional, Protocol, runtime_checkable +from huggingface_hub.dataclasses import strict from packaging.version import Version from kernels.compat import has_torch +@runtime_checkable class Backend(Protocol): @property def name(self) -> str: @@ -27,6 +29,7 @@ def variant_str(self) -> str: ... +@strict @dataclass(unsafe_hash=True) class CANN: _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"cann(\d+)(\d+)") @@ -49,6 +52,7 @@ def parse(s: str) -> "CANN": return CANN(version=Version(f"{m.group(1)}.{m.group(2)}")) +@strict @dataclass(unsafe_hash=True) class CPU: @property @@ -66,6 +70,7 @@ def parse(s: str) -> "CPU": return CPU() +@strict @dataclass(unsafe_hash=True) class CUDA: _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"cu(\d+)(\d+)") @@ -88,6 +93,7 @@ def parse(s: str) -> "CUDA": return CUDA(version=Version(f"{m.group(1)}.{m.group(2)}")) +@strict @dataclass(unsafe_hash=True) class Metal: @property @@ -105,6 +111,7 @@ def parse(s: str) -> "Metal": return Metal() +@strict @dataclass(unsafe_hash=True) class Neuron: @property @@ -122,6 +129,7 @@ def parse(s: str) -> "Neuron": return Neuron() +@strict @dataclass(unsafe_hash=True) class ROCm: _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"rocm(\d+)(\d+)") @@ -144,6 +152,7 @@ def parse(s: str) -> "ROCm": return ROCm(version=Version(f"{m.group(1)}.{m.group(2)}")) +@strict @dataclass(unsafe_hash=True) class XPU: _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"xpu(\d+)(\d+)") diff --git a/kernels/src/kernels/cli/benchmark.py b/kernels/src/kernels/cli/benchmark.py index 0257d084..c22eb100 100644 --- a/kernels/src/kernels/cli/benchmark.py +++ b/kernels/src/kernels/cli/benchmark.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Any +from huggingface_hub.dataclasses import strict from huggingface_hub.utils import ( build_hf_headers, disable_progress_bars, @@ -69,6 +70,7 @@ def _calculate_iqr_and_outliers( return q1, q3, iqr, outliers +@strict @dataclass class TimingResults: mean_ms: float @@ -83,7 +85,18 @@ class TimingResults: verified: bool | None = None # None = no verify fn, True = passed, False = failed ref_mean_ms: float | None = None # Reference implementation mean time + def validate_iterations(self): + if self.iterations <= 0: + raise ValueError(f"iterations must be > 0, got {self.iterations}") + def validate_timing_range(self): + if self.min_ms > self.max_ms: + raise ValueError( + f"min_ms ({self.min_ms}) must be <= max_ms ({self.max_ms})" + ) + + +@strict @dataclass class MachineInfo: gpu: str @@ -94,6 +107,7 @@ class MachineInfo: gpu_cores: int | None = None +@strict @dataclass class BenchmarkResult: timing_results: dict[str, TimingResults] # workload name -> timing diff --git a/kernels/src/kernels/cli/kernel_card_utils.py b/kernels/src/kernels/cli/kernel_card_utils.py index 7363860a..67f8ab5d 100644 --- a/kernels/src/kernels/cli/kernel_card_utils.py +++ b/kernels/src/kernels/cli/kernel_card_utils.py @@ -1,8 +1,10 @@ import ast +from dataclasses import dataclass, field from pathlib import Path -from typing import Any +from typing import Any, Dict, List, Optional from huggingface_hub import ModelCard +from huggingface_hub.dataclasses import strict from ..compat import tomllib @@ -22,18 +24,85 @@ LIBRARY_NAME = "kernels" -def _parse_build_toml(local_path: str | Path) -> dict | None: - local_path = Path(local_path) - build_toml_path = local_path / "build.toml" +@strict +@dataclass +class HubConfig: + repo_id: str = "" + + @staticmethod + def from_dict(data: dict) -> "HubConfig": + return HubConfig(repo_id=data.get("repo-id", "")) + + +@strict +@dataclass +class GeneralConfig: + name: str = "" + version: Optional[int] = None + license: Optional[str] = None + backends: Optional[List[str]] = None + hub: Optional[HubConfig] = None + + @staticmethod + def from_dict(data: dict) -> "GeneralConfig": + hub_data = data.get("hub") + return GeneralConfig( + name=data.get("name", ""), + version=data.get("version"), + license=data.get("license"), + backends=data.get("backends"), + hub=HubConfig.from_dict(hub_data) if hub_data else None, + ) - if not build_toml_path.exists(): - return None - try: - with open(build_toml_path, "rb") as f: - return tomllib.load(f) - except Exception: - return None +@strict +@dataclass +class KernelConfig: + cuda_capabilities: Optional[List[str]] = None + + @staticmethod + def from_dict(data: dict) -> "KernelConfig": + return KernelConfig(cuda_capabilities=data.get("cuda-capabilities")) + + +@strict +@dataclass +class BuildConfig: + general: GeneralConfig = field(default_factory=GeneralConfig) + kernel: Optional[Dict[str, KernelConfig]] = None + upstream: Optional[str] = None + + @staticmethod + def from_dict(data: dict) -> "BuildConfig": + general_data = data.get("general", {}) + kernel_data = data.get("kernel") + return BuildConfig( + general=GeneralConfig.from_dict(general_data), + kernel=( + { + name: KernelConfig.from_dict(info) + for name, info in kernel_data.items() + } + if kernel_data + else None + ), + upstream=data.get("upstream"), + ) + + @staticmethod + def load(build_toml_path: Path) -> Optional["BuildConfig"]: + if not build_toml_path.exists(): + return None + try: + with open(build_toml_path, "rb") as f: + data = tomllib.load(f) + except Exception: + return None + return BuildConfig.from_dict(data) + + +def _parse_build_toml(local_path: Path) -> BuildConfig | None: + return BuildConfig.load(local_path / "build.toml") def _find_torch_ext_init(local_path: str | Path) -> Path | None: @@ -44,7 +113,7 @@ def _find_torch_ext_init(local_path: str | Path) -> Path | None: return None try: - kernel_name = config.get("general", {}).get("name") + kernel_name = config.general.name if not kernel_name: return None @@ -88,8 +157,10 @@ def _parse_repo_id(local_path: str | Path) -> str | None: if not config: return None - repo_id = config.get("general", {}).get("hub", {}).get("repo-id", None) - return repo_id + if config.general.hub is None: + return None + repo_id = config.general.hub.repo_id + return repo_id if repo_id else None def _build_kernel_card_vars( @@ -111,28 +182,25 @@ def _build_kernel_card_vars( # --- backends, CUDA capabilities, upstream --- config = _parse_build_toml(local_path) if config: - general_config = config.get("general", {}) - backends = general_config.get("backends") + backends = config.general.backends if backends: vars["supported_backends"] = "\n".join(f"- {b}" for b in backends) - kernel_configs = config.get("kernel", {}) cuda_capabilities: set[Any] = set() # TODO (sayakpaul): implement this to read from `metadata.json` per each build - for k in kernel_configs: - caps = kernel_configs[k].get("cuda-capabilities") - if caps: - cuda_capabilities.update(caps) + if config.kernel: + for kernel_cfg in config.kernel.values(): + if kernel_cfg.cuda_capabilities: + cuda_capabilities.update(kernel_cfg.cuda_capabilities) if cuda_capabilities: vars["cuda_capabilities"] = "\n".join( f"- {cap}" for cap in cuda_capabilities ) - upstream_repo = config.get("upstream", None) - if upstream_repo: + if config.upstream: vars["source_code"] = ( - f"Source code of this kernel originally comes from {upstream_repo}" + f"Source code of this kernel originally comes from {config.upstream}" " and it was repurposed for compatibility with `kernels`." ) @@ -151,12 +219,12 @@ def _build_kernel_card_vars( def _update_kernel_card_license( kernel_card: ModelCard, local_path: str | Path ) -> ModelCard: - config = _parse_build_toml(local_path) + config = _parse_build_toml(Path(local_path)) if not config: return kernel_card existing_license = kernel_card.data.get("license", None) - license_from_config = config.get("general", {}).get("license", None) + license_from_config = config.general.license final_license = license_from_config or existing_license kernel_card.data["license"] = final_license return kernel_card diff --git a/kernels/src/kernels/deps.py b/kernels/src/kernels/deps.py index 7ae1a1d8..7bbcc2f2 100644 --- a/kernels/src/kernels/deps.py +++ b/kernels/src/kernels/deps.py @@ -1,12 +1,66 @@ import importlib.util import json +from dataclasses import dataclass, field from pathlib import Path +from typing import Optional + +from huggingface_hub.dataclasses import strict from kernels.backends import Backend + +@strict +@dataclass +class PythonPackage: + pkg: str + import_name: Optional[str] = None + + @staticmethod + def from_dict(data: dict) -> "PythonPackage": + return PythonPackage( + pkg=data["pkg"], + import_name=data.get("import"), + ) + + +@strict +@dataclass +class DependencyInfo: + nix: list + python: list + + @staticmethod + def from_dict(data: dict) -> "DependencyInfo": + return DependencyInfo( + nix=data.get("nix", []), + python=[PythonPackage.from_dict(p) for p in data.get("python", [])], + ) + + +@strict +@dataclass +class DependencyData: + general: dict = field(default_factory=dict) + backends: dict = field(default_factory=dict) + + @staticmethod + def from_dict(data: dict) -> "DependencyData": + general = { + name: DependencyInfo.from_dict(info) + for name, info in data.get("general", {}).items() + } + backends = { + backend_name: { + name: DependencyInfo.from_dict(info) for name, info in deps.items() + } + for backend_name, deps in data.get("backends", {}).items() + } + return DependencyData(general=general, backends=backends) + + try: with open(Path(__file__).parent / "python_depends.json", "r") as f: - DEPENDENCY_DATA: dict = json.load(f) + _DEPENDENCY_DATA = DependencyData.from_dict(json.load(f)) except FileNotFoundError: raise FileNotFoundError( "Cannot load dependency data, is `kernels` correctly installed?" @@ -23,16 +77,16 @@ def validate_dependencies( dependencies (`list[str]`): A list of dependency strings to validate. backend (`str`): The backend to validate dependencies for. """ - general_deps = DEPENDENCY_DATA.get("general", {}) - backend_deps = DEPENDENCY_DATA.get("backends", {}).get(backend.name, {}) + general_deps = _DEPENDENCY_DATA.general + backend_deps = _DEPENDENCY_DATA.backends.get(backend.name, {}) # Validate each dependency for dependency in dependencies: # Look up dependency in general dependencies first, then backend-specific if dependency in general_deps: - python_packages = general_deps[dependency].get("python", []) + dep_info = general_deps[dependency] elif dependency in backend_deps: - python_packages = backend_deps[dependency].get("python", []) + dep_info = backend_deps[dependency] else: # Dependency not found in general or backend-specific dependencies raise ValueError( @@ -40,15 +94,14 @@ def validate_dependencies( ) # Check if each python package is installed - for python_package in python_packages: - # Convert package name to module name (replace - with _) - pkg_name = python_package.get("pkg") + for python_package in dep_info.python: + pkg_name = python_package.pkg # Assertion because this should not happen and is a bug. assert ( pkg_name is not None ), f"Invalid dependency data for `{dependency}`: missing `pkg` field." - module_name = python_package.get("import") + module_name = python_package.import_name if module_name is None: # These are typically packages that do not provide any Python # code, but get installed to Python's library dirctory. E.g. diff --git a/kernels/src/kernels/layer/device.py b/kernels/src/kernels/layer/device.py index b7faedfa..c796b647 100644 --- a/kernels/src/kernels/layer/device.py +++ b/kernels/src/kernels/layer/device.py @@ -1,6 +1,9 @@ from dataclasses import dataclass +from huggingface_hub.dataclasses import strict + +@strict @dataclass(frozen=True) class CUDAProperties: """ @@ -34,6 +37,12 @@ class CUDAProperties: min_capability: int max_capability: int + def validate_capability_range(self): + if self.min_capability > self.max_capability: + raise ValueError( + f"min_capability ({self.min_capability}) must be <= max_capability ({self.max_capability})" + ) + def __eq__(self, other): if not isinstance(other, CUDAProperties): return NotImplemented @@ -46,6 +55,7 @@ def __hash__(self): return hash((self.min_capability, self.max_capability)) +@strict @dataclass(frozen=True) class ROCMProperties: """ @@ -79,6 +89,12 @@ class ROCMProperties: min_capability: int max_capability: int + def validate_capability_range(self): + if self.min_capability > self.max_capability: + raise ValueError( + f"min_capability ({self.min_capability}) must be <= max_capability ({self.max_capability})" + ) + def __eq__(self, other): if not isinstance(other, ROCMProperties): return NotImplemented @@ -91,6 +107,7 @@ def __hash__(self): return hash((self.min_capability, self.max_capability)) +@strict @dataclass(frozen=True) class Device: """ diff --git a/kernels/src/kernels/lockfile.py b/kernels/src/kernels/lockfile.py index 2aca2eef..ccc7755a 100644 --- a/kernels/src/kernels/lockfile.py +++ b/kernels/src/kernels/lockfile.py @@ -2,17 +2,21 @@ from dataclasses import dataclass from pathlib import Path +from huggingface_hub.dataclasses import strict + from kernels._versions import resolve_version_spec_as_ref from kernels.compat import tomllib from kernels.status import resolve_status +@strict @dataclass class VariantLock: hash: str hash_type: str = "git_lfs_concat" +@strict @dataclass class KernelLock: repo_id: str diff --git a/kernels/src/kernels/metadata.py b/kernels/src/kernels/metadata.py index 1c50abc4..d102a099 100644 --- a/kernels/src/kernels/metadata.py +++ b/kernels/src/kernels/metadata.py @@ -2,9 +2,12 @@ from dataclasses import dataclass from pathlib import Path +from huggingface_hub.dataclasses import strict + from kernels.compat import tomllib +@strict @dataclass class Metadata: python_depends: list[str] diff --git a/kernels/src/kernels/status.py b/kernels/src/kernels/status.py index f22cc099..b0a219b5 100644 --- a/kernels/src/kernels/status.py +++ b/kernels/src/kernels/status.py @@ -3,11 +3,13 @@ from typing import Union from huggingface_hub import HfApi +from huggingface_hub.dataclasses import strict from huggingface_hub.utils import EntryNotFoundError from kernels.compat import tomllib +@strict @dataclass class Redirect: kind: str # must be "redirect" diff --git a/kernels/src/kernels/variants.py b/kernels/src/kernels/variants.py index c40d6e69..fdb0d262 100644 --- a/kernels/src/kernels/variants.py +++ b/kernels/src/kernels/variants.py @@ -6,6 +6,7 @@ from typing import ClassVar from huggingface_hub import HfApi +from huggingface_hub.dataclasses import strict from huggingface_hub.hf_api import RepoFolder from packaging.version import Version, parse @@ -25,6 +26,7 @@ ) +@strict @dataclass(unsafe_hash=True) class Torch: _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"torch(\d+?)(\d+)") @@ -47,6 +49,7 @@ def parse(s: str) -> "Torch": return Torch(version=Version(f"{m.group(1)}.{m.group(2)}")) +@strict @dataclass(unsafe_hash=True) class TvmFfi: _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"tvm-ffi(\d+?)(\d+)") @@ -65,6 +68,7 @@ def parse(s: str) -> "TvmFfi": return TvmFfi(version=Version(f"{m.group(1)}.{m.group(2)}")) +@strict @dataclass(unsafe_hash=True) class Arch: """Aarch kernel information.""" @@ -101,6 +105,7 @@ def parse(parts: list[str]) -> "Arch": return Arch(backend=backend, platform=platform, os=os, cxx11_abi=cxx11_abi) +@strict @dataclass(unsafe_hash=True) class Noarch: """Noarch kernel information.""" @@ -116,6 +121,7 @@ def parse(s: str) -> "Noarch": return Noarch(backend_name=s) +@strict @dataclass(unsafe_hash=True) class Variant: """Kernel build variant.""" From 574d23946f14a55748e93eb392228a8d72f63a39 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Mar 2026 15:01:27 +0530 Subject: [PATCH 02/10] remove from classes that using packaging.version.Version --- kernels/src/kernels/backends.py | 4 ---- kernels/src/kernels/variants.py | 2 -- 2 files changed, 6 deletions(-) diff --git a/kernels/src/kernels/backends.py b/kernels/src/kernels/backends.py index 2cc3fc98..41378bc0 100644 --- a/kernels/src/kernels/backends.py +++ b/kernels/src/kernels/backends.py @@ -29,7 +29,6 @@ def variant_str(self) -> str: ... -@strict @dataclass(unsafe_hash=True) class CANN: _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"cann(\d+)(\d+)") @@ -70,7 +69,6 @@ def parse(s: str) -> "CPU": return CPU() -@strict @dataclass(unsafe_hash=True) class CUDA: _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"cu(\d+)(\d+)") @@ -129,7 +127,6 @@ def parse(s: str) -> "Neuron": return Neuron() -@strict @dataclass(unsafe_hash=True) class ROCm: _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"rocm(\d+)(\d+)") @@ -152,7 +149,6 @@ def parse(s: str) -> "ROCm": return ROCm(version=Version(f"{m.group(1)}.{m.group(2)}")) -@strict @dataclass(unsafe_hash=True) class XPU: _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"xpu(\d+)(\d+)") diff --git a/kernels/src/kernels/variants.py b/kernels/src/kernels/variants.py index fdb0d262..8bfdddb8 100644 --- a/kernels/src/kernels/variants.py +++ b/kernels/src/kernels/variants.py @@ -26,7 +26,6 @@ ) -@strict @dataclass(unsafe_hash=True) class Torch: _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"torch(\d+?)(\d+)") @@ -49,7 +48,6 @@ def parse(s: str) -> "Torch": return Torch(version=Version(f"{m.group(1)}.{m.group(2)}")) -@strict @dataclass(unsafe_hash=True) class TvmFfi: _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"tvm-ffi(\d+?)(\d+)") From ef5f2e16f6dd0eaff579f0650287518b423b5534 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Mar 2026 15:18:58 +0530 Subject: [PATCH 03/10] fix --- kernels/pyproject.toml | 2 +- kernels/src/kernels/cli/kernel_card_utils.py | 18 +++++++++--------- kernels/src/kernels/deps.py | 3 +-- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/kernels/pyproject.toml b/kernels/pyproject.toml index 3468d159..46f59ffd 100644 --- a/kernels/pyproject.toml +++ b/kernels/pyproject.toml @@ -10,7 +10,7 @@ license = { text = "Apache-2.0" } readme = "README.md" requires-python = ">= 3.9" dependencies = [ - "huggingface_hub>=0.31.4,<2.0", + "huggingface_hub>=1.3.0,<2.0", "Jinja2>=3.1.5", "packaging>=20.0", "pyyaml>=6", diff --git a/kernels/src/kernels/cli/kernel_card_utils.py b/kernels/src/kernels/cli/kernel_card_utils.py index 67f8ab5d..b2dbcf95 100644 --- a/kernels/src/kernels/cli/kernel_card_utils.py +++ b/kernels/src/kernels/cli/kernel_card_utils.py @@ -1,7 +1,7 @@ import ast from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any from huggingface_hub import ModelCard from huggingface_hub.dataclasses import strict @@ -38,10 +38,10 @@ def from_dict(data: dict) -> "HubConfig": @dataclass class GeneralConfig: name: str = "" - version: Optional[int] = None - license: Optional[str] = None - backends: Optional[List[str]] = None - hub: Optional[HubConfig] = None + version: int | None = None + license: str | None = None + backends: list[str] | None = None + hub: HubConfig | None = None @staticmethod def from_dict(data: dict) -> "GeneralConfig": @@ -58,7 +58,7 @@ def from_dict(data: dict) -> "GeneralConfig": @strict @dataclass class KernelConfig: - cuda_capabilities: Optional[List[str]] = None + cuda_capabilities: list[str] | None = None @staticmethod def from_dict(data: dict) -> "KernelConfig": @@ -69,8 +69,8 @@ def from_dict(data: dict) -> "KernelConfig": @dataclass class BuildConfig: general: GeneralConfig = field(default_factory=GeneralConfig) - kernel: Optional[Dict[str, KernelConfig]] = None - upstream: Optional[str] = None + kernel: dict[str, KernelConfig] | None = None + upstream: str | None = None @staticmethod def from_dict(data: dict) -> "BuildConfig": @@ -90,7 +90,7 @@ def from_dict(data: dict) -> "BuildConfig": ) @staticmethod - def load(build_toml_path: Path) -> Optional["BuildConfig"]: + def load(build_toml_path: Path) -> "BuildConfig | None": if not build_toml_path.exists(): return None try: diff --git a/kernels/src/kernels/deps.py b/kernels/src/kernels/deps.py index 7bbcc2f2..312e8a76 100644 --- a/kernels/src/kernels/deps.py +++ b/kernels/src/kernels/deps.py @@ -2,7 +2,6 @@ import json from dataclasses import dataclass, field from pathlib import Path -from typing import Optional from huggingface_hub.dataclasses import strict @@ -13,7 +12,7 @@ @dataclass class PythonPackage: pkg: str - import_name: Optional[str] = None + import_name: str | None = None @staticmethod def from_dict(data: dict) -> "PythonPackage": From bd84f0faa205fc95446d42760591c7ddf7e602a2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Mar 2026 16:03:24 +0530 Subject: [PATCH 04/10] up --- nix-builder/overlay.nix | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/nix-builder/overlay.nix b/nix-builder/overlay.nix index c2f2839a..89bba5cd 100644 --- a/nix-builder/overlay.nix +++ b/nix-builder/overlay.nix @@ -77,6 +77,17 @@ in else python-self.callPackage ./pkgs/python-modules/cuda-python { }; + huggingface-hub = python-super.huggingface-hub.overridePythonAttrs (old: rec { + version = "1.3.0"; + src = python-super.fetchPypi { + pname = "huggingface_hub"; + inherit version; + hash = "sha256-KJ4qNYb98B41iClE6qBvvVdDbeJLbmU9H6skhYSs1ms="; + }; + # Skip tests since they require network access. + doCheck = false; + }); + fastapi = python-super.fastapi.overrideAttrs ( _: prevAttrs: { # Gets stuck sometimes, already tested in nixpkgs. From fb44e0cdd12cd16374a7a417d972a1bbd4762f65 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Mar 2026 16:12:56 +0530 Subject: [PATCH 05/10] up --- nix-builder/overlay.nix | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/nix-builder/overlay.nix b/nix-builder/overlay.nix index 89bba5cd..e6be59c8 100644 --- a/nix-builder/overlay.nix +++ b/nix-builder/overlay.nix @@ -84,6 +84,11 @@ in inherit version; hash = "sha256-KJ4qNYb98B41iClE6qBvvVdDbeJLbmU9H6skhYSs1ms="; }; + dependencies = (old.dependencies or [ ]) ++ [ + python-self.httpx + python-self.shellingham + python-self.typer-slim + ]; # Skip tests since they require network access. doCheck = false; }); From 13d4122e6889ce7c6186a9e2daf1873d8db16e06 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 24 Mar 2026 20:20:04 +0530 Subject: [PATCH 06/10] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Daniƫl de Kok --- kernels/src/kernels/cli/kernel_card_utils.py | 23 ++++++++++---------- kernels/src/kernels/deps.py | 6 ++--- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/kernels/src/kernels/cli/kernel_card_utils.py b/kernels/src/kernels/cli/kernel_card_utils.py index b2dbcf95..0beda2b8 100644 --- a/kernels/src/kernels/cli/kernel_card_utils.py +++ b/kernels/src/kernels/cli/kernel_card_utils.py @@ -27,27 +27,27 @@ @strict @dataclass class HubConfig: - repo_id: str = "" + repo_id: str | None @staticmethod def from_dict(data: dict) -> "HubConfig": - return HubConfig(repo_id=data.get("repo-id", "")) + return HubConfig(repo_id=data.get("repo-id")) @strict @dataclass class GeneralConfig: name: str = "" - version: int | None = None - license: str | None = None - backends: list[str] | None = None - hub: HubConfig | None = None + version: int | None + license: str | None + backends: list[str] | None + hub: HubConfig | None @staticmethod def from_dict(data: dict) -> "GeneralConfig": hub_data = data.get("hub") return GeneralConfig( - name=data.get("name", ""), + name=data.get("name"), version=data.get("version"), license=data.get("license"), backends=data.get("backends"), @@ -58,7 +58,7 @@ def from_dict(data: dict) -> "GeneralConfig": @strict @dataclass class KernelConfig: - cuda_capabilities: list[str] | None = None + cuda_capabilities: list[str] | None @staticmethod def from_dict(data: dict) -> "KernelConfig": @@ -69,8 +69,8 @@ def from_dict(data: dict) -> "KernelConfig": @dataclass class BuildConfig: general: GeneralConfig = field(default_factory=GeneralConfig) - kernel: dict[str, KernelConfig] | None = None - upstream: str | None = None + kernel: dict[str, KernelConfig] | None + upstream: str | None @staticmethod def from_dict(data: dict) -> "BuildConfig": @@ -159,8 +159,7 @@ def _parse_repo_id(local_path: str | Path) -> str | None: if config.general.hub is None: return None - repo_id = config.general.hub.repo_id - return repo_id if repo_id else None + return config.general.hub.repo_id def _build_kernel_card_vars( diff --git a/kernels/src/kernels/deps.py b/kernels/src/kernels/deps.py index 312e8a76..1303cc3e 100644 --- a/kernels/src/kernels/deps.py +++ b/kernels/src/kernels/deps.py @@ -12,7 +12,7 @@ @dataclass class PythonPackage: pkg: str - import_name: str | None = None + import_name: str | None @staticmethod def from_dict(data: dict) -> "PythonPackage": @@ -25,8 +25,8 @@ def from_dict(data: dict) -> "PythonPackage": @strict @dataclass class DependencyInfo: - nix: list - python: list + nix: list[str] + python: list[str] @staticmethod def from_dict(data: dict) -> "DependencyInfo": From bc042beb05392989f13ad55b6e506ffe6de2bdd2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Mar 2026 20:20:30 +0530 Subject: [PATCH 07/10] add comment about using HubConfig.from_dict. --- kernels/src/kernels/cli/kernel_card_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/kernels/src/kernels/cli/kernel_card_utils.py b/kernels/src/kernels/cli/kernel_card_utils.py index 0beda2b8..a0b4fbde 100644 --- a/kernels/src/kernels/cli/kernel_card_utils.py +++ b/kernels/src/kernels/cli/kernel_card_utils.py @@ -45,6 +45,8 @@ class GeneralConfig: @staticmethod def from_dict(data: dict) -> "GeneralConfig": + # TODO: revisit `from_dict` as per + # https://github.com/huggingface/kernels/pull/393/changes#r2981209411 hub_data = data.get("hub") return GeneralConfig( name=data.get("name"), From 5f78c049194eb345d19d6097f00273c0d591d1ee Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Mar 2026 20:22:04 +0530 Subject: [PATCH 08/10] use Path in _update_kernel_card_license --- kernels/src/kernels/cli/kernel_card_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/kernels/src/kernels/cli/kernel_card_utils.py b/kernels/src/kernels/cli/kernel_card_utils.py index a0b4fbde..62f6ffb5 100644 --- a/kernels/src/kernels/cli/kernel_card_utils.py +++ b/kernels/src/kernels/cli/kernel_card_utils.py @@ -217,10 +217,8 @@ def _build_kernel_card_vars( return vars -def _update_kernel_card_license( - kernel_card: ModelCard, local_path: str | Path -) -> ModelCard: - config = _parse_build_toml(Path(local_path)) +def _update_kernel_card_license(kernel_card: ModelCard, local_path: Path) -> ModelCard: + config = _parse_build_toml(local_path) if not config: return kernel_card From 4a0a854a7cc8cd8cd908df780d04b9a72e1cfebd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Mar 2026 20:25:12 +0530 Subject: [PATCH 09/10] fix --- kernels/src/kernels/deps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernels/src/kernels/deps.py b/kernels/src/kernels/deps.py index 1303cc3e..74ba5606 100644 --- a/kernels/src/kernels/deps.py +++ b/kernels/src/kernels/deps.py @@ -26,7 +26,7 @@ def from_dict(data: dict) -> "PythonPackage": @dataclass class DependencyInfo: nix: list[str] - python: list[str] + python: list[PythonPackage] @staticmethod def from_dict(data: dict) -> "DependencyInfo": From d2d4804ed5af45b2a15d938f1916ccf04bf3db69 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Mar 2026 20:37:20 +0530 Subject: [PATCH 10/10] fix mypy errors. --- kernels/src/kernels/cli/kernel_card_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/kernels/src/kernels/cli/kernel_card_utils.py b/kernels/src/kernels/cli/kernel_card_utils.py index 62f6ffb5..56775aa8 100644 --- a/kernels/src/kernels/cli/kernel_card_utils.py +++ b/kernels/src/kernels/cli/kernel_card_utils.py @@ -1,5 +1,5 @@ import ast -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path from typing import Any @@ -37,7 +37,7 @@ def from_dict(data: dict) -> "HubConfig": @strict @dataclass class GeneralConfig: - name: str = "" + name: str | None version: int | None license: str | None backends: list[str] | None @@ -70,7 +70,7 @@ def from_dict(data: dict) -> "KernelConfig": @strict @dataclass class BuildConfig: - general: GeneralConfig = field(default_factory=GeneralConfig) + general: GeneralConfig kernel: dict[str, KernelConfig] | None upstream: str | None