Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ dependencies = [
"numba>=0.57.0",
# Vision-language models
"open-clip-torch>=2.20.0",
# Custom inference package
"hpc-inference @ git+https://github.com/Imageomics/hpc-inference.git"
# Image embedding pipeline (used directly in shared/utils/image_pipeline.py)
"torch>=1.11.0",
]

[project.optional-dependencies]
Expand Down
168 changes: 117 additions & 51 deletions shared/services/embedding_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,68 @@

Heavy libraries (torch, open_clip) are imported lazily inside methods
to avoid slowing down app startup.

Device-aware concurrency:

PyTorch has two kinds of parallelism built in, we focus on the intra-op
parallelism which is relevant to the embedding pipeline:

Intra-op is the parallelism inside a single operation. One op, say
`Normalize` on a `[3, 244, 244]` tensor, or a big matrix multiply, splits its
own work across multiple threads (via an openMP/MKL thread pool).
`torch.get_num_threads()` queries how many threads one op may use, and
`torch.set_num_threads(n)` sets it.

A single `preprocess(img)` is a chain of torch ops (resize -> to_tensor ->
normalize). With the default intra-op thread settings, each of those ops can
fan its work out across all CPU cores. So ONE preprocess call of one image
can momentarily spin up ~`cpu_count` threads to do that tiny bit of math.

^^^ Why that's wasteful here?

Since we already have our own parallelism layer at image level: the
`ThreadPoolExecutor` runs `workers` threads, one image per thread, and each
thread calls `preprocess(img)`. If each preprocess call fans out across all
CPU cores, then `workers` threads can easily oversubscribe the CPU with
`workers * cpu_count` threads. This causes contention and can actually slow
down the whole process.

```
Layer 1 (ThreadPoolExecutor): 16 worker threads, each handling one image preprocess
Layer 2 (torch intra-op): x Each preprocess call can use up to `cpu_count` threads
========================================================
Total threads = 16 (workers) * cpu_count (intra-op) =>
Potentially 256 threads on a 16-core machine,
causing oversubscription and slowdown.
```

By setting `torch.set_num_threads(1)`, we ensure that each preprocess call
runs single-thread, no internal spliting. All parallelism comes cleanly from
one place - the `ThreadPoolExecutor`. Instead of two nested layers that
multiply into a thread explosion, each core does one useful thing (decode a
whole image) with no scheduling thrash and no per-op thread-launch overhead.

```
Layer 1 (ThreadPoolExecutor): 16 worker threads, each handling one image preprocess
Layer 2 (torch intra-op): x 1 (each op runs single-threaded, instantly)
========================================================
Total threads = 16 (workers) * 1 (intra-op) =>
Potentially 16 threads on a 16-core machine,
fully utilizing the CPU without oversubscription.
```

What Intra-op is good for?

Intra-op parallelism is excellent for big ops. On the CPU-only path, the
forward pass of the model is the bottleneck, and it benefits from intra-op
parallelism. So we leave torch's intra-op threads alone on CPU, and cap the
worker threads to a small number (2) to avoid too much contention. On GPU,
the forward pass is fast and doesn't need CPU cores, so we maximize worker
threads for decoding and set intra-op to 1 to avoid oversubscription.

"""

import os
import numpy as np
import streamlit as st
import time
Expand Down Expand Up @@ -79,90 +139,96 @@ def generate_embeddings(
model_name: str,
batch_size: int,
n_workers: int,
progress_callback: Optional[Callable[[float, str], None]] = None
progress_callback: Optional[Callable[[float, str], None]] = None,
recursive: bool = False,
) -> Tuple[np.ndarray, List[str]]:
"""
Generate embeddings for images in a directory.

Preprocessing runs on a thread pool (GIL-light) overlapped with the model
forward pass — no multiprocessing, so behavior is identical on every OS.

Args:
image_dir: Path to directory containing images
model_name: Name of the model to use
batch_size: Batch size for processing
n_workers: Number of worker processes
batch_size: Batch size for the forward pass
n_workers: Max preprocessing threads (capped per device, see below)
progress_callback: Optional callback for progress updates
recursive: Recurse into subdirectories when listing images

Returns:
Tuple of (embeddings array, list of valid image paths)
"""
import torch
from hpc_inference.datasets.image_folder_dataset import ImageFolderDataset
from shared.utils.image_pipeline import embed_image_folder

logger.info(f"Starting embedding generation: dir={image_dir}, model={model_name}, "
f"batch_size={batch_size}, n_workers={n_workers}")
f"batch_size={batch_size}, n_workers={n_workers}, recursive={recursive}")
total_start = time.time()

if progress_callback:
progress_callback(0.0, "Listing images...")

image_paths = list_image_files(image_dir)
logger.info(f"Found {len(image_paths)} images in {image_dir}")
image_paths = list_image_files(image_dir, recursive=recursive)
total = len(image_paths)
logger.info(f"Found {total} images in {image_dir}")

if progress_callback:
progress_callback(0.1, f"Found {len(image_paths)} images. Loading model...")
progress_callback(0.05, f"Found {total} images. Loading model...")

torch_device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(torch_device)
logger.info(f"Using device: {torch_device}")
model, preprocess = EmbeddingService.load_model_unified(model_name, torch_device)

if progress_callback:
progress_callback(0.2, "Creating dataset...")

# Create dataset & DataLoader
dataset = ImageFolderDataset(
image_dir=image_dir,
preprocess=preprocess,
uuid_mode="fullpath",
rank=0,
world_size=1,
evenly_distribute=True,
validate=True
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=n_workers,
pin_memory=True
)
# Device-aware concurrency:
cpu_count = os.cpu_count() or 1
prev_threads = None

if device.type == "cuda":
# GPU: feed the GPU with parallel decode, avoid per-op oversubscription.
# - preprocess threads: wide
# - torch intra-op threads: forced to 1

# Set the number of preprocessing threads, clamped by three ceilings:
# 1) the user-requested n_workers
# 2) the number of CPU cores
# 3) never more threads than images
workers = max(1, min(n_workers, cpu_count, max(total, 1)))

prev_threads = torch.get_num_threads()
torch.set_num_threads(1)

total = len(image_paths)
valid_paths = []
embeddings = []

processed = 0
with torch.no_grad():
for batch_paths, batch_imgs in dataloader:
batch_imgs = batch_imgs.to(torch_device, non_blocking=True)
batch_embeds = model.encode_image(batch_imgs).cpu().numpy()
embeddings.append(batch_embeds)
valid_paths.extend(batch_paths)
processed += len(batch_paths)

if progress_callback:
progress = 0.2 + (processed / total) * 0.8 # Use 20% to 100% for actual processing
progress_callback(progress, f"Embedding {processed}/{total}")

# Stack embeddings if available
if embeddings:
embeddings = np.vstack(embeddings)
else:
embeddings = np.empty((0, model.visual.output_dim))
# CPU: the CPU forward is the bottleneck, needs the cores,
# so keep preprocess pool small and leave torch threads alone.
workers = max(1, min(2, n_workers, max(total, 1)))

# Map the pipeline's 0..1 progress into the 0.1..1.0 band (model load took 0..0.1).
def _embed_progress(frac: float, msg: str):
if progress_callback:
progress_callback(0.1 + 0.9 * frac, msg)

try:
embeddings, valid_paths = embed_image_folder(
image_paths,
model,
preprocess,
device,
batch_size=batch_size,
n_workers=workers,
progress_callback=_embed_progress,
)
finally:
if prev_threads is not None:
torch.set_num_threads(prev_threads)

if progress_callback:
progress_callback(1.0, f"Complete! Generated {embeddings.shape[0]} embeddings")

total_elapsed = time.time() - total_start
logger.info(f"Embedding generation completed: {embeddings.shape[0]} embeddings in {total_elapsed:.2f}s "
f"({embeddings.shape[0] / total_elapsed:.1f} images/sec)")
rate = embeddings.shape[0] / total_elapsed if total_elapsed > 0 else 0.0
logger.info(f"Embedding generation completed: {embeddings.shape[0]} embeddings in "
f"{total_elapsed:.2f}s ({rate:.1f} images/sec)")

return embeddings, valid_paths
154 changes: 154 additions & 0 deletions shared/utils/image_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Thread-parallel image embedding pipeline.

Turns a list of image paths into embeddings on a single machine. Preprocessing
(decode + transform) runs on a thread pool; the model forward runs on the
calling thread that owns the device. Each batch is preprocessed while the
previous batch runs through the model (a one-batch prefetch), so CPU decoding
and the device forward overlap.

Threads — rather than worker processes — carry the preprocessing because the
work is GIL-light: PIL decode and torchvision tensor ops release the GIL, so a
thread pool scales nearly linearly. Staying in one process means no per-image
data crosses a process boundary and there is no worker-spawn cost, so small
folders are cheap and behavior does not depend on the OS.

This module is Streamlit-free and unit-testable.
"""

from __future__ import annotations

import concurrent.futures as cf
from collections import deque
from typing import Callable, List, Optional, Tuple

import numpy as np
import torch
from PIL import Image

from shared.utils.logging_config import get_logger

logger = get_logger(__name__)


def _output_dim(model) -> int:
"""Best-effort embedding width, for shaping an empty result."""
return int(getattr(getattr(model, "visual", None), "output_dim", 0) or 0)


def _preprocess_one(path: str, preprocess: Callable, color_mode: str):
"""Decode + preprocess a single image.

Returns ``(path, tensor)`` on success or ``(path, None)`` if the file can't
be read/decoded. Pure and device-free, so it is safe on worker threads.
"""
try:
with Image.open(path) as im:
img = im.convert(color_mode)
return path, preprocess(img)
except Exception as e:
logger.warning(f"[Embed] Skipping unreadable image {path}: {e}")
return path, None


def embed_image_folder(
image_paths: List[str],
model,
preprocess: Callable,
device: torch.device,
*,
batch_size: int = 32,
n_workers: int = 8,
prefetch_batches: int = 1,
color_mode: str = "RGB",
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> Tuple[np.ndarray, List[str]]:
"""Embed a list of image paths, overlapping preprocessing with the forward.

Preprocessing runs on a ``ThreadPoolExecutor``; the model forward runs on the
calling thread (which owns ``device``). Up to ``prefetch_batches`` batches are
preprocessed ahead of the batch currently being run through the model.

Unreadable images are skipped (and logged), so the returned embeddings may
have fewer rows than ``image_paths``. ``embeddings[i]`` corresponds to
``valid_paths[i]``.

Args:
image_paths: Image file paths to embed.
model: Model exposing ``encode_image(tensor) -> tensor``.
preprocess: Callable mapping a PIL image to a CHW tensor.
device: Torch device the model lives on.
batch_size: Images per forward pass.
n_workers: Preprocessing threads.
prefetch_batches: Batches to preprocess ahead of the forward (overlap).
color_mode: PIL convert mode applied before preprocessing.
progress_callback: Optional ``(fraction, message)`` progress sink.

Returns:
``(embeddings [N, D] float array, valid_paths [N])``.
"""
total = len(image_paths)
if total == 0:
return np.empty((0, _output_dim(model)), dtype=np.float32), []

batches = [image_paths[i:i + batch_size] for i in range(0, total, batch_size)]
window = max(1, prefetch_batches + 1)

emb_chunks: List[np.ndarray] = []
valid_paths: List[str] = []
processed = 0

# concurrent.futures.ThreadPoolExecutor(max_workers=n_workers)
# spins up `n_workers` OS threads sitting idle, waiting for work...
# hand it work with ex.submit(fn, *args), which returns a Future immediately,
# and runs fn(*args) on a worker thread when it gets scheduled by the OS...
with cf.ThreadPoolExecutor(max_workers=n_workers) as ex:

# non-blocking, starts the preprocessing of a batch on the pool
def submit(batch: List[str]) -> List[cf.Future]:
return [ex.submit(_preprocess_one, p, preprocess, color_mode) for p in batch]

# Prime the pipeline so the first forward already has successors decoding.
# pending is a queue of lists of futures, one list per batch.
pending: deque = deque()
next_idx = 0
while next_idx < len(batches) and len(pending) < window:
pending.append(submit(batches[next_idx]))
next_idx += 1
# pending is now a full window of batches being preprocessed

with torch.no_grad():
while pending:
# Take the oldest in-flight batch
futures = pending.popleft()
# Refill the window first: these batches preprocess on the pool
# while we run the current batch through the model below.
if next_idx < len(batches):
pending.append(submit(batches[next_idx]))
next_idx += 1

# If the worker alreadt=y finished, returns immediately;
# otherwise blocks until the batch is ready.
results = [f.result() for f in futures]
batch_paths = [p for p, t in results if t is not None]
tensors = [t for _, t in results if t is not None]

if tensors:
x = torch.stack(tensors).to(device)
feats = model.encode_image(x).cpu().numpy()
emb_chunks.append(feats)
valid_paths.extend(batch_paths)

processed += len(futures)
if progress_callback:
progress_callback(processed / total, f"Embedding {processed}/{total}")

if emb_chunks:
embeddings = np.vstack(emb_chunks)
else:
embeddings = np.empty((0, _output_dim(model)), dtype=np.float32)

logger.info(
f"[Embed] {embeddings.shape[0]}/{total} images embedded "
f"({total - embeddings.shape[0]} skipped)"
)
return embeddings, valid_paths
Loading
Loading