From 5be958b1df3dfd925b9901c6777cee56903ca2dc Mon Sep 17 00:00:00 2001 From: Net Zhang Date: Mon, 15 Jun 2026 13:27:36 -0400 Subject: [PATCH] Replace multiprocessing DataLoader with thread-based embedding pipeline Image preprocessing is GIL-light as PIL/torchvision release the GIL, therefore a thread pool parallelizes it near-linearly without process spawn, pickling, or fork/spawn divergence... This commit introduce `shared/utils/image_pipeline.py`: thread-pool decode+preprocess overlapped with a main-thread batch forward. The embedding service implements device-aware concurrency - GPU = wide pool + intra-op threads pinned to 1 - CPU = small pool, forward keeps the cores This commit drops the hpc-inference git dependency entirely, unblocks PyPI release. Co-Authored-By: Claude Opus 4.8 --- pyproject.toml | 4 +- shared/services/embedding_service.py | 168 +++++++++++++++++++-------- shared/utils/image_pipeline.py | 154 ++++++++++++++++++++++++ shared/utils/io.py | 31 +++-- 4 files changed, 296 insertions(+), 61 deletions(-) create mode 100644 shared/utils/image_pipeline.py diff --git a/pyproject.toml b/pyproject.toml index 2f2f4b5..b19de98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,8 +48,8 @@ dependencies = [ "numba>=0.57.0", # Vision-language models "open-clip-torch>=2.20.0", - # Custom inference package - "hpc-inference @ git+https://github.com/Imageomics/hpc-inference.git" + # Image embedding pipeline (used directly in shared/utils/image_pipeline.py) + "torch>=1.11.0", ] [project.optional-dependencies] diff --git a/shared/services/embedding_service.py b/shared/services/embedding_service.py index 3b82e28..0eb2399 100644 --- a/shared/services/embedding_service.py +++ b/shared/services/embedding_service.py @@ -3,8 +3,68 @@ Heavy libraries (torch, open_clip) are imported lazily inside methods to avoid slowing down app startup. + +Device-aware concurrency: + +PyTorch has two kinds of parallelism built in, we focus on the intra-op +parallelism which is relevant to the embedding pipeline: + +Intra-op is the parallelism inside a single operation. One op, say +`Normalize` on a `[3, 244, 244]` tensor, or a big matrix multiply, splits its +own work across multiple threads (via an openMP/MKL thread pool). +`torch.get_num_threads()` queries how many threads one op may use, and +`torch.set_num_threads(n)` sets it. + +A single `preprocess(img)` is a chain of torch ops (resize -> to_tensor -> +normalize). With the default intra-op thread settings, each of those ops can +fan its work out across all CPU cores. So ONE preprocess call of one image +can momentarily spin up ~`cpu_count` threads to do that tiny bit of math. + +^^^ Why that's wasteful here? + +Since we already have our own parallelism layer at image level: the +`ThreadPoolExecutor` runs `workers` threads, one image per thread, and each +thread calls `preprocess(img)`. If each preprocess call fans out across all +CPU cores, then `workers` threads can easily oversubscribe the CPU with +`workers * cpu_count` threads. This causes contention and can actually slow +down the whole process. + +``` +Layer 1 (ThreadPoolExecutor): 16 worker threads, each handling one image preprocess +Layer 2 (torch intra-op): x Each preprocess call can use up to `cpu_count` threads + ======================================================== + Total threads = 16 (workers) * cpu_count (intra-op) => + Potentially 256 threads on a 16-core machine, + causing oversubscription and slowdown. +``` + +By setting `torch.set_num_threads(1)`, we ensure that each preprocess call +runs single-thread, no internal spliting. All parallelism comes cleanly from +one place - the `ThreadPoolExecutor`. Instead of two nested layers that +multiply into a thread explosion, each core does one useful thing (decode a +whole image) with no scheduling thrash and no per-op thread-launch overhead. + +``` +Layer 1 (ThreadPoolExecutor): 16 worker threads, each handling one image preprocess +Layer 2 (torch intra-op): x 1 (each op runs single-threaded, instantly) + ======================================================== + Total threads = 16 (workers) * 1 (intra-op) => + Potentially 16 threads on a 16-core machine, + fully utilizing the CPU without oversubscription. +``` + +What Intra-op is good for? + +Intra-op parallelism is excellent for big ops. On the CPU-only path, the +forward pass of the model is the bottleneck, and it benefits from intra-op +parallelism. So we leave torch's intra-op threads alone on CPU, and cap the +worker threads to a small number (2) to avoid too much contention. On GPU, +the forward pass is fast and doesn't need CPU cores, so we maximize worker +threads for decoding and set intra-op to 1 to avoid oversubscription. + """ +import os import numpy as np import streamlit as st import time @@ -79,90 +139,96 @@ def generate_embeddings( model_name: str, batch_size: int, n_workers: int, - progress_callback: Optional[Callable[[float, str], None]] = None + progress_callback: Optional[Callable[[float, str], None]] = None, + recursive: bool = False, ) -> Tuple[np.ndarray, List[str]]: """ Generate embeddings for images in a directory. + Preprocessing runs on a thread pool (GIL-light) overlapped with the model + forward pass — no multiprocessing, so behavior is identical on every OS. + Args: image_dir: Path to directory containing images model_name: Name of the model to use - batch_size: Batch size for processing - n_workers: Number of worker processes + batch_size: Batch size for the forward pass + n_workers: Max preprocessing threads (capped per device, see below) progress_callback: Optional callback for progress updates + recursive: Recurse into subdirectories when listing images Returns: Tuple of (embeddings array, list of valid image paths) """ import torch - from hpc_inference.datasets.image_folder_dataset import ImageFolderDataset + from shared.utils.image_pipeline import embed_image_folder logger.info(f"Starting embedding generation: dir={image_dir}, model={model_name}, " - f"batch_size={batch_size}, n_workers={n_workers}") + f"batch_size={batch_size}, n_workers={n_workers}, recursive={recursive}") total_start = time.time() if progress_callback: progress_callback(0.0, "Listing images...") - image_paths = list_image_files(image_dir) - logger.info(f"Found {len(image_paths)} images in {image_dir}") + image_paths = list_image_files(image_dir, recursive=recursive) + total = len(image_paths) + logger.info(f"Found {total} images in {image_dir}") if progress_callback: - progress_callback(0.1, f"Found {len(image_paths)} images. Loading model...") + progress_callback(0.05, f"Found {total} images. Loading model...") torch_device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(torch_device) logger.info(f"Using device: {torch_device}") model, preprocess = EmbeddingService.load_model_unified(model_name, torch_device) - if progress_callback: - progress_callback(0.2, "Creating dataset...") - - # Create dataset & DataLoader - dataset = ImageFolderDataset( - image_dir=image_dir, - preprocess=preprocess, - uuid_mode="fullpath", - rank=0, - world_size=1, - evenly_distribute=True, - validate=True - ) - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=batch_size, - shuffle=False, - num_workers=n_workers, - pin_memory=True - ) + # Device-aware concurrency: + cpu_count = os.cpu_count() or 1 + prev_threads = None + + if device.type == "cuda": + # GPU: feed the GPU with parallel decode, avoid per-op oversubscription. + # - preprocess threads: wide + # - torch intra-op threads: forced to 1 + + # Set the number of preprocessing threads, clamped by three ceilings: + # 1) the user-requested n_workers + # 2) the number of CPU cores + # 3) never more threads than images + workers = max(1, min(n_workers, cpu_count, max(total, 1))) + + prev_threads = torch.get_num_threads() + torch.set_num_threads(1) - total = len(image_paths) - valid_paths = [] - embeddings = [] - - processed = 0 - with torch.no_grad(): - for batch_paths, batch_imgs in dataloader: - batch_imgs = batch_imgs.to(torch_device, non_blocking=True) - batch_embeds = model.encode_image(batch_imgs).cpu().numpy() - embeddings.append(batch_embeds) - valid_paths.extend(batch_paths) - processed += len(batch_paths) - - if progress_callback: - progress = 0.2 + (processed / total) * 0.8 # Use 20% to 100% for actual processing - progress_callback(progress, f"Embedding {processed}/{total}") - - # Stack embeddings if available - if embeddings: - embeddings = np.vstack(embeddings) else: - embeddings = np.empty((0, model.visual.output_dim)) + # CPU: the CPU forward is the bottleneck, needs the cores, + # so keep preprocess pool small and leave torch threads alone. + workers = max(1, min(2, n_workers, max(total, 1))) + + # Map the pipeline's 0..1 progress into the 0.1..1.0 band (model load took 0..0.1). + def _embed_progress(frac: float, msg: str): + if progress_callback: + progress_callback(0.1 + 0.9 * frac, msg) + + try: + embeddings, valid_paths = embed_image_folder( + image_paths, + model, + preprocess, + device, + batch_size=batch_size, + n_workers=workers, + progress_callback=_embed_progress, + ) + finally: + if prev_threads is not None: + torch.set_num_threads(prev_threads) if progress_callback: progress_callback(1.0, f"Complete! Generated {embeddings.shape[0]} embeddings") total_elapsed = time.time() - total_start - logger.info(f"Embedding generation completed: {embeddings.shape[0]} embeddings in {total_elapsed:.2f}s " - f"({embeddings.shape[0] / total_elapsed:.1f} images/sec)") + rate = embeddings.shape[0] / total_elapsed if total_elapsed > 0 else 0.0 + logger.info(f"Embedding generation completed: {embeddings.shape[0]} embeddings in " + f"{total_elapsed:.2f}s ({rate:.1f} images/sec)") return embeddings, valid_paths diff --git a/shared/utils/image_pipeline.py b/shared/utils/image_pipeline.py new file mode 100644 index 0000000..7251a48 --- /dev/null +++ b/shared/utils/image_pipeline.py @@ -0,0 +1,154 @@ +"""Thread-parallel image embedding pipeline. + +Turns a list of image paths into embeddings on a single machine. Preprocessing +(decode + transform) runs on a thread pool; the model forward runs on the +calling thread that owns the device. Each batch is preprocessed while the +previous batch runs through the model (a one-batch prefetch), so CPU decoding +and the device forward overlap. + +Threads — rather than worker processes — carry the preprocessing because the +work is GIL-light: PIL decode and torchvision tensor ops release the GIL, so a +thread pool scales nearly linearly. Staying in one process means no per-image +data crosses a process boundary and there is no worker-spawn cost, so small +folders are cheap and behavior does not depend on the OS. + +This module is Streamlit-free and unit-testable. +""" + +from __future__ import annotations + +import concurrent.futures as cf +from collections import deque +from typing import Callable, List, Optional, Tuple + +import numpy as np +import torch +from PIL import Image + +from shared.utils.logging_config import get_logger + +logger = get_logger(__name__) + + +def _output_dim(model) -> int: + """Best-effort embedding width, for shaping an empty result.""" + return int(getattr(getattr(model, "visual", None), "output_dim", 0) or 0) + + +def _preprocess_one(path: str, preprocess: Callable, color_mode: str): + """Decode + preprocess a single image. + + Returns ``(path, tensor)`` on success or ``(path, None)`` if the file can't + be read/decoded. Pure and device-free, so it is safe on worker threads. + """ + try: + with Image.open(path) as im: + img = im.convert(color_mode) + return path, preprocess(img) + except Exception as e: + logger.warning(f"[Embed] Skipping unreadable image {path}: {e}") + return path, None + + +def embed_image_folder( + image_paths: List[str], + model, + preprocess: Callable, + device: torch.device, + *, + batch_size: int = 32, + n_workers: int = 8, + prefetch_batches: int = 1, + color_mode: str = "RGB", + progress_callback: Optional[Callable[[float, str], None]] = None, +) -> Tuple[np.ndarray, List[str]]: + """Embed a list of image paths, overlapping preprocessing with the forward. + + Preprocessing runs on a ``ThreadPoolExecutor``; the model forward runs on the + calling thread (which owns ``device``). Up to ``prefetch_batches`` batches are + preprocessed ahead of the batch currently being run through the model. + + Unreadable images are skipped (and logged), so the returned embeddings may + have fewer rows than ``image_paths``. ``embeddings[i]`` corresponds to + ``valid_paths[i]``. + + Args: + image_paths: Image file paths to embed. + model: Model exposing ``encode_image(tensor) -> tensor``. + preprocess: Callable mapping a PIL image to a CHW tensor. + device: Torch device the model lives on. + batch_size: Images per forward pass. + n_workers: Preprocessing threads. + prefetch_batches: Batches to preprocess ahead of the forward (overlap). + color_mode: PIL convert mode applied before preprocessing. + progress_callback: Optional ``(fraction, message)`` progress sink. + + Returns: + ``(embeddings [N, D] float array, valid_paths [N])``. + """ + total = len(image_paths) + if total == 0: + return np.empty((0, _output_dim(model)), dtype=np.float32), [] + + batches = [image_paths[i:i + batch_size] for i in range(0, total, batch_size)] + window = max(1, prefetch_batches + 1) + + emb_chunks: List[np.ndarray] = [] + valid_paths: List[str] = [] + processed = 0 + + # concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) + # spins up `n_workers` OS threads sitting idle, waiting for work... + # hand it work with ex.submit(fn, *args), which returns a Future immediately, + # and runs fn(*args) on a worker thread when it gets scheduled by the OS... + with cf.ThreadPoolExecutor(max_workers=n_workers) as ex: + + # non-blocking, starts the preprocessing of a batch on the pool + def submit(batch: List[str]) -> List[cf.Future]: + return [ex.submit(_preprocess_one, p, preprocess, color_mode) for p in batch] + + # Prime the pipeline so the first forward already has successors decoding. + # pending is a queue of lists of futures, one list per batch. + pending: deque = deque() + next_idx = 0 + while next_idx < len(batches) and len(pending) < window: + pending.append(submit(batches[next_idx])) + next_idx += 1 + # pending is now a full window of batches being preprocessed + + with torch.no_grad(): + while pending: + # Take the oldest in-flight batch + futures = pending.popleft() + # Refill the window first: these batches preprocess on the pool + # while we run the current batch through the model below. + if next_idx < len(batches): + pending.append(submit(batches[next_idx])) + next_idx += 1 + + # If the worker alreadt=y finished, returns immediately; + # otherwise blocks until the batch is ready. + results = [f.result() for f in futures] + batch_paths = [p for p, t in results if t is not None] + tensors = [t for _, t in results if t is not None] + + if tensors: + x = torch.stack(tensors).to(device) + feats = model.encode_image(x).cpu().numpy() + emb_chunks.append(feats) + valid_paths.extend(batch_paths) + + processed += len(futures) + if progress_callback: + progress_callback(processed / total, f"Embedding {processed}/{total}") + + if emb_chunks: + embeddings = np.vstack(emb_chunks) + else: + embeddings = np.empty((0, _output_dim(model)), dtype=np.float32) + + logger.info( + f"[Embed] {embeddings.shape[0]}/{total} images embedded " + f"({total - embeddings.shape[0]} skipped)" + ) + return embeddings, valid_paths diff --git a/shared/utils/io.py b/shared/utils/io.py index 69652b2..98d2e7e 100644 --- a/shared/utils/io.py +++ b/shared/utils/io.py @@ -1,22 +1,37 @@ import os import shutil -def list_image_files(image_dir, allowed_extensions=('jpg', 'jpeg', 'png')): +# Image extensions we attempt to load (PIL-decodable raster formats). +IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp') + + +def list_image_files(image_dir, allowed_extensions=IMAGE_EXTENSIONS, recursive=False): """ List image file paths in a directory with allowed extensions. Args: image_dir (str): Path to the directory containing images. - allowed_extensions (tuple, optional): Allowed file extensions. Defaults to ('jpg', 'jpeg', 'png'). + allowed_extensions (tuple, optional): Allowed file extensions (lowercase, + leading dot). Defaults to IMAGE_EXTENSIONS. + recursive (bool, optional): Recurse into subdirectories. Defaults to False. Returns: - list: List of full file paths for images with allowed extensions. + list: Sorted list of full file paths for images with allowed extensions. """ - return [ - os.path.join(image_dir, f) - for f in os.listdir(image_dir) - if f.lower().endswith(allowed_extensions) - ] + if recursive: + paths = [ + os.path.join(root, f) + for root, _, files in os.walk(image_dir) + for f in files + if f.lower().endswith(allowed_extensions) + ] + else: + paths = [ + os.path.join(image_dir, f) + for f in os.listdir(image_dir) + if f.lower().endswith(allowed_extensions) + ] + return sorted(paths) def copy_image(row, repartition_dir): """