diff --git a/opf/_cli/common.py b/opf/_cli/common.py index 31224e1..b75f043 100644 --- a/opf/_cli/common.py +++ b/opf/_cli/common.py @@ -55,8 +55,13 @@ def add_device_arg(parser: object) -> None: parser.add_argument( "--device", type=str, - default="cuda", - help="Device to run on", + default="auto", + help=( + "Device to run on. 'auto' (default) picks cuda if a GPU is " + "available, otherwise cpu. Pass 'cuda' or 'cpu' explicitly to " + "override or to get a loud error when the requested backend is " + "unavailable." + ), ) diff --git a/opf/_common/device.py b/opf/_common/device.py new file mode 100644 index 0000000..f2b162c --- /dev/null +++ b/opf/_common/device.py @@ -0,0 +1,30 @@ +"""Device-name resolution helpers shared by CLI entrypoints.""" + +from __future__ import annotations + +import sys + +import torch + +AUTO_DEVICE: str = "auto" + + +def resolve_device(device_name: str) -> torch.device: + """Resolve a user-supplied device name into a concrete ``torch.device``. + + ``"auto"`` selects the best available device: CUDA if a GPU is detected, + otherwise CPU. Any other value is passed through to ``torch.device`` as-is + so that explicit requests like ``"cuda"`` or ``"cpu"`` still fail loudly + when the underlying backend is unavailable. + """ + if device_name == AUTO_DEVICE: + if torch.cuda.is_available(): + return torch.device("cuda") + print( + "info: no CUDA device detected; falling back to CPU " + "(pass --device cuda to override).", + file=sys.stderr, + flush=True, + ) + return torch.device("cpu") + return torch.device(device_name) diff --git a/opf/_core/runtime.py b/opf/_core/runtime.py index 2c3034e..5ed4f1a 100644 --- a/opf/_core/runtime.py +++ b/opf/_core/runtime.py @@ -19,6 +19,7 @@ REDACTED_OUTPUT_LABEL, REDACTED_OUTPUT_PLACEHOLDER, ) +from .._common.device import resolve_device from .._common.env import get_env_bool from .decoding import ViterbiCRFDecoder from .._common.label_space import resolve_label_space_from_config @@ -215,7 +216,7 @@ def load_inference_runtime( if output_mode not in OUTPUT_MODES: raise ValueError(f"Unsupported output_mode: {output_mode!r}") _validate_checkpoint_dir(checkpoint) - device = torch.device(device_name) + device = resolve_device(device_name) checkpoint_config = _load_checkpoint_config(checkpoint) n_ctx = _resolve_n_ctx(checkpoint_config, n_ctx_override, device) encoding_name = checkpoint_config.get("encoding") diff --git a/opf/_train/runner.py b/opf/_train/runner.py index b00e23f..274fa20 100644 --- a/opf/_train/runner.py +++ b/opf/_train/runner.py @@ -16,6 +16,7 @@ from .args import parse_args from .._api import resolve_checkpoint_path from .._common.constants import SCHEMA_VERSION +from .._common.device import resolve_device from .._common.label_space import ( resolve_checkpoint_label_space, resolve_label_space_from_config, @@ -588,7 +589,7 @@ def main(argv: Sequence[str] | None = None, *, prog: str | None = None) -> int: progress_interval_s = parsed_interval checkpoint = resolve_checkpoint_path(args.checkpoint) - device = torch.device(args.device) + device = resolve_device(args.device) # Default to Triton-backed MoE kernels on non-CPU devices unless callers # explicitly opt out. CPU uses torch ops by default so Triton stays optional.