Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions opf/_cli/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,20 @@ def add_device_arg(parser: object) -> None:
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to run on",
default="auto",
help="Device to run on. 'auto' picks cuda if available, else cpu.",
)


def resolve_device(name: str) -> str:
"""Resolve an ``auto`` device selection to a concrete torch device name."""
if name != "auto":
return name
import torch

return "cuda" if torch.cuda.is_available() else "cpu"


def add_n_ctx_arg(parser: object) -> None:
"""Add the shared context-window override argument."""
parser.add_argument(
Expand Down
3 changes: 2 additions & 1 deletion opf/_cli/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,13 @@ def build_redactor_from_args(
):
"""Construct an ``OPF`` redactor from parsed CLI arguments."""
from .._api import OPF
from .common import resolve_device

redactor = OPF(
model=args.checkpoint,
context_window_length=args.n_ctx,
trim_whitespace=args.trim_span_whitespace,
device=args.device,
device=resolve_device(args.device),
output_mode=args.output_mode,
discard_overlapping_predicted_spans=args.discard_overlapping_predicted_spans,
output_text_only=output_text_only,
Expand Down
2 changes: 1 addition & 1 deletion opf/_common/checkpoint_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _reset_terminal_after_download() -> None:


def _build_download_progress_class():
from tqdm.auto import tqdm
from huggingface_hub.utils import tqdm

class OpfDownloadTqdm(tqdm):
"""Flush HuggingFace's final progress footer before OPF resumes output."""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ description = "Standalone local-only OPF PyTorch model and eval tools."
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"huggingface_hub",
"huggingface_hub>=0.20",
"numpy",
"packaging",
"torch",
Expand Down