Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions opf/_cli/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,13 @@ def add_device_arg(parser: object) -> None:
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to run on",
default="auto",
help=(
"Device to run on. 'auto' (default) picks cuda if a GPU is "
"available, otherwise cpu. Pass 'cuda' or 'cpu' explicitly to "
"override or to get a loud error when the requested backend is "
"unavailable."
),
)


Expand Down
30 changes: 30 additions & 0 deletions opf/_common/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Device-name resolution helpers shared by CLI entrypoints."""

from __future__ import annotations

import sys

import torch

AUTO_DEVICE: str = "auto"


def resolve_device(device_name: str) -> torch.device:
"""Resolve a user-supplied device name into a concrete ``torch.device``.

``"auto"`` selects the best available device: CUDA if a GPU is detected,
otherwise CPU. Any other value is passed through to ``torch.device`` as-is
so that explicit requests like ``"cuda"`` or ``"cpu"`` still fail loudly
when the underlying backend is unavailable.
"""
if device_name == AUTO_DEVICE:
if torch.cuda.is_available():
return torch.device("cuda")
print(
"info: no CUDA device detected; falling back to CPU "
"(pass --device cuda to override).",
file=sys.stderr,
flush=True,
)
return torch.device("cpu")
return torch.device(device_name)
3 changes: 2 additions & 1 deletion opf/_core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
REDACTED_OUTPUT_LABEL,
REDACTED_OUTPUT_PLACEHOLDER,
)
from .._common.device import resolve_device
from .._common.env import get_env_bool
from .decoding import ViterbiCRFDecoder
from .._common.label_space import resolve_label_space_from_config
Expand Down Expand Up @@ -215,7 +216,7 @@ def load_inference_runtime(
if output_mode not in OUTPUT_MODES:
raise ValueError(f"Unsupported output_mode: {output_mode!r}")
_validate_checkpoint_dir(checkpoint)
device = torch.device(device_name)
device = resolve_device(device_name)
checkpoint_config = _load_checkpoint_config(checkpoint)
n_ctx = _resolve_n_ctx(checkpoint_config, n_ctx_override, device)
encoding_name = checkpoint_config.get("encoding")
Expand Down
3 changes: 2 additions & 1 deletion opf/_train/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .args import parse_args
from .._api import resolve_checkpoint_path
from .._common.constants import SCHEMA_VERSION
from .._common.device import resolve_device
from .._common.label_space import (
resolve_checkpoint_label_space,
resolve_label_space_from_config,
Expand Down Expand Up @@ -588,7 +589,7 @@ def main(argv: Sequence[str] | None = None, *, prog: str | None = None) -> int:
progress_interval_s = parsed_interval

checkpoint = resolve_checkpoint_path(args.checkpoint)
device = torch.device(args.device)
device = resolve_device(args.device)

# Default to Triton-backed MoE kernels on non-CPU devices unless callers
# explicitly opt out. CPU uses torch ops by default so Triton stays optional.
Expand Down