diff --git a/apps/embed_explore/components/image_preview.py b/apps/embed_explore/components/image_preview.py index 368483b..092b56f 100644 --- a/apps/embed_explore/components/image_preview.py +++ b/apps/embed_explore/components/image_preview.py @@ -19,24 +19,31 @@ def render_image_preview(): valid_paths = st.session_state.get("valid_paths", None) labels = st.session_state.get("labels", None) - selected_idx = st.session_state.get("selected_image_idx", 0) + kmeans_col = st.session_state.get("kmeans_column", None) + selected_idx = st.session_state.get("selected_image_idx", None) if ( valid_paths is not None and - labels is not None and selected_idx is not None and 0 <= selected_idx < len(valid_paths) ): img_path = valid_paths[selected_idx] - cluster = labels[selected_idx] if labels is not None else "?" + cluster = labels[selected_idx] if labels is not None else None - # Log only when image changes if _last_displayed_path != img_path: - logger.info(f"[Image] Loading local file: {os.path.basename(img_path)} (cluster={cluster})") + log_msg = f"[Image] Loading local file: {os.path.basename(img_path)}" + if cluster is not None: + log_msg += f" (cluster={cluster})" + logger.info(log_msg) _last_displayed_path = img_path - st.image(img_path, caption=f"Cluster {cluster}: {os.path.basename(img_path)}", width='stretch') + caption = os.path.basename(img_path) + if cluster is not None and kmeans_col: + caption = f"{kmeans_col}={cluster}: {caption}" + + st.image(img_path, caption=caption, width='stretch') st.markdown(f"**File:** `{os.path.basename(img_path)}`") - st.markdown(f"**Cluster:** `{cluster}`") + if cluster is not None and kmeans_col: + st.markdown(f"**{kmeans_col}:** `{cluster}`") else: - st.info("Image preview will appear here after you select a cluster point.") + st.info("Image preview will appear here after you select a point in the scatter.") diff --git a/apps/embed_explore/components/sidebar.py b/apps/embed_explore/components/sidebar.py index 129b2a2..511d88b 100644 --- a/apps/embed_explore/components/sidebar.py +++ b/apps/embed_explore/components/sidebar.py @@ -4,13 +4,20 @@ import streamlit as st import os -from typing import Tuple, List, Optional +import time +import hashlib +import numpy as np +import pandas as pd +from typing import Tuple, Optional from shared.services.embedding_service import EmbeddingService from shared.services.clustering_service import ClusteringService from shared.services.file_service import FileService from shared.lib.progress import StreamlitProgressContext -from shared.components.clustering_controls import render_clustering_backend_controls, render_basic_clustering_controls +from shared.components.clustering_controls import ( + render_projection_controls, + render_kmeans_controls, +) from shared.utils.backend import check_cuda_available, resolve_backend, is_oom_error from shared.utils.logging_config import get_logger @@ -74,10 +81,11 @@ def render_embedding_section() -> Tuple[bool, Optional[str], Optional[str], int, st.session_state.valid_paths = valid_paths st.session_state.last_image_dir = image_dir st.session_state.embedding_complete = True - # Reset clustering/selection state + # Reset projection/clustering/selection state for the new embeddings st.session_state.labels = None + st.session_state.kmeans_column = None st.session_state.data = None - st.session_state.selected_image_idx = 0 + st.session_state.selected_image_idx = None except Exception as e: st.error(f"Error during embedding: {e}") @@ -89,152 +97,239 @@ def render_embedding_section() -> Tuple[bool, Optional[str], Optional[str], int, return embed_button, image_dir, model_name, n_workers, batch_size -def render_clustering_section(n_workers: int = 1) -> Tuple[bool, int, str]: - """ - Render the clustering section of the sidebar. +def render_projection_section(): + """Render the 2D projection section.""" + with st.expander("Project to 2D", expanded=False): + embeddings = st.session_state.get("embeddings", None) + valid_paths = st.session_state.get("valid_paths", None) - Args: - n_workers: Number of workers for parallel processing + if embeddings is None or valid_paths is None or len(valid_paths) < 2: + st.info("Run embedding first to enable projection.") + return - Returns: - Tuple of (cluster_button_clicked, n_clusters, reduction_method) - """ - with st.expander("Cluster", expanded=False): - # Basic clustering controls - n_clusters, reduction_method = render_basic_clustering_controls() - - # Backend and advanced controls - dim_reduction_backend, clustering_backend, n_workers_clustering, seed = render_clustering_backend_controls() - - cluster_button = st.button("Run Clustering", type="primary") - - # Handle clustering execution - if cluster_button: - embeddings = st.session_state.get("embeddings", None) - valid_paths = st.session_state.get("valid_paths", None) - - if embeddings is not None and valid_paths is not None and len(valid_paths) > 1: - run_clustering_with_fallback( - embeddings, valid_paths, n_clusters, reduction_method, - n_workers_clustering, dim_reduction_backend, clustering_backend, seed - ) - else: - st.error("Please run embedding first.") - - return cluster_button, n_clusters, reduction_method - - -def run_clustering_with_fallback( - embeddings, - valid_paths, - n_clusters: int, - reduction_method: str, - n_workers: int, - dim_reduction_backend: str, - clustering_backend: str, - seed: Optional[int] -): - """ - Run clustering with robust error handling and automatic fallbacks. + n_samples, emb_dim = embeddings.shape + st.markdown(f"**Ready to project:** {n_samples:,} images ({emb_dim}-dim embeddings)") - Uses ClusteringService.run_clustering_safe() which transparently - handles GPU errors by falling back to CPU-based sklearn backends. - """ - cuda_available, device_info = check_cuda_available() - actual_dim_backend = resolve_backend(dim_reduction_backend, "reduction") - actual_cluster_backend = resolve_backend(clustering_backend, "clustering") + reduction_method = st.selectbox( + "Dimensionality Reduction", + ["TSNE", "PCA", "UMAP"], + help="Method to project high-dimensional embeddings to 2D for visualization.", + ) + + dim_reduction_backend, seed = render_projection_controls() + + if st.button("Project to 2D", type="primary"): + _run_projection(embeddings, valid_paths, reduction_method, dim_reduction_backend, seed) + + +def render_kmeans_section(): + """Render the optional KMeans clustering section.""" + with st.expander("KMeans Clustering", expanded=False): + df_plot = st.session_state.get("data", None) + embeddings = st.session_state.get("embeddings", None) + + if df_plot is None or embeddings is None: + st.info("Run projection first to enable KMeans.") + return + + emb_dim = embeddings.shape[1] + st.markdown(f"**{len(df_plot):,} points** ({emb_dim}-dim embeddings)") + + n_clusters = st.slider("Number of clusters", 2, min(100, max(2, len(df_plot) // 2)), 5) + + clustering_backend, n_workers, seed = render_kmeans_controls() + + if st.button("Run KMeans", type="primary"): + _run_kmeans(embeddings, n_clusters, clustering_backend, n_workers, seed) + + +def _run_projection(embeddings, valid_paths, reduction_method, dim_reduction_backend, seed): + """Run dim reduction and create the 2D scatter plot dataframe.""" + try: + cuda_available, device_info = check_cuda_available() + actual_backend = resolve_backend(dim_reduction_backend, "reduction") + + logger.info("=" * 60) + logger.info("PROJECTION START") + logger.info(f"Device: {device_info} (CUDA: {'Yes' if cuda_available else 'No'})") + logger.info(f"Backend: {actual_backend} (requested: {dim_reduction_backend})") + + t_start = time.time() + n_samples, emb_dim = embeddings.shape + logger.info(f"Records: {n_samples:,} | Dim: {emb_dim}") + + with st.spinner(f"Running {reduction_method}..."): + reduced = ClusteringService.run_dim_reduction_safe( + embeddings, reduction_method, + n_workers=8, dim_reduction_backend=actual_backend, seed=seed + ) + + t_total = time.time() - t_start + logger.info(f"Projection complete in {t_total:.2f}s") + + # Build plot dataframe (no cluster column) + df_plot = pd.DataFrame({ + "x": reduced[:, 0], + "y": reduced[:, 1], + "image_path": valid_paths, + "file_name": [os.path.basename(p) for p in valid_paths], + "idx": range(len(valid_paths)), + }) + + # Carry over any prior KMeans columns from the previous df_plot (if length matches) + prev_df = st.session_state.get("data") + if prev_df is not None and len(prev_df) == len(df_plot): + for col in prev_df.columns: + if col.startswith("KMeans (k="): + df_plot[col] = prev_df[col].values + + data_hash = hashlib.md5(f"{len(df_plot)}_{reduction_method}_{t_total}".encode()).hexdigest()[:8] + st.session_state.data = df_plot + st.session_state.data_version = data_hash + st.session_state.selected_image_idx = None + + logger.info("=" * 60) + st.success(f"Projected {n_samples:,} points to 2D using {reduction_method}.") + + except (RuntimeError, OSError) as e: + if is_oom_error(e): + st.error("**GPU Out of Memory**") + st.info("Try: Reduce dataset size, use 'sklearn' backend, or try PCA.") + logger.exception("GPU OOM during projection") + else: + st.error(f"Error during projection: {e}") + logger.exception("Projection error") + except MemoryError: + st.error("**System Out of Memory** - Reduce dataset size") + logger.exception("System memory exhausted during projection") + except Exception as e: + st.error(f"Error: {e}") + logger.exception("Unexpected projection error") - logger.info(f"Starting clustering: samples={len(embeddings)}, clusters={n_clusters}, " - f"reduction={reduction_method}, device={device_info}") - logger.info(f"Backends: dim_reduction={actual_dim_backend}, clustering={actual_cluster_backend}") +def _run_kmeans(embeddings, n_clusters, clustering_backend, n_workers, seed): + """Run KMeans on already-extracted embeddings and add labels to df_plot.""" try: - with st.spinner(f"Running {reduction_method} + KMeans ({actual_dim_backend}/{actual_cluster_backend})..."): - df_plot, labels = ClusteringService.run_clustering_safe( - embeddings, valid_paths, n_clusters, reduction_method, - n_workers, actual_dim_backend, actual_cluster_backend, seed + actual_backend = resolve_backend(clustering_backend, "clustering") + logger.info(f"KMeans: k={n_clusters}, backend={actual_backend}") + + with st.spinner(f"Running KMeans (k={n_clusters})..."): + labels = ClusteringService.run_kmeans_only_safe( + embeddings, n_clusters, + n_workers=n_workers, clustering_backend=actual_backend, seed=seed ) - # Store results + df_plot = st.session_state.data + kmeans_col = f"KMeans (k={n_clusters})" + + df_plot[kmeans_col] = labels.astype(str) st.session_state.data = df_plot st.session_state.labels = labels - st.session_state.selected_image_idx = 0 + st.session_state.kmeans_column = kmeans_col - # Compute and store clustering summary + # Compute clustering summary on the full embedding space. + # Cache by kmeans_col so multiple KMeans runs can each have their own + # summary + representatives that the user can switch between. logger.info("Computing clustering summary statistics...") summary_df, representatives = ClusteringService.generate_clustering_summary( embeddings, labels, df_plot ) - st.session_state.clustering_summary = summary_df - st.session_state.clustering_representatives = representatives - logger.info(f"Clustering summary computed: {len(summary_df)} clusters") + summaries = st.session_state.get("clustering_summaries", {}) + reps_by_col = st.session_state.get("clustering_representatives_by_col", {}) + summaries[kmeans_col] = summary_df + reps_by_col[kmeans_col] = representatives + st.session_state.clustering_summaries = summaries + st.session_state.clustering_representatives_by_col = reps_by_col + logger.info(f"Clustering summary computed for {kmeans_col}: {len(summary_df)} clusters") - st.success(f"Clustering complete! Found {n_clusters} clusters.") + logger.info(f"KMeans complete: {len(np.unique(labels))} clusters") + st.success(f"KMeans complete! {len(np.unique(labels))} clusters assigned.") except (RuntimeError, OSError) as e: if is_oom_error(e): - st.error("**GPU Out of Memory** - Dataset too large for GPU") - st.info("Try: Reduce dataset size, or select 'sklearn' backend") - logger.exception("GPU OOM error during clustering") + st.error("**GPU Out of Memory**") + logger.exception("GPU OOM during KMeans") else: - st.error(f"Error during clustering: {e}") - logger.exception("Clustering error") - + st.error(f"Error during KMeans: {e}") + logger.exception("KMeans error") except MemoryError: st.error("**System Out of Memory** - Reduce dataset size") - logger.exception("System memory exhausted during clustering") - + logger.exception("System memory exhausted during KMeans") except Exception as e: - st.error(f"Error during clustering: {e}") - logger.exception("Unexpected clustering error") + st.error(f"Error: {e}") + logger.exception("Unexpected KMeans error") -def render_save_section(): - """Render the save operations section of the sidebar.""" - # --- Save images from a specific cluster utility --- - save_status_placeholder = st.empty() - with st.expander("Save Images from Specific Cluster", expanded=True): - df_plot = st.session_state.get("data", None) - labels = st.session_state.get("labels", None) - - if df_plot is not None and labels is not None: - available_clusters = sorted(df_plot['cluster'].unique(), key=lambda x: int(x)) - selected_clusters = st.multiselect( - "Select cluster(s) to save", - available_clusters, - default=available_clusters[:1] if available_clusters else [], - key="save_cluster_select" - ) - save_dir = st.text_input( - "Directory to save selected cluster images", - value="cluster_selected_output", - key="save_cluster_dir" - ) - save_cluster_button = st.button("Save images", key="save_cluster_btn") +def _get_available_kmeans_cols(df_plot) -> list: + """Return KMeans columns in df_plot sorted by k value.""" + if df_plot is None: + return [] + return sorted( + [c for c in df_plot.columns if c.startswith("KMeans (k=")], + key=lambda c: int(c.split("=")[1].rstrip(")")), + ) - # Handle save execution - if save_cluster_button and selected_clusters: - cluster_rows = df_plot[df_plot['cluster'].isin(selected_clusters)] - max_workers = st.session_state.get("num_threads", 8) - with StreamlitProgressContext( - save_status_placeholder, - f"Images from cluster(s) {', '.join(map(str, selected_clusters))} saved successfully!" - ) as progress: - try: - save_summary_df, csv_path = FileService.save_cluster_images( - cluster_rows, save_dir, max_workers, progress_callback=progress - ) - st.info(f"Summary CSV saved at {csv_path}") - - except Exception as e: - save_status_placeholder.error(f"Error saving images: {e}") +def render_save_section(): + """Render the save operations section of the sidebar. - elif save_cluster_button: - save_status_placeholder.warning("Please select at least one cluster.") + Both 'Save Images from Specific Cluster' and 'Repartition Images by Cluster' + require at least one KMeans run. When multiple KMeans runs exist, the user + picks which one to operate on via a shared selector at the top. + """ + df_plot = st.session_state.get("data", None) + kmeans_cols = _get_available_kmeans_cols(df_plot) + + if not kmeans_cols: + st.info("Run KMeans first to enable saving by cluster.") + return + + # Shared selector: which KMeans run drives both save operations + default_idx = len(kmeans_cols) - 1 # most recent run + selected_kmeans_col = st.selectbox( + "KMeans result", + options=kmeans_cols, + index=default_idx, + key="save_kmeans_selector", + help="Pick which KMeans run to use for save / repartition.", + ) - else: - st.info("Run clustering first to enable this utility.") + # --- Save images from a specific cluster utility --- + save_status_placeholder = st.empty() + with st.expander("Save Images from Specific Cluster", expanded=True): + available_clusters = sorted(df_plot[selected_kmeans_col].unique(), key=lambda x: int(x)) + selected_clusters = st.multiselect( + "Select cluster(s) to save", + available_clusters, + default=available_clusters[:1] if available_clusters else [], + key="save_cluster_select", + ) + save_dir = st.text_input( + "Directory to save selected cluster images", + value="cluster_selected_output", + key="save_cluster_dir", + ) + save_cluster_button = st.button("Save images", key="save_cluster_btn") + + if save_cluster_button and selected_clusters: + cluster_rows = df_plot[df_plot[selected_kmeans_col].isin(selected_clusters)].copy() + # FileService expects a 'cluster' column + cluster_rows["cluster"] = cluster_rows[selected_kmeans_col] + max_workers = st.session_state.get("num_threads", 8) + + with StreamlitProgressContext( + save_status_placeholder, + f"Images from cluster(s) {', '.join(map(str, selected_clusters))} saved successfully!" + ) as progress: + try: + save_summary_df, csv_path = FileService.save_cluster_images( + cluster_rows, save_dir, max_workers, progress_callback=progress + ) + st.info(f"Summary CSV saved at {csv_path}") + except Exception as e: + save_status_placeholder.error(f"Error saving images: {e}") + elif save_cluster_button: + save_status_placeholder.warning("Please select at least one cluster.") # --- Repartition expander and status --- repartition_status_placeholder = st.empty() @@ -243,7 +338,7 @@ def render_save_section(): repartition_dir = st.text_input( "Directory", value="repartitioned_output", - key="repartition_dir" + key="repartition_dir", ) max_workers = st.number_input( "Number of threads (higher = faster, try 8-32)", @@ -251,49 +346,34 @@ def render_save_section(): max_value=64, value=8, step=1, - key="num_threads" + key="num_threads", ) repartition_button = st.button("Repartition images by cluster", key="repartition_btn") - # Handle repartition execution if repartition_button: - df_plot = st.session_state.get("data", None) - - if df_plot is None or len(df_plot) < 1: - repartition_status_placeholder.warning("Please run clustering first before repartitioning images.") - else: - with StreamlitProgressContext( - repartition_status_placeholder, - f"Repartition complete! Images organized in {repartition_dir}" - ) as progress: - try: - repartition_summary_df, csv_path = FileService.repartition_images_by_cluster( - df_plot, repartition_dir, max_workers, progress_callback=progress - ) - st.info(f"Summary CSV saved at {csv_path}") - - except Exception as e: - repartition_status_placeholder.error(f"Error repartitioning images: {e}") + df_for_repartition = df_plot.copy() + df_for_repartition["cluster"] = df_for_repartition[selected_kmeans_col] + with StreamlitProgressContext( + repartition_status_placeholder, + f"Repartition complete! Images organized in {repartition_dir}", + ) as progress: + try: + repartition_summary_df, csv_path = FileService.repartition_images_by_cluster( + df_for_repartition, repartition_dir, max_workers, progress_callback=progress + ) + st.info(f"Summary CSV saved at {csv_path}") + except Exception as e: + repartition_status_placeholder.error(f"Error repartitioning images: {e}") def render_clustering_sidebar(): - """Render the complete clustering sidebar with all sections.""" + """Render the complete sidebar with embed / project / KMeans / save sections.""" tab_compute, tab_save = st.tabs(["Compute", "Save"]) with tab_compute: - embed_button, image_dir, model_name, n_workers, batch_size = render_embedding_section() - cluster_button, n_clusters, reduction_method = render_clustering_section(n_workers) + render_embedding_section() + render_projection_section() + render_kmeans_section() with tab_save: render_save_section() - - return { - 'embed_button': embed_button, - 'image_dir': image_dir, - 'model_name': model_name, - 'n_workers': n_workers, - 'batch_size': batch_size, - 'cluster_button': cluster_button, - 'n_clusters': n_clusters, - 'reduction_method': reduction_method, - } diff --git a/shared/components/summary.py b/shared/components/summary.py index 03dcd29..1a74847 100644 --- a/shared/components/summary.py +++ b/shared/components/summary.py @@ -144,47 +144,74 @@ def render_taxonomic_tree_summary(): def render_clustering_summary(show_taxonomy=False): - """Render the clustering summary panel using cached results from clustering action.""" + """Render the clustering summary panel using cached results per KMeans run. + + For the embed_explore app, when multiple KMeans runs exist on df_plot, + the user can pick which run's summary + representative images to display. + Summaries are cached per kmeans_col by `_run_kmeans` so switching is instant. + """ df_plot = st.session_state.get("data", None) - labels = st.session_state.get("labels", None) - # Get pre-computed summary from session state (computed when clustering was run) - summary_df = st.session_state.get("clustering_summary", None) - representatives = st.session_state.get("clustering_representatives", None) - - if df_plot is not None: - has_images = 'image_path' in df_plot.columns - - if has_images: - # embed_explore app: show full clustering summary with representative images - if labels is not None: - st.subheader("Clustering Summary") - - if summary_df is not None and representatives is not None: - logger.debug("Displaying cached clustering summary") - st.dataframe(summary_df, hide_index=True, width='stretch') - - st.markdown("#### Representative Images") - for row in summary_df.itertuples(): - k = row.Cluster - st.markdown(f"**Cluster {k}**") - img_cols = st.columns(3) - for i, img_idx in enumerate(representatives[k]): - img_path = df_plot.iloc[img_idx]["image_path"] - logger.debug(f"Displaying representative image: {img_path}") - img_cols[i].image( - img_path, - width='stretch', - caption=os.path.basename(img_path) - ) - else: - st.info("Clustering summary will be computed when you run clustering.") - else: - # Precalculated app: show taxonomy tree (works with or without KMeans) - if show_taxonomy: - filtered_df = st.session_state.get("filtered_df_for_clustering", None) - if filtered_df is not None: - render_taxonomic_tree_summary() + if df_plot is None: + st.info("Summary will appear here after projection.") + return + + has_images = 'image_path' in df_plot.columns + + if has_images: + # embed_explore app: full clustering summary with representative images + kmeans_cols = sorted( + [c for c in df_plot.columns if c.startswith("KMeans (k=")], + key=lambda c: int(c.split("=")[1].rstrip(")")), + ) + + if not kmeans_cols: + st.subheader("Clustering Summary") + st.info("Run KMeans to see the clustering summary and representative images.") + return + + summaries = st.session_state.get("clustering_summaries", {}) or {} + reps_by_col = st.session_state.get("clustering_representatives_by_col", {}) or {} + + st.subheader("Clustering Summary") + default_idx = len(kmeans_cols) - 1 # most recent run + selected_kmeans_col = st.selectbox( + "KMeans result", + options=kmeans_cols, + index=default_idx, + key="summary_kmeans_selector", + help="Select which KMeans run to view summary + representative images for.", + ) + summary_df = summaries.get(selected_kmeans_col) + representatives = reps_by_col.get(selected_kmeans_col) + + if summary_df is None or representatives is None: + st.info( + f"No cached summary for {selected_kmeans_col}. " + "Re-run KMeans with this k to regenerate it." + ) + return + + logger.debug(f"Displaying cached clustering summary for {selected_kmeans_col}") + st.dataframe(summary_df, hide_index=True, width='stretch') + + st.markdown("#### Representative Images") + for row in summary_df.itertuples(): + k = row.Cluster + st.markdown(f"**Cluster {k}**") + img_cols = st.columns(3) + for i, img_idx in enumerate(representatives[k]): + img_path = df_plot.iloc[img_idx]["image_path"] + logger.debug(f"Displaying representative image: {img_path}") + img_cols[i].image( + img_path, + width='stretch', + caption=os.path.basename(img_path), + ) else: - st.info("Summary will appear here after projection.") + # Precalculated app: show taxonomy tree (works with or without KMeans) + if show_taxonomy: + filtered_df = st.session_state.get("filtered_df_for_clustering", None) + if filtered_df is not None: + render_taxonomic_tree_summary() diff --git a/shared/components/visualization.py b/shared/components/visualization.py index 3b57f27..d5dc4c5 100644 --- a/shared/components/visualization.py +++ b/shared/components/visualization.py @@ -77,54 +77,51 @@ def _render_chart_fragment(df_plot): else: heatmap_bins = 40 # Default, not used - # Determine color column - if is_precalculated: - # Build list of colorable columns - skip_color_cols = {'x', 'y', 'idx', 'uuid', 'emb', 'embedding', 'embeddings', 'vector', - 'identifier', 'image_url', 'url', 'img_url', 'image'} - colorable_cols = [c for c in df_plot.columns - if c not in skip_color_cols and df_plot[c].nunique() <= 100] - - # Sort KMeans columns to front (all runs, sorted by k) - kmeans_cols = sorted( - [c for c in colorable_cols if c.startswith("KMeans (k=")], - key=lambda c: int(c.split("=")[1].rstrip(")")) + # Determine color column — same dropdown pattern for both apps. + # Build list of colorable columns (skip technical/identifier columns). + skip_color_cols = {'x', 'y', 'idx', 'uuid', 'emb', 'embedding', 'embeddings', 'vector', + 'identifier', 'image_url', 'url', 'img_url', 'image', + 'image_path', 'file_name'} + colorable_cols = [c for c in df_plot.columns + if c not in skip_color_cols and df_plot[c].nunique() <= 100] + + # Sort KMeans columns to front (all runs, sorted by k) + kmeans_cols = sorted( + [c for c in colorable_cols if c.startswith("KMeans (k=")], + key=lambda c: int(c.split("=")[1].rstrip(")")) + ) + other_cols = [c for c in colorable_cols if not c.startswith("KMeans (k=")] + colorable_cols = kmeans_cols + other_cols + + # Build unique count lookup for display + col_nunique = {c: df_plot[c].nunique() for c in colorable_cols} + + if colorable_cols: + color_col = st.selectbox( + "Color by", + options=["(none)"] + colorable_cols, + index=0, + key="color_by_column", + format_func=lambda c: c if c == "(none)" else f"{c} ({col_nunique[c]})", + help="Select a column to color the points by" ) - other_cols = [c for c in colorable_cols if not c.startswith("KMeans (k=")] - colorable_cols = kmeans_cols + other_cols - - # Build unique count lookup for display - col_nunique = {c: df_plot[c].nunique() for c in colorable_cols} - - if colorable_cols: - color_col = st.selectbox( - "Color by", - options=["(none)"] + colorable_cols, - index=0, - key="color_by_column", - format_func=lambda c: c if c == "(none)" else f"{c} ({col_nunique[c]})", - help="Select a column to color the points by" - ) - if color_col == "(none)": - color_col = None - else: + if color_col == "(none)": color_col = None - - # Warning for high cardinality - if color_col and df_plot[color_col].nunique() > 20: - st.warning(f"'{color_col}' has {df_plot[color_col].nunique()} unique values. Colors may repeat.") - - # Trigger full page rerun when color changes (so bottom section updates). - # Use a sentinel to distinguish "never set" from "set to None". - _sentinel = object() - prev_color = st.session_state.get("_prev_color_by", _sentinel) - if color_col != prev_color: - st.session_state["_prev_color_by"] = color_col - if prev_color is not _sentinel: - st.rerun(scope="app") else: - # embed_explore app: always color by cluster - color_col = 'cluster' if 'cluster' in df_plot.columns else None + color_col = None + + # Warning for high cardinality + if color_col and df_plot[color_col].nunique() > 20: + st.warning(f"'{color_col}' has {df_plot[color_col].nunique()} unique values. Colors may repeat.") + + # Trigger full page rerun when color changes (so bottom section updates). + # Use a sentinel to distinguish "never set" from "set to None". + _sentinel = object() + prev_color = st.session_state.get("_prev_color_by", _sentinel) + if color_col != prev_color: + st.session_state["_prev_color_by"] = color_col + if prev_color is not _sentinel: + st.rerun(scope="app") point_selector = alt.selection_point(fields=["idx"], name="point_selection") @@ -133,13 +130,11 @@ def _render_chart_fragment(df_plot): skip_cols = {'x', 'y', 'idx', 'emb', 'embedding', 'embeddings', 'vector', 'uuid', 'identifier', 'image_url', 'url', 'img_url', 'image'} - # For embed_explore, include cluster/cluster_name in tooltip - if not is_precalculated: - if 'cluster_name' in df_plot.columns: - tooltip_fields.append('cluster_name:N') - elif 'cluster' in df_plot.columns: - tooltip_fields.append('cluster:N') - skip_cols.update({'cluster', 'cluster_name'}) + # For embed_explore, include the file_name in the tooltip for quick reference + if not is_precalculated and 'file_name' in df_plot.columns: + tooltip_fields.append('file_name:N') + skip_cols.add('file_name') + skip_cols.add('image_path') # Add the color column first if set (and not already in tooltip) if color_col and color_col not in skip_cols: