diff --git a/opf/_cli/common.py b/opf/_cli/common.py index 31224e1..6852af6 100644 --- a/opf/_cli/common.py +++ b/opf/_cli/common.py @@ -55,11 +55,20 @@ def add_device_arg(parser: object) -> None: parser.add_argument( "--device", type=str, - default="cuda", - help="Device to run on", + default="auto", + help="Device to run on. 'auto' picks cuda if available, else cpu.", ) +def resolve_device(name: str) -> str: + """Resolve an ``auto`` device selection to a concrete torch device name.""" + if name != "auto": + return name + import torch + + return "cuda" if torch.cuda.is_available() else "cpu" + + def add_n_ctx_arg(parser: object) -> None: """Add the shared context-window override argument.""" parser.add_argument( diff --git a/opf/_cli/render.py b/opf/_cli/render.py index 784b543..1d8f6ff 100644 --- a/opf/_cli/render.py +++ b/opf/_cli/render.py @@ -111,12 +111,13 @@ def build_redactor_from_args( ): """Construct an ``OPF`` redactor from parsed CLI arguments.""" from .._api import OPF + from .common import resolve_device redactor = OPF( model=args.checkpoint, context_window_length=args.n_ctx, trim_whitespace=args.trim_span_whitespace, - device=args.device, + device=resolve_device(args.device), output_mode=args.output_mode, discard_overlapping_predicted_spans=args.discard_overlapping_predicted_spans, output_text_only=output_text_only, diff --git a/opf/_common/checkpoint_download.py b/opf/_common/checkpoint_download.py index 743f889..89c643e 100644 --- a/opf/_common/checkpoint_download.py +++ b/opf/_common/checkpoint_download.py @@ -47,7 +47,7 @@ def _reset_terminal_after_download() -> None: def _build_download_progress_class(): - from tqdm.auto import tqdm + from huggingface_hub.utils import tqdm class OpfDownloadTqdm(tqdm): """Flush HuggingFace's final progress footer before OPF resumes output.""" diff --git a/pyproject.toml b/pyproject.toml index c27d87a..cbcbc81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ description = "Standalone local-only OPF PyTorch model and eval tools." readme = "README.md" requires-python = ">=3.10" dependencies = [ - "huggingface_hub", + "huggingface_hub>=0.20", "numpy", "packaging", "torch",