From 1e8b17e7fc47283b8e34429b40a63069d9d5e1c0 Mon Sep 17 00:00:00 2001 From: berkkirik Date: Fri, 24 Apr 2026 11:06:54 +0300 Subject: [PATCH] fix: default to auto-detected device so CPU-only machines don't crash The CLI --device flag defaulted to "cuda", which crashed with a raw PyTorch traceback ("Found no NVIDIA driver on your system ...") on machines without a GPU. Users had to discover --device cpu themselves. Add an "auto" mode that picks the best available backend (cuda if detected, otherwise cpu) and make it the default. Users who explicitly pass --device cuda still get the original loud failure on non-CUDA machines, which is the correct behavior when they ask for cuda by name. - opf/_common/device.py (new): resolve_device("auto"|...) helper. - opf/_cli/common.py: flip --device default to "auto", expand help text. - opf/_core/runtime.py, opf/_train/runner.py: call resolve_device() where device names turn into torch.device objects. Stderr on auto-fallback: info: no CUDA device detected; falling back to CPU (pass --device cuda to override). Fixes #12 --- opf/_cli/common.py | 9 +++++++-- opf/_common/device.py | 30 ++++++++++++++++++++++++++++++ opf/_core/runtime.py | 3 ++- opf/_train/runner.py | 3 ++- 4 files changed, 41 insertions(+), 4 deletions(-) create mode 100644 opf/_common/device.py diff --git a/opf/_cli/common.py b/opf/_cli/common.py index 31224e1..b75f043 100644 --- a/opf/_cli/common.py +++ b/opf/_cli/common.py @@ -55,8 +55,13 @@ def add_device_arg(parser: object) -> None: parser.add_argument( "--device", type=str, - default="cuda", - help="Device to run on", + default="auto", + help=( + "Device to run on. 'auto' (default) picks cuda if a GPU is " + "available, otherwise cpu. Pass 'cuda' or 'cpu' explicitly to " + "override or to get a loud error when the requested backend is " + "unavailable." + ), ) diff --git a/opf/_common/device.py b/opf/_common/device.py new file mode 100644 index 0000000..f2b162c --- /dev/null +++ b/opf/_common/device.py @@ -0,0 +1,30 @@ +"""Device-name resolution helpers shared by CLI entrypoints.""" + +from __future__ import annotations + +import sys + +import torch + +AUTO_DEVICE: str = "auto" + + +def resolve_device(device_name: str) -> torch.device: + """Resolve a user-supplied device name into a concrete ``torch.device``. + + ``"auto"`` selects the best available device: CUDA if a GPU is detected, + otherwise CPU. Any other value is passed through to ``torch.device`` as-is + so that explicit requests like ``"cuda"`` or ``"cpu"`` still fail loudly + when the underlying backend is unavailable. + """ + if device_name == AUTO_DEVICE: + if torch.cuda.is_available(): + return torch.device("cuda") + print( + "info: no CUDA device detected; falling back to CPU " + "(pass --device cuda to override).", + file=sys.stderr, + flush=True, + ) + return torch.device("cpu") + return torch.device(device_name) diff --git a/opf/_core/runtime.py b/opf/_core/runtime.py index 2c3034e..5ed4f1a 100644 --- a/opf/_core/runtime.py +++ b/opf/_core/runtime.py @@ -19,6 +19,7 @@ REDACTED_OUTPUT_LABEL, REDACTED_OUTPUT_PLACEHOLDER, ) +from .._common.device import resolve_device from .._common.env import get_env_bool from .decoding import ViterbiCRFDecoder from .._common.label_space import resolve_label_space_from_config @@ -215,7 +216,7 @@ def load_inference_runtime( if output_mode not in OUTPUT_MODES: raise ValueError(f"Unsupported output_mode: {output_mode!r}") _validate_checkpoint_dir(checkpoint) - device = torch.device(device_name) + device = resolve_device(device_name) checkpoint_config = _load_checkpoint_config(checkpoint) n_ctx = _resolve_n_ctx(checkpoint_config, n_ctx_override, device) encoding_name = checkpoint_config.get("encoding") diff --git a/opf/_train/runner.py b/opf/_train/runner.py index b00e23f..274fa20 100644 --- a/opf/_train/runner.py +++ b/opf/_train/runner.py @@ -16,6 +16,7 @@ from .args import parse_args from .._api import resolve_checkpoint_path from .._common.constants import SCHEMA_VERSION +from .._common.device import resolve_device from .._common.label_space import ( resolve_checkpoint_label_space, resolve_label_space_from_config, @@ -588,7 +589,7 @@ def main(argv: Sequence[str] | None = None, *, prog: str | None = None) -> int: progress_interval_s = parsed_interval checkpoint = resolve_checkpoint_path(args.checkpoint) - device = torch.device(args.device) + device = resolve_device(args.device) # Default to Triton-backed MoE kernels on non-CPU devices unless callers # explicitly opt out. CPU uses torch ops by default so Triton stays optional.