diff --git a/docs/tagger-cli-api.md b/docs/tagger-cli-api.md new file mode 100644 index 00000000..e6624736 --- /dev/null +++ b/docs/tagger-cli-api.md @@ -0,0 +1,254 @@ +# Tagger CLI and OpenAI-Compatible API Captioning + +This document covers the first command-line dataset tagging path for issue #40. +It does not change the WebUI tagger page and does not integrate the separate +dataset tag editor. + +## Modes + +### Local TAG mode + +Local mode uses the existing WD/CL ONNX taggers in `mikazuki/tagger/` and writes +Danbooru-style TAG captions beside each image: + +```powershell +python -m mikazuki.tagger.cli local --path .\input --model wd14-convnextv2-v2 +``` + +```bash +python -m mikazuki.tagger.cli local --path ./input --model wd14-convnextv2-v2 +``` + +Useful options: + +- `--threshold 0.35`: general tag threshold. +- `--character-threshold 0.6`: character tag threshold. +- `--recursive`: scan child folders. +- `--additional-tags "masterpiece, best quality"`: always append tags. +- `--exclude-tags "lowres, bad anatomy"`: remove exact tags. +- `--use-cn-mirror`: use `https://hf-mirror.com` if the local model is missing. +- `--hf-endpoint https://...`: use a custom Hugging Face-compatible endpoint. +- `--on-conflict ignore|copy|prepend`: skip existing `.txt`, replace it, or + prepend new tags to existing text. +- `--no-replace-underscore`: keep underscores in tags. +- `--no-escape-tag`: do not escape parentheses and backslashes. + +Wrapper scripts are available: + +```powershell +.\scripts\cli\tagger.ps1 local --path .\input +``` + +```bash +bash scripts/cli/tagger.sh local --path ./input +``` + +The wrappers only set `HF_HOME=huggingface`, prefer the local `venv` Python when +present, and forward all arguments to `python -m mikazuki.tagger.cli`. +The Python service also defaults `HF_HOME` to the project `huggingface/` folder +when the environment variable is not already set, so first-run model downloads do +not go to the user's global Hugging Face cache. +Set `USE_CN_MIRROR=1` before running the wrapper to set +`HF_ENDPOINT=https://hf-mirror.com` when `HF_ENDPOINT` is not already set: + +```powershell +$env:USE_CN_MIRROR = "1" +.\scripts\cli\tagger.ps1 local --path .\input +``` + +```bash +USE_CN_MIRROR=1 bash scripts/cli/tagger.sh local --path ./input +``` + +Mirror note: the code path honors `HF_ENDPOINT`, but the mirror itself must be +compatible with the installed `huggingface_hub` version and the target model's +large-file hosting. If mirror metadata resolution fails, unset `HF_ENDPOINT` or +prefetch the model into `huggingface/` by another network path. + +Local model priority: + +1. If `MIKAZUKI_TAGGER_DIR` is set, the loader checks that directory first. +2. Then it checks project-local built-in locations: + - `taggers//` + - `models/taggers//` + - `huggingface/taggers//` +3. If required files are present, they are used directly and no network download + is attempted. +4. If no local files are found, Hugging Face download is attempted through + `HF_ENDPOINT`, `--hf-endpoint`, `--use-cn-mirror`, or direct HF in that order. + +For WD taggers, place `model.onnx` and `selected_tags.csv` together, for example: + +```text +taggers/ + wd14-convnextv2-v2/ + model.onnx + selected_tags.csv +``` + +For `cl_tagger_1_01`, place `model.onnx` and `tag_mapping.json` together: + +```text +taggers/ + cl_tagger_1_01/ + model.onnx + tag_mapping.json +``` + +### Local NL caption mode + +Caption mode uses a local Hugging Face BLIP-compatible caption model and writes +natural-language captions beside each image: + +```powershell +python -m mikazuki.tagger.cli caption --path .\input +``` + +```bash +python -m mikazuki.tagger.cli caption --path ./input +``` + +Defaults: + +- Model: `Salesforce/blip-image-captioning-base` +- Output: natural-language `.txt` captions +- Cache: project `huggingface/` folder unless `HF_HOME` is already set + +Useful options: + +- `--model Salesforce/blip-image-captioning-base`: Hugging Face model id. +- `--prompt "a photo of"`: optional conditional caption prompt. +- `--device auto|cpu|cuda`: torch device selection. +- `--max-new-tokens 64`: generated caption length cap. +- `--use-cn-mirror` / `--hf-endpoint`: download source for missing caption + model files. +- `--recursive`, `--additional-tags`, `--exclude-tags`, and `--on-conflict` + behave like API NL mode. + +First-run model downloads print a stage message and then rely on Hugging Face / +Transformers console progress for file downloads. This satisfies the command-line +progress requirement; WebUI SSE/WebSocket progress remains a later UI task. + +### API NL mode + +API mode calls an OpenAI-compatible Chat Completions vision endpoint and writes a +natural-language caption beside each image: + +```powershell +$env:OPENAI_API_KEY = "sk-..." +python -m mikazuki.tagger.cli api --path .\input --model gpt-4o-mini ` + --prompt "Describe this image for LoRA training. Return one concise caption." +``` + +```bash +export OPENAI_API_KEY="sk-..." +python -m mikazuki.tagger.cli api --path ./input --model gpt-4o-mini \ + --prompt "Describe this image for LoRA training. Return one concise caption." +``` + +Useful options: + +- `--endpoint https://api.openai.com/v1`: base endpoint. The CLI posts to + `{endpoint}/chat/completions`. +- `--api-key sk-...`: explicit key. This takes precedence over env lookup. +- `--api-key-env OPENAI_API_KEY`: environment variable used when `--api-key` is + not provided. +- `--timeout 60`: request timeout per image. +- `--retries 2`: retry count per image. +- `--recursive`, `--additional-tags`, `--exclude-tags`, and `--on-conflict` + behave like local mode. For API captions, additional/exclude values are applied + as comma/newline text fragments rather than confidence-scored tags. + +API mode sends image bytes to the configured endpoint. Users should confirm +privacy, safety, and billing terms for their provider before running it on a +dataset. + +## Output Rules + +- Supported images are detected through Pillow's registered image extensions. +- A sidecar caption is written as `image_name.txt` in the same directory as the + image. +- `local` writes Danbooru-style TAG captions; `caption` and `api` write + natural-language captions. +- `--on-conflict ignore` skips images that already have a sidecar `.txt`. +- `--on-conflict copy` replaces existing text. +- `--on-conflict prepend` writes the new caption before the existing text. +- Duplicate comma-separated fragments are removed while preserving first + occurrence order. + +## OpenAI-Compatible Request Shape + +The CLI sends a `POST` request to `{endpoint}/chat/completions`: + +```json +{ + "model": "gpt-4o-mini", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe this image for image model training." + }, + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,..." + } + } + ] + } + ] +} +``` + +The response parser expects: + +```json +{ + "choices": [ + { + "message": { + "content": "A natural-language caption." + } + } + ] +} +``` + +If the response does not contain `choices[0].message.content`, the CLI fails with +a clear error and does not silently write an empty caption. + +## Interface Reference + +Public command-line interfaces: + +- `python -m mikazuki.tagger.cli local`: local TAG tagging with WD/CL ONNX + interrogators. +- `python -m mikazuki.tagger.cli caption`: local NL captioning with a + BLIP-compatible Hugging Face model. +- `python -m mikazuki.tagger.cli api`: OpenAI-compatible NL captioning through + Chat Completions vision. +- `scripts/cli/tagger.ps1` and `scripts/cli/tagger.sh`: thin wrappers that set + project-local cache defaults and forward arguments. + +Programmatic interfaces reserved for WebUI or future tooling reuse: + +- `run_local_tagger(...)`: local TAG batch runner. +- `run_caption_tagger(...)`: local NL batch runner. +- `run_api_tagger(...)`: API NL batch runner. +- `OpenAICompatibleCaptionClient`: API client using `/chat/completions`. +- `LocalBlipCaptionClient`: local BLIP-compatible caption client. +- Existing WebUI endpoint `POST /interrogate`: still accepts the current + `TaggerInterrogateRequest` fields and now delegates to `run_local_tagger`. + +No dataset tag editor API is introduced in this phase. + +## Later Model Candidate + +PixAI Tagger v0.9 is a strong future local TAG candidate because its model card +describes a newer Danbooru snapshot through 2025-01 and about 13.5k +Danbooru-style tags. It is intentionally not added in this first CLI pass because +it would introduce new dependency and model-size decisions beyond the existing +WD/CL ONNX path. diff --git a/mikazuki/app/api.py b/mikazuki/app/api.py index ab6470ef..900aef8b 100644 --- a/mikazuki/app/api.py +++ b/mikazuki/app/api.py @@ -22,8 +22,7 @@ from mikazuki.app.models import (APIResponse, APIResponseFail, APIResponseSuccess, TaggerInterrogateRequest) from mikazuki.log import log -from mikazuki.tagger.interrogator import (available_interrogators, - on_interrogate) +from mikazuki.tagger.service import run_local_tagger from mikazuki.tasks import tm from mikazuki.train_log_hub import hub as train_log_hub from mikazuki.utils import train_utils @@ -402,29 +401,21 @@ async def run_script(request: Request, background_tasks: BackgroundTasks): @router.post("/interrogate") async def run_interrogate(req: TaggerInterrogateRequest, background_tasks: BackgroundTasks): - interrogator = available_interrogators.get(req.interrogator_model, available_interrogators["wd14-convnextv2-v2"]) background_tasks.add_task( - on_interrogate, - image=None, - batch_input_glob=req.path, - batch_input_recursive=req.batch_input_recursive, - batch_output_dir="", - batch_output_filename_format="[name].[output_extension]", - batch_output_action_on_conflict=req.batch_output_action_on_conflict, - batch_remove_duplicated_tag=True, - batch_output_save_json=False, - interrogator=interrogator, + run_local_tagger, + input_path=req.path, + model=req.interrogator_model, threshold=req.threshold, character_threshold=req.character_threshold, - add_rating_tag=req.add_rating_tag, - add_model_tag=req.add_model_tag, + recursive=req.batch_input_recursive, additional_tags=req.additional_tags, exclude_tags=req.exclude_tags, - sort_by_alphabetical_order=False, - add_confident_as_weight=False, + on_conflict=req.batch_output_action_on_conflict, replace_underscore=req.replace_underscore, replace_underscore_excludes=req.replace_underscore_excludes, escape_tag=req.escape_tag, + add_rating_tag=req.add_rating_tag, + add_model_tag=req.add_model_tag, unload_model_after_running=True ) return APIResponseSuccess() diff --git a/mikazuki/tagger/cli.py b/mikazuki/tagger/cli.py new file mode 100644 index 00000000..bd442a39 --- /dev/null +++ b/mikazuki/tagger/cli.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import argparse +import os +import sys + +from mikazuki.tagger.interrogator import available_interrogators +from mikazuki.tagger.service import ( + DEFAULT_API_PROMPT, + DEFAULT_LOCAL_CAPTION_MODEL, + DEFAULT_OPENAI_ENDPOINT, + LocalBlipCaptionClient, + OpenAICompatibleCaptionClient, + run_api_tagger, + run_caption_tagger, + run_local_tagger, +) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="python -m mikazuki.tagger.cli", + description="Batch tag/caption dataset images and write sidecar .txt files.", + ) + subparsers = parser.add_subparsers(dest="command", required=True) + + local = subparsers.add_parser("local", help="Run a local WD/CL tagger and write TAG captions.") + local.add_argument("--path", required=True, help="Image file, folder, or glob pattern.") + local.add_argument( + "--model", + default="wd14-convnextv2-v2", + choices=sorted(available_interrogators), + help="Local tagger model key.", + ) + local.add_argument("--threshold", type=float, default=0.35, help="General tag threshold.") + local.add_argument("--character-threshold", type=float, default=0.6, help="Character tag threshold.") + local.add_argument("--recursive", action="store_true", help="Search folders recursively.") + local.add_argument("--additional-tags", default="", help="Comma-separated tags to append.") + local.add_argument("--exclude-tags", default="", help="Comma-separated tags to remove.") + local.add_argument("--use-cn-mirror", action="store_true", help="Use https://hf-mirror.com for missing model downloads.") + local.add_argument("--hf-endpoint", help="Custom Hugging Face endpoint for missing model downloads.") + local.add_argument( + "--on-conflict", + choices=("ignore", "copy", "prepend"), + default="ignore", + help="What to do when the sidecar .txt already exists.", + ) + local.add_argument( + "--replace-underscore", + action=argparse.BooleanOptionalAction, + default=True, + help="Replace underscores with spaces in local tags.", + ) + local.add_argument( + "--escape-tag", + action=argparse.BooleanOptionalAction, + default=True, + help="Escape parentheses and backslashes in local tags.", + ) + + api = subparsers.add_parser( + "api", + help="Call an OpenAI-compatible Chat Completions vision endpoint and write NL captions.", + ) + api.add_argument("--path", required=True, help="Image file, folder, or glob pattern.") + api.add_argument("--endpoint", default=DEFAULT_OPENAI_ENDPOINT, help="Base endpoint, e.g. https://api.openai.com/v1") + api.add_argument("--model", required=True, help="Vision-capable chat model name.") + api.add_argument("--prompt", default=DEFAULT_API_PROMPT, help="Prompt sent together with each image.") + api.add_argument("--api-key", help="API key value. Takes precedence over --api-key-env.") + api.add_argument("--api-key-env", default="OPENAI_API_KEY", help="Environment variable containing the API key.") + api.add_argument("--timeout", type=float, default=60, help="Request timeout in seconds.") + api.add_argument("--retries", type=int, default=2, help="Retry count per image.") + api.add_argument("--recursive", action="store_true", help="Search folders recursively.") + api.add_argument("--additional-tags", default="", help="Comma/newline fragment to append to API captions.") + api.add_argument("--exclude-tags", default="", help="Comma/newline fragment to remove from API captions.") + api.add_argument( + "--on-conflict", + choices=("ignore", "copy", "prepend"), + default="ignore", + help="What to do when the sidecar .txt already exists.", + ) + + caption = subparsers.add_parser( + "caption", + help="Run a local BLIP caption model and write NL captions.", + ) + caption.add_argument("--path", required=True, help="Image file, folder, or glob pattern.") + caption.add_argument( + "--model", + default=DEFAULT_LOCAL_CAPTION_MODEL, + help="Hugging Face BLIP-compatible caption model.", + ) + caption.add_argument("--prompt", default="", help="Optional conditional caption prompt.") + caption.add_argument("--device", default="auto", help="auto, cpu, cuda, or another torch device string.") + caption.add_argument("--max-new-tokens", type=int, default=64, help="Maximum generated caption tokens.") + caption.add_argument("--use-cn-mirror", action="store_true", help="Use https://hf-mirror.com for missing model downloads.") + caption.add_argument("--hf-endpoint", help="Custom Hugging Face endpoint for missing model downloads.") + caption.add_argument("--recursive", action="store_true", help="Search folders recursively.") + caption.add_argument("--additional-tags", default="", help="Comma/newline fragment to append to captions.") + caption.add_argument("--exclude-tags", default="", help="Comma/newline fragment to remove from captions.") + caption.add_argument( + "--on-conflict", + choices=("ignore", "copy", "prepend"), + default="ignore", + help="What to do when the sidecar .txt already exists.", + ) + return parser + + +def resolve_api_key(explicit_key: str | None, env_name: str) -> str | None: + return explicit_key or os.environ.get(env_name) + + +def main(argv: list[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + + try: + if args.command == "local": + result = run_local_tagger( + args.path, + model=args.model, + threshold=args.threshold, + character_threshold=args.character_threshold, + recursive=args.recursive, + additional_tags=args.additional_tags, + exclude_tags=args.exclude_tags, + on_conflict=args.on_conflict, + replace_underscore=args.replace_underscore, + escape_tag=args.escape_tag, + use_cn_mirror=args.use_cn_mirror, + hf_endpoint=args.hf_endpoint, + ) + elif args.command == "api": + api_key = resolve_api_key(args.api_key, args.api_key_env) + if not api_key: + parser.error(f"API key is required: pass --api-key or set {args.api_key_env}") + client = OpenAICompatibleCaptionClient( + endpoint=args.endpoint, + api_key=api_key, + model=args.model, + prompt=args.prompt, + timeout=args.timeout, + retries=args.retries, + ) + result = run_api_tagger( + args.path, + client=client, + recursive=args.recursive, + additional_tags=args.additional_tags, + exclude_tags=args.exclude_tags, + on_conflict=args.on_conflict, + ) + else: + client = LocalBlipCaptionClient( + model=args.model, + prompt=args.prompt, + device=args.device, + max_new_tokens=args.max_new_tokens, + use_cn_mirror=args.use_cn_mirror, + hf_endpoint=args.hf_endpoint, + ) + result = run_caption_tagger( + args.path, + client=client, + recursive=args.recursive, + additional_tags=args.additional_tags, + exclude_tags=args.exclude_tags, + on_conflict=args.on_conflict, + ) + except Exception as error: + print(f"tagger failed: {error}", file=sys.stderr) + return 1 + + print( + "done: " + f"found={result.found}, processed={result.processed}, " + f"skipped={result.skipped}, failed={result.failed}" + ) + return 0 if result.failed == 0 else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/mikazuki/tagger/interrogators/cl.py b/mikazuki/tagger/interrogators/cl.py index 31eebd38..ca775934 100644 --- a/mikazuki/tagger/interrogators/cl.py +++ b/mikazuki/tagger/interrogators/cl.py @@ -14,6 +14,7 @@ from dataclasses import dataclass from mikazuki.tagger import dbimutils, format from mikazuki.tagger.interrogators.base import Interrogator +from mikazuki.tagger.local_models import resolve_local_tagger_files @dataclass @@ -139,6 +140,12 @@ def __init__( self.kwargs = kwargs def download(self) -> Tuple[os.PathLike, os.PathLike]: + local_files = resolve_local_tagger_files(self.name, [self.model_path, self.tag_mapping_path]) + if local_files is not None: + model_path, tag_mapping_path = local_files + print(f"Loading {self.name} model from local files: {model_path.parent}") + return model_path, tag_mapping_path + print(f"Loading {self.name} model file from {self.kwargs['repo_id']}") model_path = Path(hf_hub_download( @@ -265,4 +272,4 @@ def stable_sigmoid(x): # output_text = ", ".join(output_tags) print(predictions) - return predictions \ No newline at end of file + return predictions diff --git a/mikazuki/tagger/interrogators/wd14.py b/mikazuki/tagger/interrogators/wd14.py index 38e8a016..8758b8c4 100644 --- a/mikazuki/tagger/interrogators/wd14.py +++ b/mikazuki/tagger/interrogators/wd14.py @@ -14,6 +14,7 @@ from huggingface_hub import hf_hub_download from mikazuki.tagger.interrogators.base import Interrogator from mikazuki.tagger import dbimutils, format +from mikazuki.tagger.local_models import resolve_local_tagger_files class WaifuDiffusionInterrogator(Interrogator): @@ -30,6 +31,12 @@ def __init__( self.kwargs = kwargs def download(self) -> Tuple[os.PathLike, os.PathLike]: + local_files = resolve_local_tagger_files(self.name, [self.model_path, self.tags_path]) + if local_files is not None: + model_path, tags_path = local_files + print(f"Loading {self.name} model from local files: {model_path.parent}") + return model_path, tags_path + repo_id = self.kwargs["repo_id"] print(f"Loading {self.name} model from {repo_id} (first run may download ~400MB, see console log)") diff --git a/mikazuki/tagger/local_models.py b/mikazuki/tagger/local_models.py new file mode 100644 index 00000000..704cfe76 --- /dev/null +++ b/mikazuki/tagger/local_models.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import os +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[2] +DEFAULT_TAGGER_DIRS = ( + REPO_ROOT / "taggers", + REPO_ROOT / "models" / "taggers", + REPO_ROOT / "huggingface" / "taggers", +) + + +def iter_tagger_roots() -> list[Path]: + roots: list[Path] = [] + env_value = os.environ.get("MIKAZUKI_TAGGER_DIR", "") + for item in env_value.split(os.pathsep): + if item.strip(): + roots.append(Path(item).expanduser().resolve()) + roots.extend(path.resolve() for path in DEFAULT_TAGGER_DIRS) + return roots + + +def _candidate_file(model_dir: Path, relative_path: str) -> list[Path]: + rel = Path(relative_path) + return [ + model_dir / rel, + model_dir / rel.name, + ] + + +def resolve_local_tagger_files(model_key: str, required_files: list[str]) -> tuple[Path, ...] | None: + for root in iter_tagger_roots(): + model_dirs = [root] + if root.name != model_key: + model_dirs.insert(0, root / model_key) + + for model_dir in model_dirs: + resolved: list[Path] = [] + for required_file in required_files: + found = next( + (candidate for candidate in _candidate_file(model_dir, required_file) if candidate.is_file()), + None, + ) + if found is None: + break + resolved.append(found) + else: + return tuple(resolved) + + return None diff --git a/mikazuki/tagger/service.py b/mikazuki/tagger/service.py new file mode 100644 index 00000000..99d47aeb --- /dev/null +++ b/mikazuki/tagger/service.py @@ -0,0 +1,449 @@ +from __future__ import annotations + +import base64 +import copy +import mimetypes +import os +import time +from collections import OrderedDict +from dataclasses import dataclass +from glob import glob +from pathlib import Path +from typing import Iterable, Protocol + +import httpx +from PIL import Image, UnidentifiedImageError + +from mikazuki.tagger.interrogator import available_interrogators +from mikazuki.tagger.interrogators.base import Interrogator + + +DEFAULT_OPENAI_ENDPOINT = "https://api.openai.com/v1" +DEFAULT_API_PROMPT = ( + "Describe this image for image model training. Return a concise natural-language " + "caption only, without markdown or explanations." +) +DEFAULT_LOCAL_CAPTION_MODEL = "Salesforce/blip-image-captioning-base" +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def ensure_project_hf_home() -> Path: + hf_home = Path(os.environ.get("HF_HOME", REPO_ROOT / "huggingface")).resolve() + os.environ["HF_HOME"] = str(hf_home) + hf_home.mkdir(parents=True, exist_ok=True) + return hf_home + + +def configure_hf_download( + *, + use_cn_mirror: bool = False, + hf_endpoint: str | None = None, +) -> Path: + hf_home = ensure_project_hf_home() + if hf_endpoint: + os.environ["HF_ENDPOINT"] = hf_endpoint + elif use_cn_mirror and not os.environ.get("HF_ENDPOINT"): + os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" + return hf_home + + +@dataclass +class TaggerRunResult: + found: int = 0 + processed: int = 0 + skipped: int = 0 + failed: int = 0 + + +class CaptionClient(Protocol): + def caption(self, image_path: Path) -> str: + ... + + +def split_csv(value: str | Iterable[str] | None) -> list[str]: + if value is None: + return [] + if isinstance(value, str): + parts = value.split(",") + else: + parts = value + return [str(part).strip() for part in parts if str(part).strip()] + + +def collect_image_paths(input_path: str, recursive: bool = False) -> list[Path]: + path_text = input_path.strip() + if not path_text: + return [] + + Image.init() + supported_extensions = { + ext.lower() + for ext, opener in Image.registered_extensions().items() + if opener in Image.OPEN + } + + candidate_paths: list[Path] + path = Path(path_text) + if any(mark in path_text for mark in ("*", "?")): + candidate_paths = [Path(item) for item in glob(path_text, recursive=recursive)] + elif path.is_dir(): + pattern = "**/*" if recursive else "*" + candidate_paths = [item for item in path.glob(pattern)] + elif path.is_file(): + candidate_paths = [path] + else: + raise ValueError(f"input path does not exist: {input_path}") + + return sorted( + item + for item in candidate_paths + if item.is_file() and item.suffix.lower() in supported_extensions + ) + + +def merge_output( + existing_text: str, + new_text: str, + *, + on_conflict: str, + remove_duplicated_tag: bool = True, +) -> str: + if not existing_text: + merged = new_text + elif on_conflict == "copy": + merged = new_text + elif on_conflict == "prepend": + merged = f"{new_text}, {existing_text}" + else: + merged = f"{existing_text}, {new_text}" + + if not remove_duplicated_tag: + return merged.strip() + + return ", ".join( + OrderedDict.fromkeys(part.strip() for part in merged.split(",") if part.strip()) + ) + + +def apply_text_fragments(text: str, additional_tags: str = "", exclude_tags: str = "") -> str: + additions = split_csv(additional_tags) + excludes = set(split_csv(exclude_tags)) + fragments = [ + fragment.strip() + for line in text.splitlines() + for fragment in line.split(",") + if fragment.strip() + ] + fragments = [fragment for fragment in fragments if fragment not in excludes] + fragments.extend(addition for addition in additions if addition not in excludes) + return ", ".join(OrderedDict.fromkeys(fragments)) + + +def write_caption( + image_path: Path, + caption: str, + *, + on_conflict: str, + remove_duplicated_tag: bool = True, +) -> bool: + output_path = image_path.with_suffix(".txt") + if output_path.is_file(): + existing_text = output_path.read_text(encoding="utf-8", errors="ignore").strip() + if on_conflict == "ignore": + return False + else: + existing_text = "" + + output_path.write_text( + merge_output( + existing_text, + caption, + on_conflict=on_conflict, + remove_duplicated_tag=remove_duplicated_tag, + ), + encoding="utf-8", + ) + return True + + +def run_local_tagger( + input_path: str, + *, + model: str = "wd14-convnextv2-v2", + threshold: float = 0.35, + character_threshold: float = 0.6, + recursive: bool = False, + additional_tags: str = "", + exclude_tags: str = "", + on_conflict: str = "ignore", + replace_underscore: bool = True, + replace_underscore_excludes: str = "", + escape_tag: bool = True, + add_rating_tag: bool = False, + add_model_tag: bool = False, + unload_model_after_running: bool = True, + use_cn_mirror: bool = False, + hf_endpoint: str | None = None, +) -> TaggerRunResult: + configure_hf_download(use_cn_mirror=use_cn_mirror, hf_endpoint=hf_endpoint) + image_paths = collect_image_paths(input_path, recursive=recursive) + result = TaggerRunResult(found=len(image_paths)) + interrogator = available_interrogators.get(model, available_interrogators["wd14-convnextv2-v2"]) + + try: + for image_path in image_paths: + if image_path.with_suffix(".txt").is_file() and on_conflict == "ignore": + result.skipped += 1 + print(f"skipping {image_path}") + continue + + try: + with Image.open(image_path) as image: + tags = interrogator.interrogate(image) + except UnidentifiedImageError: + result.failed += 1 + print(f"{image_path} is not supported image type") + continue + except Exception: + result.failed += 1 + raise + + processed_tags = Interrogator.postprocess_tags( + copy.deepcopy(tags), + threshold, + character_threshold, + add_rating_tag, + add_model_tag, + split_csv(additional_tags), + split_csv(exclude_tags), + False, + False, + replace_underscore, + split_csv(replace_underscore_excludes), + escape_tag, + ) + caption = ", ".join(processed_tags) + if write_caption(image_path, caption, on_conflict=on_conflict): + result.processed += 1 + print(f"tagged {image_path} ({len(processed_tags)} tags)") + else: + result.skipped += 1 + print(f"skipping {image_path}") + finally: + if unload_model_after_running: + interrogator.unload() + + return result + + +class OpenAICompatibleCaptionClient: + def __init__( + self, + *, + endpoint: str = DEFAULT_OPENAI_ENDPOINT, + api_key: str, + model: str, + prompt: str = DEFAULT_API_PROMPT, + timeout: float = 60, + retries: int = 2, + http_client: httpx.Client | None = None, + ) -> None: + self.url = build_chat_completions_url(endpoint) + self.api_key = api_key + self.model = model + self.prompt = prompt + self.timeout = timeout + self.retries = retries + self.http_client = http_client + + def caption(self, image_path: Path) -> str: + payload = { + "model": self.model, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": self.prompt}, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(image_path)}, + }, + ], + } + ], + } + headers = {"Authorization": f"Bearer {self.api_key}"} + last_error: Exception | None = None + + for attempt in range(self.retries + 1): + try: + if self.http_client is None: + with httpx.Client(timeout=self.timeout) as client: + response = client.post(self.url, headers=headers, json=payload) + else: + response = self.http_client.post(self.url, headers=headers, json=payload) + response.raise_for_status() + data = response.json() + return parse_openai_caption(data) + except Exception as error: + last_error = error + if attempt >= self.retries: + break + time.sleep(min(2 ** attempt, 5)) + + raise RuntimeError(f"API tagging failed for {image_path}: {last_error}") from last_error + + +class LocalBlipCaptionClient: + def __init__( + self, + *, + model: str = DEFAULT_LOCAL_CAPTION_MODEL, + prompt: str = "", + device: str = "auto", + max_new_tokens: int = 64, + use_cn_mirror: bool = False, + hf_endpoint: str | None = None, + ) -> None: + self.model_name = model + self.prompt = prompt + self.device_name = device + self.max_new_tokens = max_new_tokens + self.use_cn_mirror = use_cn_mirror + self.hf_endpoint = hf_endpoint + self.processor = None + self.model = None + self.device = None + + def load(self) -> None: + if self.model is not None and self.processor is not None: + return + + configure_hf_download(use_cn_mirror=self.use_cn_mirror, hf_endpoint=self.hf_endpoint) + print(f"[tagger] Loading local caption model: {self.model_name}") + print("[tagger] First run may download model files; Hugging Face progress is shown in the console.") + + import torch + from transformers import BlipForConditionalGeneration, BlipProcessor + + if self.device_name == "auto": + self.device = "cuda" if torch.cuda.is_available() else "cpu" + else: + self.device = self.device_name + + self.processor = BlipProcessor.from_pretrained(self.model_name) + self.model = BlipForConditionalGeneration.from_pretrained(self.model_name) + self.model.to(self.device) + self.model.eval() + print(f"[tagger] Loaded {self.model_name} on {self.device}") + + def caption(self, image_path: Path) -> str: + self.load() + + import torch + from PIL import Image + + with Image.open(image_path) as image: + image = image.convert("RGB") + if self.prompt: + inputs = self.processor(image, self.prompt, return_tensors="pt") + else: + inputs = self.processor(image, return_tensors="pt") + + inputs = {key: value.to(self.device) for key, value in inputs.items()} + with torch.no_grad(): + output = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens) + caption = self.processor.decode(output[0], skip_special_tokens=True).strip() + + if not caption: + raise ValueError(f"local caption model returned an empty caption for {image_path}") + return caption + + +def build_chat_completions_url(endpoint: str) -> str: + endpoint = endpoint.rstrip("/") + if endpoint.endswith("/chat/completions"): + return endpoint + return f"{endpoint}/chat/completions" + + +def image_to_data_url(image_path: Path) -> str: + mime_type, _ = mimetypes.guess_type(str(image_path)) + if not mime_type: + mime_type = "image/png" + data = base64.b64encode(image_path.read_bytes()).decode("ascii") + return f"data:{mime_type};base64,{data}" + + +def parse_openai_caption(data: dict) -> str: + try: + content = data["choices"][0]["message"]["content"] + except (KeyError, IndexError, TypeError) as exc: + raise ValueError("OpenAI-compatible response does not contain choices[0].message.content") from exc + + if isinstance(content, str): + caption = content.strip() + elif isinstance(content, list): + caption = " ".join( + item.get("text", "").strip() + for item in content + if isinstance(item, dict) and item.get("type") in {None, "text", "output_text"} + ).strip() + else: + caption = "" + + if not caption: + raise ValueError("OpenAI-compatible response caption is empty") + return caption + + +def run_api_tagger( + input_path: str, + *, + client: CaptionClient, + recursive: bool = False, + additional_tags: str = "", + exclude_tags: str = "", + on_conflict: str = "ignore", +) -> TaggerRunResult: + image_paths = collect_image_paths(input_path, recursive=recursive) + result = TaggerRunResult(found=len(image_paths)) + + for image_path in image_paths: + output_path = image_path.with_suffix(".txt") + if output_path.is_file() and on_conflict == "ignore": + result.skipped += 1 + print(f"skipping {image_path}") + continue + + try: + caption = client.caption(image_path) + caption = apply_text_fragments(caption, additional_tags, exclude_tags) + write_caption(image_path, caption, on_conflict=on_conflict) + result.processed += 1 + print(f"captioned {image_path}") + except Exception as error: + result.failed += 1 + print(f"failed {image_path}: {error}") + raise + + return result + + +def run_caption_tagger( + input_path: str, + *, + client: CaptionClient, + recursive: bool = False, + additional_tags: str = "", + exclude_tags: str = "", + on_conflict: str = "ignore", +) -> TaggerRunResult: + return run_api_tagger( + input_path, + client=client, + recursive=recursive, + additional_tags=additional_tags, + exclude_tags=exclude_tags, + on_conflict=on_conflict, + ) diff --git a/scripts/cli/tagger.ps1 b/scripts/cli/tagger.ps1 new file mode 100644 index 00000000..52826027 --- /dev/null +++ b/scripts/cli/tagger.ps1 @@ -0,0 +1,20 @@ +$ErrorActionPreference = "Stop" + +$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path +$RepoRoot = Resolve-Path (Join-Path $ScriptDir "..\..") +Set-Location $RepoRoot + +if (-not $Env:HF_HOME) { + $Env:HF_HOME = "huggingface" +} + +if ($Env:USE_CN_MIRROR -and -not $Env:HF_ENDPOINT) { + $Env:HF_ENDPOINT = "https://hf-mirror.com" +} + +$VenvPython = Join-Path $RepoRoot "venv\Scripts\python.exe" +if (Test-Path $VenvPython) { + & $VenvPython -m mikazuki.tagger.cli @args +} else { + python -m mikazuki.tagger.cli @args +} diff --git a/scripts/cli/tagger.sh b/scripts/cli/tagger.sh new file mode 100644 index 00000000..e38edcc1 --- /dev/null +++ b/scripts/cli/tagger.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +cd "${REPO_ROOT}" + +export HF_HOME="${HF_HOME:-huggingface}" + +if [[ -n "${USE_CN_MIRROR:-}" && -z "${HF_ENDPOINT:-}" ]]; then + export HF_ENDPOINT="https://hf-mirror.com" +fi + +if [[ -x "venv/bin/python" ]]; then + exec "venv/bin/python" -m mikazuki.tagger.cli "$@" +else + exec python -m mikazuki.tagger.cli "$@" +fi diff --git a/tests/test_tagger_cli_service.py b/tests/test_tagger_cli_service.py new file mode 100644 index 00000000..d0a04eec --- /dev/null +++ b/tests/test_tagger_cli_service.py @@ -0,0 +1,263 @@ +import os +import tempfile +import unittest +from pathlib import Path + +from PIL import Image + +from mikazuki.tagger import cli +from mikazuki.tagger import local_models +from mikazuki.tagger import service +from mikazuki.tagger.cli import resolve_api_key + + +class FakeInterrogator: + def __init__(self): + self.calls = 0 + self.unloaded = False + + def interrogate(self, image): + self.calls += 1 + return { + "rating": [("general", 0.9)], + "general": [("blue_hair", 0.8), ("lowres", 0.7), ("red_eyes", 0.2)], + "character": [("test_character", 0.7)], + "model": [], + } + + def unload(self): + self.unloaded = True + return True + + +class FakeResponse: + def __init__(self, data, status_code=200): + self.data = data + self.status_code = status_code + + def raise_for_status(self): + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + def json(self): + return self.data + + +class FakeHttpClient: + def __init__(self, response): + self.response = response + self.calls = [] + + def post(self, url, headers=None, json=None): + self.calls.append({"url": url, "headers": headers, "json": json}) + return self.response + + +class TaggerServiceTests(unittest.TestCase): + def make_image(self, path: Path): + Image.new("RGB", (8, 8), color=(255, 0, 0)).save(path) + + def test_collect_image_paths_supports_folder_glob_recursive_and_skips_non_images(self): + with tempfile.TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + self.make_image(root / "a.png") + (root / "note.txt").write_text("not an image", encoding="utf-8") + child = root / "child" + child.mkdir() + self.make_image(child / "b.jpg") + + self.assertEqual([p.name for p in service.collect_image_paths(str(root))], ["a.png"]) + self.assertEqual( + [p.name for p in service.collect_image_paths(str(root / "*.png"))], + ["a.png"], + ) + self.assertEqual( + [p.name for p in service.collect_image_paths(str(root), recursive=True)], + ["a.png", "b.jpg"], + ) + + def test_local_tagger_applies_thresholds_tag_edits_dedupe_and_conflict(self): + fake = FakeInterrogator() + original = service.available_interrogators.get("fake") + service.available_interrogators["fake"] = fake + try: + with tempfile.TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + image_path = root / "a.png" + self.make_image(image_path) + image_path.with_suffix(".txt").write_text("old_tag", encoding="utf-8") + + skipped = service.run_local_tagger(str(root), model="fake", on_conflict="ignore") + self.assertEqual(skipped.skipped, 1) + self.assertEqual(fake.calls, 0) + + result = service.run_local_tagger( + str(root), + model="fake", + threshold=0.35, + character_threshold=0.6, + additional_tags="blue hair, best quality", + exclude_tags="lowres", + on_conflict="copy", + replace_underscore=True, + escape_tag=False, + ) + + self.assertEqual(result.processed, 1) + self.assertEqual( + image_path.with_suffix(".txt").read_text(encoding="utf-8"), + "blue hair, best quality, test character", + ) + self.assertTrue(fake.unloaded) + finally: + if original is None: + del service.available_interrogators["fake"] + else: + service.available_interrogators["fake"] = original + + def test_openai_client_posts_expected_chat_completions_payload(self): + with tempfile.TemporaryDirectory() as temp_dir: + image_path = Path(temp_dir) / "a.png" + self.make_image(image_path) + http_client = FakeHttpClient( + FakeResponse({"choices": [{"message": {"content": "a concise caption"}}]}) + ) + client = service.OpenAICompatibleCaptionClient( + endpoint="https://example.test/v1/", + api_key="secret", + model="vision-model", + prompt="caption it", + retries=0, + http_client=http_client, + ) + + caption = client.caption(image_path) + + self.assertEqual(caption, "a concise caption") + call = http_client.calls[0] + self.assertEqual(call["url"], "https://example.test/v1/chat/completions") + self.assertEqual(call["headers"]["Authorization"], "Bearer secret") + self.assertEqual(call["json"]["model"], "vision-model") + content = call["json"]["messages"][0]["content"] + self.assertEqual(content[0]["text"], "caption it") + self.assertTrue(content[1]["image_url"]["url"].startswith("data:image/png;base64,")) + + def test_api_tagger_writes_nl_caption_with_fragment_edits(self): + class StaticClient: + def caption(self, image_path): + return "a girl, lowres\nstanding" + + with tempfile.TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + image_path = root / "a.png" + self.make_image(image_path) + + result = service.run_api_tagger( + str(root), + client=StaticClient(), + additional_tags="best quality", + exclude_tags="lowres", + on_conflict="copy", + ) + + self.assertEqual(result.processed, 1) + self.assertEqual( + image_path.with_suffix(".txt").read_text(encoding="utf-8"), + "a girl, standing, best quality", + ) + + def test_caption_tagger_reuses_folder_writer_for_local_nl_models(self): + class StaticCaptionClient: + def caption(self, image_path): + return "a small test image" + + with tempfile.TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + image_path = root / "a.png" + self.make_image(image_path) + + result = service.run_caption_tagger( + str(root), + client=StaticCaptionClient(), + additional_tags="training caption", + on_conflict="copy", + ) + + self.assertEqual(result.processed, 1) + self.assertEqual( + image_path.with_suffix(".txt").read_text(encoding="utf-8"), + "a small test image, training caption", + ) + + def test_cli_exposes_local_api_and_caption_modes(self): + parser = cli.build_parser() + + local_args = parser.parse_args(["local", "--path", "input"]) + api_args = parser.parse_args(["api", "--path", "input", "--model", "vision"]) + caption_args = parser.parse_args(["caption", "--path", "input"]) + + self.assertEqual(local_args.command, "local") + self.assertEqual(api_args.command, "api") + self.assertEqual(caption_args.command, "caption") + self.assertEqual(caption_args.model, service.DEFAULT_LOCAL_CAPTION_MODEL) + + def test_local_tagger_files_can_be_resolved_from_user_directory(self): + old_value = os.environ.get("MIKAZUKI_TAGGER_DIR") + try: + with tempfile.TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + model_dir = root / "wd14-convnextv2-v2" + model_dir.mkdir() + (model_dir / "model.onnx").write_bytes(b"fake onnx") + (model_dir / "selected_tags.csv").write_text("name\nfake_tag\n", encoding="utf-8") + + os.environ["MIKAZUKI_TAGGER_DIR"] = str(root) + resolved = local_models.resolve_local_tagger_files( + "wd14-convnextv2-v2", + ["model.onnx", "selected_tags.csv"], + ) + + self.assertEqual(resolved, (model_dir / "model.onnx", model_dir / "selected_tags.csv")) + finally: + if old_value is None: + os.environ.pop("MIKAZUKI_TAGGER_DIR", None) + else: + os.environ["MIKAZUKI_TAGGER_DIR"] = old_value + + def test_hf_download_config_uses_project_cache_and_optional_mirror(self): + old_home = os.environ.get("HF_HOME") + old_endpoint = os.environ.get("HF_ENDPOINT") + try: + os.environ.pop("HF_HOME", None) + os.environ.pop("HF_ENDPOINT", None) + + hf_home = service.configure_hf_download(use_cn_mirror=True) + + self.assertEqual(hf_home, (service.REPO_ROOT / "huggingface").resolve()) + self.assertEqual(os.environ["HF_HOME"], str((service.REPO_ROOT / "huggingface").resolve())) + self.assertEqual(os.environ["HF_ENDPOINT"], "https://hf-mirror.com") + finally: + if old_home is None: + os.environ.pop("HF_HOME", None) + else: + os.environ["HF_HOME"] = old_home + if old_endpoint is None: + os.environ.pop("HF_ENDPOINT", None) + else: + os.environ["HF_ENDPOINT"] = old_endpoint + + def test_api_key_explicit_value_takes_precedence_over_environment(self): + old_value = os.environ.get("TAGGER_TEST_KEY") + os.environ["TAGGER_TEST_KEY"] = "from-env" + try: + self.assertEqual(resolve_api_key("explicit", "TAGGER_TEST_KEY"), "explicit") + self.assertEqual(resolve_api_key(None, "TAGGER_TEST_KEY"), "from-env") + finally: + if old_value is None: + os.environ.pop("TAGGER_TEST_KEY", None) + else: + os.environ["TAGGER_TEST_KEY"] = old_value + + +if __name__ == "__main__": + unittest.main()