Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions apps/precalculated/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def app():
render_projection_section,
render_kmeans_section,
)
from apps.precalculated.components.data_preview import render_data_preview
from apps.precalculated.components.data_preview import (
render_data_preview,
render_cluster_representatives,
)
from shared.components.visualization import render_scatter_plot
from shared.components.summary import render_clustering_summary

Expand Down Expand Up @@ -71,9 +74,10 @@ def app():
with col_preview:
render_data_preview()

# Bottom: Taxonomy summary
# Bottom: Taxonomy summary + representative images
st.markdown("---")
render_clustering_summary(show_taxonomy=True)
render_cluster_representatives()


if __name__ == "__main__":
Expand Down
174 changes: 97 additions & 77 deletions apps/precalculated/components/data_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,80 +6,21 @@
import streamlit as st
import pandas as pd
import numpy as np
import requests
import time
from typing import Optional
from PIL import Image
from io import BytesIO

from shared.utils.logging_config import get_logger
from shared.utils.representatives import find_cluster_representatives
from shared.utils.images import (
IMAGE_URL_COLUMNS,
fetch_images_concurrent,
get_image_from_url,
resolve_record_image_url,
_IMAGE_CACHE,
)
from shared.components.representatives import render_representative_images

logger = get_logger(__name__)


@st.cache_data(ttl=300, show_spinner=False)
def _fetch_image_from_url_cached(url: str, timeout: int = 5) -> Optional[bytes]:
"""Internal cached function to fetch image bytes."""
if not url or not isinstance(url, str):
return None

try:
if not url.startswith(('http://', 'https://')):
return None

response = requests.get(url, timeout=timeout, stream=True)
response.raise_for_status()

content_type = response.headers.get('content-type', '').lower()
if not content_type.startswith('image/'):
return None

return response.content

except Exception:
return None


def fetch_image_from_url(url: str, timeout: int = 5) -> Optional[bytes]:
"""
Fetch an image from a URL with logging.
Uses caching internally but logs the request.
"""
if not url or not isinstance(url, str):
return None

if not url.startswith(('http://', 'https://')):
logger.warning(f"[Image] Invalid URL scheme: {url[:50]}...")
return None

logger.info(f"[Image] Fetching: {url[:80]}...")
start_time = time.time()

result = _fetch_image_from_url_cached(url, timeout)

elapsed = time.time() - start_time
if result:
logger.info(f"[Image] Loaded: {len(result)/1024:.1f}KB in {elapsed:.3f}s")
else:
logger.warning(f"[Image] Failed to load: {url[:50]}...")

return result


def get_image_from_url(url: str) -> Optional[Image.Image]:
"""Get image from URL with caching and logging."""
image_bytes = fetch_image_from_url(url)
if image_bytes:
try:
image = Image.open(BytesIO(image_bytes))
logger.info(f"[Image] Opened: {image.size[0]}x{image.size[1]} {image.mode}")
return image
except Exception as e:
logger.error(f"[Image] Failed to open: {e}")
return None
return None


def render_data_preview():
"""Render the data preview panel (record details on point click)."""
df_plot = st.session_state.get("data", None)
Expand Down Expand Up @@ -110,15 +51,12 @@ def render_data_preview():

st.markdown("### Record Details")

# Try to display image if identifier/url column exists (cached to prevent re-fetch)
image_cols = ['identifier', 'image_url', 'url', 'img_url', 'image']
for img_col in image_cols:
if img_col in record.index and pd.notna(record[img_col]):
url = record[img_col]
image = get_image_from_url(url)
if image is not None:
st.image(image, width=280)
break
# Try to display image if an image URL column exists (process-cached).
url = resolve_record_image_url(record)
if url:
image = get_image_from_url(url)
if image is not None:
st.image(image, width=280)

st.markdown(f"**UUID:** `{selected_uuid}`")

Expand Down Expand Up @@ -276,3 +214,85 @@ def render_cluster_analysis():
st.code(tree_output, language="text")
else:
st.info(f"No valid '{color_by}' values to compare with KMeans clusters.")


def render_cluster_representatives():
"""Render representative images per KMeans cluster for the precalculated app.

Representatives are the members closest to each cluster centroid (computed
on the full-dimensional embeddings). Images are fetched from each record's
URL column; URLs that fail to load are skipped and the next-closest
candidate is tried (fallback), so transient/broken URLs don't leave gaps.
"""
df_plot = st.session_state.get("data", None)
embeddings = st.session_state.get("embeddings", None)
if df_plot is None or embeddings is None:
return

kmeans_cols = sorted(
[c for c in df_plot.columns if c.startswith("KMeans (k=")],
key=lambda c: int(c.split("=")[1].rstrip(")")),
)
if not kmeans_cols:
return # nothing to show until a KMeans run exists

st.markdown("### Representative Images")
st.caption(
"Members closest to each cluster centroid. Images load from each "
"record's URL; unreachable images are skipped automatically."
)

selected_col = st.selectbox(
"KMeans result",
options=kmeans_cols,
index=len(kmeans_cols) - 1,
key="representatives_kmeans_selector",
help="Which KMeans run to show representatives for.",
)

# Guard: embeddings must align row-for-row with df_plot.
if len(embeddings) != len(df_plot):
st.info("Re-run projection and KMeans to view representatives.")
return

n_per_cluster = 3
representatives = find_cluster_representatives(
embeddings, df_plot[selected_col].values, n_per_cluster=n_per_cluster
)

# Warm the cache concurrently. Representatives are oversampled for fallback,
# but we only need a few successes per cluster — prefetch a prefix (2x the
# display count) in parallel. Deeper fallback candidates (rare) resolve
# on-demand below.
prefetch_per_cluster = n_per_cluster * 2
prefetch_urls = [
resolve_record_image_url(df_plot.iloc[idx])
for idxs in representatives.values()
for idx in idxs[:prefetch_per_cluster]
]
with st.spinner("Loading representative images..."):
fetch_images_concurrent([u for u in prefetch_urls if u])

def _resolve(idx):
url = resolve_record_image_url(df_plot.iloc[idx])
if not url:
return None
# Prefetched URLs hit the process cache; anything deeper falls back to
# a single synchronous fetch (also cached).
if url in _IMAGE_CACHE:
return _IMAGE_CACHE[url]
return get_image_from_url(url)

def _caption(idx):
row = df_plot.iloc[idx]
for col in ("scientific_name", "species", "common_name", "uuid"):
if col in row.index and pd.notna(row[col]):
return str(row[col])
return None

render_representative_images(
representatives,
resolve_image=_resolve,
n_per_cluster=n_per_cluster,
caption_fn=_caption,
)
7 changes: 1 addition & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [
"pandas>=2.0.0",
"pillow>=9.0.0",
"pyarrow>=10.0.0",
"requests>=2.28.0",
# Visualization
"altair>=5.0.0",
# Machine learning
Expand Down Expand Up @@ -97,12 +98,6 @@ emb-embed-explore = "apps.embed_explore.app:main"
emb-precalculated = "apps.precalculated.app:main"
list-models = "shared.utils.models:print_available_models"

[tool.hatch.version]
path = "shared/__init__.py"

[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.build.targets.wheel]
packages = ["shared", "apps"]

Expand Down
9 changes: 8 additions & 1 deletion shared/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,11 @@
Shared utilities and services for the emb-explorer applications.
"""

__version__ = "0.1.0"
from importlib.metadata import PackageNotFoundError, version as _version

try:
# Single source of truth: the version declared in pyproject.toml,
# read from the installed package metadata.
__version__ = _version("emb-explorer")
except PackageNotFoundError: # running from a source tree without an install
__version__ = "0.0.0+unknown"
76 changes: 76 additions & 0 deletions shared/components/representatives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Shared renderer for per-cluster representative images.

Both apps surface representative images differently:
- embed_explore resolves a local image file path.
- precalculated fetches a remote image URL (which can fail).

This renderer is source-agnostic: the caller passes a `resolve_image(idx)`
callable that returns something `st.image` can display (a PIL image, a path,
or bytes) or `None` when the image is unavailable. The renderer walks each
cluster's ranked candidate indices and collects up to `n_per_cluster`
successful images, skipping any that resolve to `None` — the shared fallback.
"""

from typing import Any, Callable, Dict, List, Optional

import streamlit as st

from shared.utils.logging_config import get_logger

logger = get_logger(__name__)


def _sorted_cluster_ids(representatives: Dict[object, List[int]]) -> List[object]:
"""Sort cluster ids numerically when possible, else as strings."""
keys = list(representatives.keys())
try:
return sorted(keys, key=lambda k: int(k))
except (ValueError, TypeError):
return sorted(keys, key=str)


def render_representative_images(
representatives: Dict[object, List[int]],
resolve_image: Callable[[int], Optional[Any]],
n_per_cluster: int = 3,
caption_fn: Optional[Callable[[int], str]] = None,
columns: int = 3,
) -> None:
"""Render up to `n_per_cluster` representative images per cluster.

Args:
representatives: {cluster_id: [ranked candidate global indices]}, as
returned by `find_cluster_representatives`.
resolve_image: idx -> displayable (PIL image / path / bytes) or None.
None means "unavailable" and the renderer falls back to the next
candidate.
n_per_cluster: number of images to show per cluster.
caption_fn: optional idx -> caption string.
columns: images per row.
"""
for cluster_id in _sorted_cluster_ids(representatives):
candidates = representatives[cluster_id]
st.markdown(f"**Cluster {cluster_id}**")

# Walk ranked candidates, collecting successful resolutions until we
# have n_per_cluster (or run out of candidates).
shown: List[tuple] = [] # (displayable, caption)
for idx in candidates:
if len(shown) >= n_per_cluster:
break
try:
img = resolve_image(idx)
except Exception as e: # never let one bad image break the panel
logger.debug(f"resolve_image({idx}) raised: {e}")
img = None
if img is not None:
caption = caption_fn(idx) if caption_fn else None
shown.append((img, caption))

if not shown:
st.caption("No images available for this cluster.")
continue

cols = st.columns(min(columns, len(shown)))
for i, (img, caption) in enumerate(shown):
cols[i % len(cols)].image(img, caption=caption, width="stretch")
31 changes: 19 additions & 12 deletions shared/components/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import pandas as pd
from shared.utils.taxonomy_tree import build_taxonomic_tree, format_tree_string, get_tree_statistics
from shared.components.representatives import render_representative_images
from shared.utils.logging_config import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -197,18 +198,24 @@ def render_clustering_summary(show_taxonomy=False):
st.dataframe(summary_df, hide_index=True, width='stretch')

st.markdown("#### Representative Images")
for row in summary_df.itertuples():
k = row.Cluster
st.markdown(f"**Cluster {k}**")
img_cols = st.columns(3)
for i, img_idx in enumerate(representatives[k]):
img_path = df_plot.iloc[img_idx]["image_path"]
logger.debug(f"Displaying representative image: {img_path}")
img_cols[i].image(
img_path,
width='stretch',
caption=os.path.basename(img_path),
)

def _resolve_local_image(idx):
"""Return the local image path if it exists, else None (fallback)."""
path = df_plot.iloc[idx]["image_path"]
if isinstance(path, str) and os.path.exists(path):
return path
return None

def _local_caption(idx):
path = df_plot.iloc[idx]["image_path"]
return os.path.basename(path) if isinstance(path, str) else None

render_representative_images(
representatives,
resolve_image=_resolve_local_image,
n_per_cluster=3,
caption_fn=_local_caption,
)
else:
# Precalculated app: show taxonomy tree (works with or without KMeans)
if show_taxonomy:
Expand Down
Loading
Loading