Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions opf/_cli/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,13 @@ def add_device_arg(parser: object) -> None:
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to run on",
default="auto",
help=(
"Device to run on. 'auto' (default) picks the best available "
"backend: cuda > mps (Apple Silicon) > cpu. Pass an explicit "
"value like 'cuda', 'mps', or 'cpu' to override or to get a "
"loud error when the requested backend is unavailable."
),
)


Expand Down
51 changes: 51 additions & 0 deletions opf/_common/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Device-name resolution helpers shared by CLI entrypoints."""

from __future__ import annotations

import sys

import torch

AUTO_DEVICE: str = "auto"


def _mps_is_available() -> bool:
"""Return True when the current PyTorch build supports Apple Metal (MPS)."""
backend = getattr(torch.backends, "mps", None)
if backend is None:
return False
is_available = getattr(backend, "is_available", None)
if is_available is None:
return False
try:
return bool(is_available())
except Exception:
return False


def resolve_device(device_name: str) -> torch.device:
"""Resolve a user-supplied device name into a concrete ``torch.device``.

``"auto"`` selects the best available device in this order: CUDA (NVIDIA
GPU) > MPS (Apple Silicon GPU) > CPU. Any other value is passed through
to ``torch.device`` as-is so that explicit requests like ``"cuda"`` or
``"mps"`` still fail loudly when the underlying backend is unavailable.
"""
if device_name == AUTO_DEVICE:
if torch.cuda.is_available():
return torch.device("cuda")
if _mps_is_available():
print(
"info: no CUDA device detected; using Apple Metal (MPS).",
file=sys.stderr,
flush=True,
)
return torch.device("mps")
print(
"info: no CUDA or MPS device detected; falling back to CPU "
"(pass --device cuda or --device mps to override).",
file=sys.stderr,
flush=True,
)
return torch.device("cpu")
return torch.device(device_name)
3 changes: 2 additions & 1 deletion opf/_core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
REDACTED_OUTPUT_LABEL,
REDACTED_OUTPUT_PLACEHOLDER,
)
from .._common.device import resolve_device
from .._common.env import get_env_bool
from .decoding import ViterbiCRFDecoder
from .._common.label_space import resolve_label_space_from_config
Expand Down Expand Up @@ -215,7 +216,7 @@ def load_inference_runtime(
if output_mode not in OUTPUT_MODES:
raise ValueError(f"Unsupported output_mode: {output_mode!r}")
_validate_checkpoint_dir(checkpoint)
device = torch.device(device_name)
device = resolve_device(device_name)
checkpoint_config = _load_checkpoint_config(checkpoint)
n_ctx = _resolve_n_ctx(checkpoint_config, n_ctx_override, device)
encoding_name = checkpoint_config.get("encoding")
Expand Down
7 changes: 5 additions & 2 deletions opf/_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,8 +750,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
expert_indices = experts.indices
expert_weights = expert_weights / self.experts_per_token
experts_per_token_eff = self.experts_per_token
not_running_on_cpu = t.device.type != "cpu"
use_triton = get_env_bool("OPF_MOE_TRITON", default=not_running_on_cpu)
# Triton kernels are CUDA-only; auto-enable only on CUDA devices. MPS
# and CPU fall back to the torch-ops path unless the user explicitly
# opts in via OPF_MOE_TRITON=1.
is_cuda_device = t.device.type == "cuda"
use_triton = get_env_bool("OPF_MOE_TRITON", default=is_cuda_device)
if use_triton:
_require_triton()

Expand Down
10 changes: 6 additions & 4 deletions opf/_train/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .args import parse_args
from .._api import resolve_checkpoint_path
from .._common.constants import SCHEMA_VERSION
from .._common.device import resolve_device
from .._common.label_space import (
resolve_checkpoint_label_space,
resolve_label_space_from_config,
Expand Down Expand Up @@ -588,11 +589,12 @@ def main(argv: Sequence[str] | None = None, *, prog: str | None = None) -> int:
progress_interval_s = parsed_interval

checkpoint = resolve_checkpoint_path(args.checkpoint)
device = torch.device(args.device)
device = resolve_device(args.device)

# Default to Triton-backed MoE kernels on non-CPU devices unless callers
# explicitly opt out. CPU uses torch ops by default so Triton stays optional.
if device.type != "cpu":
# Default to Triton-backed MoE kernels on CUDA devices unless callers
# explicitly opt out. CPU and MPS use torch ops by default so Triton
# stays CUDA-only (the kernels don't run on Metal).
if device.type == "cuda":
os.environ.setdefault("OPF_MOE_TRITON", "1")

base_config = _load_checkpoint_config(checkpoint)
Expand Down