-
Notifications
You must be signed in to change notification settings - Fork 58
feat: implement strict type validation using strict from huggingface_hub
#393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a8cabe0
574d239
ef5f2e1
bd84f0f
fb44e0c
3279ecd
13d4122
bc042be
5f78c04
4a0a854
d2d4804
7eb1439
0d6db17
e6d4f1f
3bce9b3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| from huggingface_hub.dataclasses import strict | ||
| from huggingface_hub.utils import ( | ||
| build_hf_headers, | ||
| disable_progress_bars, | ||
|
|
@@ -69,6 +70,7 @@ def _calculate_iqr_and_outliers( | |
| return q1, q3, iqr, outliers | ||
|
|
||
|
|
||
| @strict | ||
| @dataclass | ||
| class TimingResults: | ||
| mean_ms: float | ||
|
|
@@ -83,7 +85,18 @@ class TimingResults: | |
| verified: bool | None = None # None = no verify fn, True = passed, False = failed | ||
| ref_mean_ms: float | None = None # Reference implementation mean time | ||
|
|
||
| def validate_iterations(self): | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are "magic" methods. |
||
| if self.iterations <= 0: | ||
| raise ValueError(f"iterations must be > 0, got {self.iterations}") | ||
|
|
||
| def validate_timing_range(self): | ||
| if self.min_ms > self.max_ms: | ||
| raise ValueError( | ||
| f"min_ms ({self.min_ms}) must be <= max_ms ({self.max_ms})" | ||
| ) | ||
|
|
||
|
|
||
| @strict | ||
| @dataclass | ||
| class MachineInfo: | ||
| gpu: str | ||
|
|
@@ -94,6 +107,7 @@ class MachineInfo: | |
| gpu_cores: int | None = None | ||
|
|
||
|
|
||
| @strict | ||
| @dataclass | ||
| class BenchmarkResult: | ||
| timing_results: dict[str, TimingResults] # workload name -> timing | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,10 @@ | ||
| import ast | ||
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| from huggingface_hub import ModelCard | ||
| from huggingface_hub.dataclasses import strict | ||
|
|
||
| from ..compat import tomllib | ||
|
|
||
|
|
@@ -22,18 +24,87 @@ | |
| LIBRARY_NAME = "kernels" | ||
|
|
||
|
|
||
| def _parse_build_toml(local_path: str | Path) -> dict | None: | ||
| local_path = Path(local_path) | ||
| build_toml_path = local_path / "build.toml" | ||
| @strict | ||
| @dataclass | ||
| class HubConfig: | ||
| repo_id: str | None | ||
|
|
||
| @staticmethod | ||
| def from_dict(data: dict) -> "HubConfig": | ||
| return HubConfig(repo_id=data.get("repo-id")) | ||
|
|
||
|
|
||
| @strict | ||
| @dataclass | ||
| class GeneralConfig: | ||
| name: str | None | ||
| version: int | None | ||
| license: str | None | ||
| backends: list[str] | None | ||
| hub: HubConfig | None | ||
|
|
||
| @staticmethod | ||
| def from_dict(data: dict) -> "GeneralConfig": | ||
| # TODO: revisit `from_dict` as per | ||
| # https://github.com/huggingface/kernels/pull/393/changes#r2981209411 | ||
| hub_data = data.get("hub") | ||
| return GeneralConfig( | ||
| name=data.get("name"), | ||
| version=data.get("version"), | ||
| license=data.get("license"), | ||
| backends=data.get("backends"), | ||
| hub=HubConfig.from_dict(hub_data) if hub_data else None, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note on the discussion from Slack, this is one thing that goes wrong with hand-rolled deserialization. What if the user put in something that makes You can avoid this by programming very defensively (e.g. here checking that it is a dict first), but those things you get for free from a library that does deserialization for you. I think for now it's ok, since we are going to replace this deserialization with Rust as discussed.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Made a note about it as well in bc042be. |
||
| ) | ||
|
|
||
| if not build_toml_path.exists(): | ||
| return None | ||
|
|
||
| try: | ||
| with open(build_toml_path, "rb") as f: | ||
| return tomllib.load(f) | ||
| except Exception: | ||
| return None | ||
| @strict | ||
| @dataclass | ||
| class KernelConfig: | ||
| cuda_capabilities: list[str] | None | ||
|
|
||
| @staticmethod | ||
| def from_dict(data: dict) -> "KernelConfig": | ||
| return KernelConfig(cuda_capabilities=data.get("cuda-capabilities")) | ||
|
|
||
|
|
||
| @strict | ||
| @dataclass | ||
| class BuildConfig: | ||
| general: GeneralConfig | ||
| kernel: dict[str, KernelConfig] | None | ||
| upstream: str | None | ||
|
|
||
| @staticmethod | ||
| def from_dict(data: dict) -> "BuildConfig": | ||
| general_data = data.get("general", {}) | ||
| kernel_data = data.get("kernel") | ||
| return BuildConfig( | ||
| general=GeneralConfig.from_dict(general_data), | ||
| kernel=( | ||
| { | ||
| name: KernelConfig.from_dict(info) | ||
| for name, info in kernel_data.items() | ||
| } | ||
| if kernel_data | ||
| else None | ||
| ), | ||
| upstream=data.get("upstream"), | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def load(build_toml_path: Path) -> "BuildConfig | None": | ||
| if not build_toml_path.exists(): | ||
| return None | ||
| try: | ||
| with open(build_toml_path, "rb") as f: | ||
| data = tomllib.load(f) | ||
| except Exception: | ||
| return None | ||
| return BuildConfig.from_dict(data) | ||
|
|
||
|
|
||
| def _parse_build_toml(local_path: Path) -> BuildConfig | None: | ||
| return BuildConfig.load(local_path / "build.toml") | ||
|
|
||
|
|
||
| def _find_torch_ext_init(local_path: str | Path) -> Path | None: | ||
|
|
@@ -44,7 +115,7 @@ def _find_torch_ext_init(local_path: str | Path) -> Path | None: | |
| return None | ||
|
|
||
| try: | ||
| kernel_name = config.get("general", {}).get("name") | ||
| kernel_name = config.general.name | ||
| if not kernel_name: | ||
| return None | ||
|
|
||
|
|
@@ -88,8 +159,9 @@ def _parse_repo_id(local_path: str | Path) -> str | None: | |
| if not config: | ||
| return None | ||
|
|
||
| repo_id = config.get("general", {}).get("hub", {}).get("repo-id", None) | ||
| return repo_id | ||
| if config.general.hub is None: | ||
| return None | ||
| return config.general.hub.repo_id | ||
|
|
||
|
|
||
| def _build_kernel_card_vars( | ||
|
|
@@ -111,28 +183,25 @@ def _build_kernel_card_vars( | |
| # --- backends, CUDA capabilities, upstream --- | ||
| config = _parse_build_toml(local_path) | ||
| if config: | ||
| general_config = config.get("general", {}) | ||
| backends = general_config.get("backends") | ||
| backends = config.general.backends | ||
| if backends: | ||
| vars["supported_backends"] = "\n".join(f"- {b}" for b in backends) | ||
|
|
||
| kernel_configs = config.get("kernel", {}) | ||
| cuda_capabilities: set[Any] = set() | ||
|
|
||
| # TODO (sayakpaul): implement this to read from `metadata.json` per each build | ||
| for k in kernel_configs: | ||
| caps = kernel_configs[k].get("cuda-capabilities") | ||
| if caps: | ||
| cuda_capabilities.update(caps) | ||
| if config.kernel: | ||
| for kernel_cfg in config.kernel.values(): | ||
| if kernel_cfg.cuda_capabilities: | ||
| cuda_capabilities.update(kernel_cfg.cuda_capabilities) | ||
| if cuda_capabilities: | ||
| vars["cuda_capabilities"] = "\n".join( | ||
| f"- {cap}" for cap in cuda_capabilities | ||
| ) | ||
|
|
||
| upstream_repo = config.get("upstream", None) | ||
| if upstream_repo: | ||
| if config.upstream: | ||
| vars["source_code"] = ( | ||
| f"Source code of this kernel originally comes from {upstream_repo}" | ||
| f"Source code of this kernel originally comes from {config.upstream}" | ||
| " and it was repurposed for compatibility with `kernels`." | ||
| ) | ||
|
|
||
|
|
@@ -148,15 +217,13 @@ def _build_kernel_card_vars( | |
| return vars | ||
|
|
||
|
|
||
| def _update_kernel_card_license( | ||
| kernel_card: ModelCard, local_path: str | Path | ||
| ) -> ModelCard: | ||
| def _update_kernel_card_license(kernel_card: ModelCard, local_path: Path) -> ModelCard: | ||
| config = _parse_build_toml(local_path) | ||
| if not config: | ||
| return kernel_card | ||
|
|
||
| existing_license = kernel_card.data.get("license", None) | ||
| license_from_config = config.get("general", {}).get("license", None) | ||
| license_from_config = config.general.license | ||
| final_license = license_from_config or existing_license | ||
| kernel_card.data["license"] = final_license | ||
| return kernel_card | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because
@strictvalidates field types usingisinstance()checks at construction time. WhenArch.backendis typed asBackend(aProtocol),@stricttries to doisinstance(value, Backend)— and Python'styping.Protocolraises TypeError unless it's decorated with@runtime_checkable.