diff --git a/opf/_common/checkpoint_download.py b/opf/_common/checkpoint_download.py index 743f889..ace1968 100644 --- a/opf/_common/checkpoint_download.py +++ b/opf/_common/checkpoint_download.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from pathlib import Path import shutil import sys @@ -9,6 +10,8 @@ DEFAULT_HF_MODEL_REPO: Final[str] = "openai/privacy-filter" +DEFAULT_HF_MODEL_REVISION: Final[str | None] = None +HF_MODEL_REVISION_ENV_VAR: Final[str] = "OPF_HF_REVISION" def _checkpoint_override_message() -> str: @@ -104,17 +107,20 @@ def ensure_default_checkpoint() -> str: f"{_checkpoint_override_message()}." ) from exc + revision = os.environ.get(HF_MODEL_REVISION_ENV_VAR) or DEFAULT_HF_MODEL_REVISION try: + revision_suffix = f" (revision={revision!r})" if revision else "" print( "Default OPF checkpoint not found at " f"{target}. Downloading from HuggingFace repo " - f"{DEFAULT_HF_MODEL_REPO!r} to {target}.", + f"{DEFAULT_HF_MODEL_REPO!r}{revision_suffix} to {target}.", file=sys.stderr, flush=True, ) try: snapshot_download( repo_id=DEFAULT_HF_MODEL_REPO, + revision=revision, local_dir=str(target), tqdm_class=_build_download_progress_class(), allow_patterns=["original/*"],