From b282592ccd59ff9ec5fc0fc1decc863afdc7774f Mon Sep 17 00:00:00 2001 From: Net Zhang Date: Thu, 11 Jun 2026 16:21:09 -0400 Subject: [PATCH 1/3] Add per-cluster representative images: shared util + concurrent fetch Surfaces the members closest to each cluster centroid as a representative-image panel across both apps. This feature already exisits in the embed&explore app, now it's made available on the precalculated embeddings app. Adds a shared compute/render core and a reusable, thread-sfae image-fetching layer. Co-Authored-By: Claude Opus 4.8 --- apps/precalculated/app.py | 8 +- apps/precalculated/components/data_preview.py | 174 +++++++++------- pyproject.toml | 1 + shared/components/representatives.py | 76 +++++++ shared/components/summary.py | 31 +-- shared/services/clustering_service.py | 16 +- shared/utils/images.py | 197 ++++++++++++++++++ shared/utils/representatives.py | 62 ++++++ tests/test_representatives.py | 61 ++++++ 9 files changed, 525 insertions(+), 101 deletions(-) create mode 100644 shared/components/representatives.py create mode 100644 shared/utils/images.py create mode 100644 shared/utils/representatives.py create mode 100644 tests/test_representatives.py diff --git a/apps/precalculated/app.py b/apps/precalculated/app.py index 9456109..bdf205c 100644 --- a/apps/precalculated/app.py +++ b/apps/precalculated/app.py @@ -13,7 +13,10 @@ render_projection_section, render_kmeans_section, ) -from apps.precalculated.components.data_preview import render_data_preview +from apps.precalculated.components.data_preview import ( + render_data_preview, + render_cluster_representatives, +) from shared.components.visualization import render_scatter_plot from shared.components.summary import render_clustering_summary @@ -71,9 +74,10 @@ def app(): with col_preview: render_data_preview() - # Bottom: Taxonomy summary + # Bottom: Taxonomy summary + representative images st.markdown("---") render_clustering_summary(show_taxonomy=True) + render_cluster_representatives() if __name__ == "__main__": diff --git a/apps/precalculated/components/data_preview.py b/apps/precalculated/components/data_preview.py index f855a8a..75c470d 100644 --- a/apps/precalculated/components/data_preview.py +++ b/apps/precalculated/components/data_preview.py @@ -6,80 +6,21 @@ import streamlit as st import pandas as pd import numpy as np -import requests -import time -from typing import Optional -from PIL import Image -from io import BytesIO from shared.utils.logging_config import get_logger +from shared.utils.representatives import find_cluster_representatives +from shared.utils.images import ( + IMAGE_URL_COLUMNS, + fetch_images_concurrent, + get_image_from_url, + resolve_record_image_url, + _IMAGE_CACHE, +) +from shared.components.representatives import render_representative_images logger = get_logger(__name__) -@st.cache_data(ttl=300, show_spinner=False) -def _fetch_image_from_url_cached(url: str, timeout: int = 5) -> Optional[bytes]: - """Internal cached function to fetch image bytes.""" - if not url or not isinstance(url, str): - return None - - try: - if not url.startswith(('http://', 'https://')): - return None - - response = requests.get(url, timeout=timeout, stream=True) - response.raise_for_status() - - content_type = response.headers.get('content-type', '').lower() - if not content_type.startswith('image/'): - return None - - return response.content - - except Exception: - return None - - -def fetch_image_from_url(url: str, timeout: int = 5) -> Optional[bytes]: - """ - Fetch an image from a URL with logging. - Uses caching internally but logs the request. - """ - if not url or not isinstance(url, str): - return None - - if not url.startswith(('http://', 'https://')): - logger.warning(f"[Image] Invalid URL scheme: {url[:50]}...") - return None - - logger.info(f"[Image] Fetching: {url[:80]}...") - start_time = time.time() - - result = _fetch_image_from_url_cached(url, timeout) - - elapsed = time.time() - start_time - if result: - logger.info(f"[Image] Loaded: {len(result)/1024:.1f}KB in {elapsed:.3f}s") - else: - logger.warning(f"[Image] Failed to load: {url[:50]}...") - - return result - - -def get_image_from_url(url: str) -> Optional[Image.Image]: - """Get image from URL with caching and logging.""" - image_bytes = fetch_image_from_url(url) - if image_bytes: - try: - image = Image.open(BytesIO(image_bytes)) - logger.info(f"[Image] Opened: {image.size[0]}x{image.size[1]} {image.mode}") - return image - except Exception as e: - logger.error(f"[Image] Failed to open: {e}") - return None - return None - - def render_data_preview(): """Render the data preview panel (record details on point click).""" df_plot = st.session_state.get("data", None) @@ -110,15 +51,12 @@ def render_data_preview(): st.markdown("### Record Details") - # Try to display image if identifier/url column exists (cached to prevent re-fetch) - image_cols = ['identifier', 'image_url', 'url', 'img_url', 'image'] - for img_col in image_cols: - if img_col in record.index and pd.notna(record[img_col]): - url = record[img_col] - image = get_image_from_url(url) - if image is not None: - st.image(image, width=280) - break + # Try to display image if an image URL column exists (process-cached). + url = resolve_record_image_url(record) + if url: + image = get_image_from_url(url) + if image is not None: + st.image(image, width=280) st.markdown(f"**UUID:** `{selected_uuid}`") @@ -276,3 +214,85 @@ def render_cluster_analysis(): st.code(tree_output, language="text") else: st.info(f"No valid '{color_by}' values to compare with KMeans clusters.") + + +def render_cluster_representatives(): + """Render representative images per KMeans cluster for the precalculated app. + + Representatives are the members closest to each cluster centroid (computed + on the full-dimensional embeddings). Images are fetched from each record's + URL column; URLs that fail to load are skipped and the next-closest + candidate is tried (fallback), so transient/broken URLs don't leave gaps. + """ + df_plot = st.session_state.get("data", None) + embeddings = st.session_state.get("embeddings", None) + if df_plot is None or embeddings is None: + return + + kmeans_cols = sorted( + [c for c in df_plot.columns if c.startswith("KMeans (k=")], + key=lambda c: int(c.split("=")[1].rstrip(")")), + ) + if not kmeans_cols: + return # nothing to show until a KMeans run exists + + st.markdown("### Representative Images") + st.caption( + "Members closest to each cluster centroid. Images load from each " + "record's URL; unreachable images are skipped automatically." + ) + + selected_col = st.selectbox( + "KMeans result", + options=kmeans_cols, + index=len(kmeans_cols) - 1, + key="representatives_kmeans_selector", + help="Which KMeans run to show representatives for.", + ) + + # Guard: embeddings must align row-for-row with df_plot. + if len(embeddings) != len(df_plot): + st.info("Re-run projection and KMeans to view representatives.") + return + + n_per_cluster = 3 + representatives = find_cluster_representatives( + embeddings, df_plot[selected_col].values, n_per_cluster=n_per_cluster + ) + + # Warm the cache concurrently. Representatives are oversampled for fallback, + # but we only need a few successes per cluster — prefetch a prefix (2x the + # display count) in parallel. Deeper fallback candidates (rare) resolve + # on-demand below. + prefetch_per_cluster = n_per_cluster * 2 + prefetch_urls = [ + resolve_record_image_url(df_plot.iloc[idx]) + for idxs in representatives.values() + for idx in idxs[:prefetch_per_cluster] + ] + with st.spinner("Loading representative images..."): + fetch_images_concurrent([u for u in prefetch_urls if u]) + + def _resolve(idx): + url = resolve_record_image_url(df_plot.iloc[idx]) + if not url: + return None + # Prefetched URLs hit the process cache; anything deeper falls back to + # a single synchronous fetch (also cached). + if url in _IMAGE_CACHE: + return _IMAGE_CACHE[url] + return get_image_from_url(url) + + def _caption(idx): + row = df_plot.iloc[idx] + for col in ("scientific_name", "species", "common_name", "uuid"): + if col in row.index and pd.notna(row[col]): + return str(row[col]) + return None + + render_representative_images( + representatives, + resolve_image=_resolve, + n_per_cluster=n_per_cluster, + caption_fn=_caption, + ) diff --git a/pyproject.toml b/pyproject.toml index 2f2f4b5..0ab13d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "pandas>=2.0.0", "pillow>=9.0.0", "pyarrow>=10.0.0", + "requests>=2.28.0", # Visualization "altair>=5.0.0", # Machine learning diff --git a/shared/components/representatives.py b/shared/components/representatives.py new file mode 100644 index 0000000..b7fb12d --- /dev/null +++ b/shared/components/representatives.py @@ -0,0 +1,76 @@ +"""Shared renderer for per-cluster representative images. + +Both apps surface representative images differently: +- embed_explore resolves a local image file path. +- precalculated fetches a remote image URL (which can fail). + +This renderer is source-agnostic: the caller passes a `resolve_image(idx)` +callable that returns something `st.image` can display (a PIL image, a path, +or bytes) or `None` when the image is unavailable. The renderer walks each +cluster's ranked candidate indices and collects up to `n_per_cluster` +successful images, skipping any that resolve to `None` — the shared fallback. +""" + +from typing import Any, Callable, Dict, List, Optional + +import streamlit as st + +from shared.utils.logging_config import get_logger + +logger = get_logger(__name__) + + +def _sorted_cluster_ids(representatives: Dict[object, List[int]]) -> List[object]: + """Sort cluster ids numerically when possible, else as strings.""" + keys = list(representatives.keys()) + try: + return sorted(keys, key=lambda k: int(k)) + except (ValueError, TypeError): + return sorted(keys, key=str) + + +def render_representative_images( + representatives: Dict[object, List[int]], + resolve_image: Callable[[int], Optional[Any]], + n_per_cluster: int = 3, + caption_fn: Optional[Callable[[int], str]] = None, + columns: int = 3, +) -> None: + """Render up to `n_per_cluster` representative images per cluster. + + Args: + representatives: {cluster_id: [ranked candidate global indices]}, as + returned by `find_cluster_representatives`. + resolve_image: idx -> displayable (PIL image / path / bytes) or None. + None means "unavailable" and the renderer falls back to the next + candidate. + n_per_cluster: number of images to show per cluster. + caption_fn: optional idx -> caption string. + columns: images per row. + """ + for cluster_id in _sorted_cluster_ids(representatives): + candidates = representatives[cluster_id] + st.markdown(f"**Cluster {cluster_id}**") + + # Walk ranked candidates, collecting successful resolutions until we + # have n_per_cluster (or run out of candidates). + shown: List[tuple] = [] # (displayable, caption) + for idx in candidates: + if len(shown) >= n_per_cluster: + break + try: + img = resolve_image(idx) + except Exception as e: # never let one bad image break the panel + logger.debug(f"resolve_image({idx}) raised: {e}") + img = None + if img is not None: + caption = caption_fn(idx) if caption_fn else None + shown.append((img, caption)) + + if not shown: + st.caption("No images available for this cluster.") + continue + + cols = st.columns(min(columns, len(shown))) + for i, (img, caption) in enumerate(shown): + cols[i % len(cols)].image(img, caption=caption, width="stretch") diff --git a/shared/components/summary.py b/shared/components/summary.py index 03dcd29..4b60841 100644 --- a/shared/components/summary.py +++ b/shared/components/summary.py @@ -6,6 +6,7 @@ import os import pandas as pd from shared.utils.taxonomy_tree import build_taxonomic_tree, format_tree_string, get_tree_statistics +from shared.components.representatives import render_representative_images from shared.utils.logging_config import get_logger logger = get_logger(__name__) @@ -165,18 +166,24 @@ def render_clustering_summary(show_taxonomy=False): st.dataframe(summary_df, hide_index=True, width='stretch') st.markdown("#### Representative Images") - for row in summary_df.itertuples(): - k = row.Cluster - st.markdown(f"**Cluster {k}**") - img_cols = st.columns(3) - for i, img_idx in enumerate(representatives[k]): - img_path = df_plot.iloc[img_idx]["image_path"] - logger.debug(f"Displaying representative image: {img_path}") - img_cols[i].image( - img_path, - width='stretch', - caption=os.path.basename(img_path) - ) + + def _resolve_local_image(idx): + """Return local image path if it exists, else None (fallback).""" + path = df_plot.iloc[idx]["image_path"] + if isinstance(path, str) and os.path.exists(path): + return path + return None + + def _local_caption(idx): + path = df_plot.iloc[idx]["image_path"] + return os.path.basename(path) if isinstance(path, str) else None + + render_representative_images( + representatives, + resolve_image=_resolve_local_image, + n_per_cluster=3, + caption_fn=_local_caption, + ) else: st.info("Clustering summary will be computed when you run clustering.") else: diff --git a/shared/services/clustering_service.py b/shared/services/clustering_service.py index 6847b9f..8bcc4a3 100644 --- a/shared/services/clustering_service.py +++ b/shared/services/clustering_service.py @@ -196,25 +196,21 @@ def generate_clustering_summary( Returns: Tuple of (summary dataframe, representatives dict) """ + from shared.utils.representatives import find_cluster_representatives + logger.info("Generating clustering summary statistics") cluster_ids = np.unique(labels) logger.debug(f"Found {len(cluster_ids)} unique clusters") - summary_data = [] - representatives = {} + # Ranked representative candidates per cluster (shared utility). + representatives = find_cluster_representatives(embeddings, labels) + + summary_data = [] for k in cluster_ids: idxs = np.where(labels == k)[0] cluster_embeds = embeddings[idxs] centroid = cluster_embeds.mean(axis=0) - - # Internal variance variance = np.mean(np.sum((cluster_embeds - centroid) ** 2, axis=1)) - - # Find 3 closest images - dists = np.sum((cluster_embeds - centroid) ** 2, axis=1) - closest_indices = idxs[np.argsort(dists)[:3]] - representatives[k] = closest_indices - summary_data.append({ "Cluster": int(k), "Count": len(idxs), diff --git a/shared/utils/images.py b/shared/utils/images.py new file mode 100644 index 0000000..61af0e8 --- /dev/null +++ b/shared/utils/images.py @@ -0,0 +1,197 @@ +"""Shared image-fetching utilities. + +App-agnostic helpers for resolving and fetching record images from remote +URLs. Kept free of Streamlit so the helpers run safely on worker threads and +in any app (precalculated, a URL-based embed_explore, the demo Space, ...). + +In-app fetch flow +----------------- +A record (parquet row) holds an image URL in one of ``IMAGE_URL_COLUMNS``. +Two call paths consume these: + +1. Cluster representatives (bulk, eager). + ``render_cluster_representatives`` resolves a URL per candidate with + ``resolve_record_image_url`` and warms the cache up front via + ``fetch_images_concurrent`` (thread pool, 8 workers). Each thread calls + ``download_image_bytes`` -> ``bytes_to_image`` and stores the PIL image + (or ``None`` on failure) in ``_IMAGE_CACHE``. The renderer then reads + results straight from the cache; broken URLs are skipped and the next + candidate is tried. + +2. Click preview (single, lazy). + ``render_data_preview`` resolves one URL and calls ``get_image_from_url``, + which serves the cached image if present and otherwise does a single + synchronous ``download_image_bytes`` -> ``bytes_to_image`` and caches it. + +So both paths share one fetch primitive and one cache; the only difference is +concurrent prefetch vs. on-demand single fetch. + +Why a process-level cache (not ``@st.cache_data``) +-------------------------------------------------- +- The bulk path fetches from worker threads, where ``st.*`` calls are unsafe; + a plain module-level dict is thread-friendly and lets both paths share the + same entries. +- It survives Streamlit reruns within the process, so panning/clicking does + not refetch. A soft FIFO cap (``_IMAGE_CACHE_MAX``) bounds memory. + Trimming only happens at the end of a fetch call, not on every insertion. +- ``None`` is cached as a known miss, so a dead URL is fetched at most once. + +The single shared ``requests.Session`` carries the project User-Agent so data +hosts can identify / allowlist us. +""" + +import concurrent.futures +import time +from io import BytesIO +from typing import Dict, Iterable, Optional + +import requests +from PIL import Image + +from shared import __version__ as _EMB_EXPLORER_VERSION +from shared.utils.logging_config import get_logger + +logger = get_logger(__name__) + +# Columns checked, in order, for an image URL when resolving a record's image. +IMAGE_URL_COLUMNS = ['identifier', 'image_url', 'url', 'img_url', 'image'] + +# Be a polite client: identify the app and link the repo so data hosts can +# contact us / allowlist us if needed. +USER_AGENT = ( + f"emb-explorer/{_EMB_EXPLORER_VERSION} " + "(https://github.com/Imageomics/emb-explorer)" +) + +_session: Optional[requests.Session] = None + + +def _get_session() -> requests.Session: + """Lazily build a shared requests.Session carrying our User-Agent.""" + global _session + if _session is None: + s = requests.Session() + s.headers.update({"User-Agent": USER_AGENT}) + _session = s + return _session + + +def download_image_bytes(url: str, timeout: int = 5) -> Optional[bytes]: + """Fetch raw image bytes via the shared session. None on any failure. + + Contains no Streamlit calls, so it is safe to run from worker threads. + """ + if not isinstance(url, str) or not url.startswith(('http://', 'https://')): + return None + try: + resp = _get_session().get(url, timeout=timeout, stream=True) + resp.raise_for_status() + if not resp.headers.get('content-type', '').lower().startswith('image/'): + return None + return resp.content + except Exception: + return None + + +def bytes_to_image(data: Optional[bytes]) -> Optional[Image.Image]: + """Decode image bytes to a PIL image, or None on failure.""" + if not data: + return None + try: + return Image.open(BytesIO(data)) + except Exception as e: + logger.error(f"[Image] Failed to open: {e}") + return None + + +# Process-level cache for fetched images. Survives Streamlit reruns within the +# process; value is a PIL image or None (known miss). +_IMAGE_CACHE: Dict[str, Optional[Image.Image]] = {} +_IMAGE_CACHE_MAX = 512 + + +def _trim_cache() -> None: + """Soft FIFO cap so the cache doesn't grow unbounded across sessions.""" + if len(_IMAGE_CACHE) > _IMAGE_CACHE_MAX: + for k in list(_IMAGE_CACHE.keys())[: len(_IMAGE_CACHE) - _IMAGE_CACHE_MAX]: + _IMAGE_CACHE.pop(k, None) + + +def fetch_images_concurrent( + urls: Iterable[str], max_workers: int = 8, timeout: int = 5 +) -> Dict[str, Optional[Image.Image]]: + """Fetch many image URLs concurrently with a thread pool. + + Returns {url: PIL image or None}. Per-URL results are cached in a + process-level dict so reruns and overlapping clusters don't refetch. + Threads only do HTTP + PIL decode (no st.* calls), which is Streamlit-safe. + """ + unique = [u for u in dict.fromkeys(urls) if isinstance(u, str) and u] + missing = [u for u in unique if u not in _IMAGE_CACHE] + + if missing: + t0 = time.time() + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: + future_to_url = { + ex.submit(download_image_bytes, u, timeout): u for u in missing + } + for fut in concurrent.futures.as_completed(future_to_url): + u = future_to_url[fut] + try: + _IMAGE_CACHE[u] = bytes_to_image(fut.result()) + except Exception: + _IMAGE_CACHE[u] = None + ok = sum(1 for u in missing if _IMAGE_CACHE.get(u) is not None) + logger.info( + f"[Image] Concurrently fetched {len(missing)} url(s) in " + f"{time.time() - t0:.2f}s ({ok} ok)" + ) + _trim_cache() + + return {u: _IMAGE_CACHE.get(u) for u in unique} + + +def get_image_from_url(url: str, timeout: int = 5) -> Optional[Image.Image]: + """Get a single image from a URL, using the process cache. + + Logs the request; results (including misses) are cached so repeated + lookups and the concurrent path share one cache. + """ + if not url or not isinstance(url, str): + return None + if url in _IMAGE_CACHE: + return _IMAGE_CACHE[url] + if not url.startswith(('http://', 'https://')): + logger.warning(f"[Image] Invalid URL scheme: {url[:50]}...") + return None + + logger.info(f"[Image] Fetching: {url[:80]}...") + start_time = time.time() + image = bytes_to_image(download_image_bytes(url, timeout)) + elapsed = time.time() - start_time + if image is not None: + logger.info(f"[Image] Loaded in {elapsed:.3f}s") + else: + logger.warning(f"[Image] Failed to load: {url[:50]}...") + + _IMAGE_CACHE[url] = image + _trim_cache() + return image + + +def resolve_record_image_url(row) -> Optional[str]: + """Return the first valid HTTP(S) image URL from a record/row, else None. + + `row` is anything supporting `col in row` membership and `row[col]` + indexing (e.g. a pandas Series or a dict). + """ + for col in IMAGE_URL_COLUMNS: + try: + present = col in row.index + except AttributeError: + present = col in row + if present: + val = row[col] + if isinstance(val, str) and val.startswith(('http://', 'https://')): + return val + return None diff --git a/shared/utils/representatives.py b/shared/utils/representatives.py new file mode 100644 index 0000000..329d078 --- /dev/null +++ b/shared/utils/representatives.py @@ -0,0 +1,62 @@ +"""Find representative members of clusters. + +Given embeddings and cluster labels, rank each cluster's members by proximity +to the cluster centroid. Returns more candidates than strictly requested +(oversampled) so callers that render images can skip candidates whose image +fails to load and still show the desired number per cluster. +""" + +from typing import Dict, List + +import numpy as np + +from shared.utils.logging_config import get_logger + +logger = get_logger(__name__) + + +def find_cluster_representatives( + embeddings: np.ndarray, + labels, + n_per_cluster: int = 3, + oversample: int = 4, +) -> Dict[object, List[int]]: + """Rank each cluster's members by closeness to the cluster centroid. + + Args: + embeddings: (N, D) array of embeddings (row i aligns with label i). + labels: array-like of length N with cluster labels (int or str). + n_per_cluster: how many representatives the caller intends to show. + oversample: multiplier for how many candidate indices to return per + cluster (n_per_cluster * oversample), so failed image loads can be + skipped while still surfacing n_per_cluster images. + + Returns: + Dict mapping each cluster label to a list of global indices into + `embeddings`, ordered closest-to-centroid first, capped at + n_per_cluster * oversample (or the cluster size, whichever is smaller). + """ + labels = np.asarray(labels) + embeddings = np.asarray(embeddings) + n_candidates = max(n_per_cluster * oversample, n_per_cluster) + + representatives: Dict[object, List[int]] = {} + for cluster_id in np.unique(labels): + member_idxs = np.where(labels == cluster_id)[0] + if member_idxs.size == 0: + continue + cluster_embeds = embeddings[member_idxs] + centroid = cluster_embeds.mean(axis=0) + + # Compute squared Euclidean distance to the centroid for each member. + dists = np.sum((cluster_embeds - centroid) ** 2, axis=1) + order = np.argsort(dists)[:n_candidates] + # Keep the label's native Python type for clean dict keys / display. + key = cluster_id.item() if hasattr(cluster_id, "item") else cluster_id + representatives[key] = member_idxs[order].tolist() + + logger.debug( + f"Found representatives for {len(representatives)} clusters " + f"(up to {n_candidates} candidates each)" + ) + return representatives diff --git a/tests/test_representatives.py b/tests/test_representatives.py new file mode 100644 index 0000000..779d838 --- /dev/null +++ b/tests/test_representatives.py @@ -0,0 +1,61 @@ +"""Tests for shared/utils/representatives.py (find_cluster_representatives). + +Covers the pure ranking logic only! Centroid ordering, global-index +correctness, the oversample cap, and label-type handling. The Streamlit +renderer is intentionally not tested here. +""" + +import numpy as np + +from shared.utils.representatives import find_cluster_representatives + + +def test_ranks_closest_to_centroid_first(): + """Members are ordered by ascending distance to the cluster centroid.""" + # Single cluster; centroid_x = (0+3+100)/3 = 34.33, so idx 1 (x=3) is + # closest, then idx 0 (x=0), then idx 2 (x=100). + embeddings = np.array([[0.0, 0.0], [3.0, 0.0], [100.0, 0.0]]) + labels = [0, 0, 0] # All in one cluster + + reps = find_cluster_representatives(embeddings, labels, n_per_cluster=3) + + assert reps[0] == [1, 0, 2] + + +def test_indices_are_global_and_match_cluster(): + """Returned indices are global and only reference members of that cluster.""" + labels = [0, 1, 0, 1, 0] + embeddings = np.random.RandomState(0).rand(5, 4) + + reps = find_cluster_representatives(embeddings, labels, n_per_cluster=2) + + assert set(reps.keys()) == {0, 1} + labels_arr = np.asarray(labels) + for cluster_id, idxs in reps.items(): + assert all(labels_arr[i] == cluster_id for i in idxs) + + +def test_oversample_capped_at_cluster_size(): + """Candidates = min(n_per_cluster * oversample, cluster size).""" + # Cluster 0: 4 members, cluster 1: 10 members. + labels = [0] * 4 + [1] * 10 + embeddings = np.random.RandomState(1).rand(14, 3) + + reps = find_cluster_representatives( + embeddings, labels, n_per_cluster=2, oversample=3 + ) # n_candidates = 6 + + assert len(reps[0]) == 4 # capped at cluster size + assert len(reps[1]) == 6 # capped at n_candidates + + +def test_preserves_label_type(): + """String labels stay string keys; numpy int labels become Python ints.""" + embeddings = np.random.RandomState(2).rand(3, 2) + + str_reps = find_cluster_representatives(embeddings, ["a", "a", "b"]) + assert set(str_reps.keys()) == {"a", "b"} + + int_reps = find_cluster_representatives(embeddings, np.array([0, 0, 1])) + assert set(int_reps.keys()) == {0, 1} + assert all(isinstance(k, int) for k in int_reps.keys()) From 806f557515c5fb46fd56bc22cd91de19a148df67 Mon Sep 17 00:00:00 2001 From: Net Zhang Date: Tue, 16 Jun 2026 08:16:12 -0400 Subject: [PATCH 2/3] Use (+URL) crawler convention in image-fetch User-Agent --- shared/utils/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shared/utils/images.py b/shared/utils/images.py index 61af0e8..548dab7 100644 --- a/shared/utils/images.py +++ b/shared/utils/images.py @@ -60,7 +60,7 @@ # contact us / allowlist us if needed. USER_AGENT = ( f"emb-explorer/{_EMB_EXPLORER_VERSION} " - "(https://github.com/Imageomics/emb-explorer)" + "(+https://github.com/Imageomics/emb-explorer)" ) _session: Optional[requests.Session] = None From 741eb35606d1e81b1c81bdee37b02422f37546eb Mon Sep 17 00:00:00 2001 From: Net Zhang Date: Tue, 16 Jun 2026 08:17:04 -0400 Subject: [PATCH 3/3] Single-source the pkg version in pypropject.toml [project].version (static, 1.0.0) is now the sole source of truth. Drop the dynamic [tool.hatch.version], and read __version__ from installed metadata via `importlib.metadata`. Also remove [tool.hatch.metadata] allow-direct-references, obsolete now that the hpc-inference git dependency is gone. --- pyproject.toml | 6 ------ shared/__init__.py | 9 ++++++++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d0936dd..922fbab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,12 +98,6 @@ emb-embed-explore = "apps.embed_explore.app:main" emb-precalculated = "apps.precalculated.app:main" list-models = "shared.utils.models:print_available_models" -[tool.hatch.version] -path = "shared/__init__.py" - -[tool.hatch.metadata] -allow-direct-references = true - [tool.hatch.build.targets.wheel] packages = ["shared", "apps"] diff --git a/shared/__init__.py b/shared/__init__.py index ae13b8e..d0214a5 100644 --- a/shared/__init__.py +++ b/shared/__init__.py @@ -2,4 +2,11 @@ Shared utilities and services for the emb-explorer applications. """ -__version__ = "0.1.0" +from importlib.metadata import PackageNotFoundError, version as _version + +try: + # Single source of truth: the version declared in pyproject.toml, + # read from the installed package metadata. + __version__ = _version("emb-explorer") +except PackageNotFoundError: # running from a source tree without an install + __version__ = "0.0.0+unknown"