diff --git a/chemap/plotting/chem_space_umap.py b/chemap/plotting/chem_space_umap.py index 138867d..8fe28b5 100644 --- a/chemap/plotting/chem_space_umap.py +++ b/chemap/plotting/chem_space_umap.py @@ -3,7 +3,11 @@ import numpy as np import pandas as pd from chemap import FingerprintConfig, compute_fingerprints -from chemap.fingerprint_conversions import fingerprints_to_csr +from chemap.fingerprint_conversions import ( + fingerprints_to_csr, + fingerprints_to_tfidf, + idf_normalized, +) from chemap.metrics import ( tanimoto_distance_dense, tanimoto_distance_sparse, @@ -53,7 +57,7 @@ def create_chem_space_umap( fpgen: Optional[Any] = None, fingerprint_config: Optional[FingerprintConfig] = None, show_progress: bool = True, - log_count: bool = False, + scaling: str = None, # UMAP (CPU / umap-learn) n_neighbors: int = 100, min_dist: float = 0.25, @@ -80,9 +84,9 @@ def create_chem_space_umap( FingerprintConfig(count=True, folded=False, invalid_policy="raise") show_progress: Forwarded to compute_fingerprints. - log_count: - If True, apply np.log1p to counts (works for sparse CSR and dense arrays). - (For binary fingerprints this is harmless) + scaling: + Define scaling for count fingerprints. Default is None, which means no scaling. + Can be set to "log" for log1p scaling, or to "tfidf" for TF-IDF scaling of bits. n_neighbors, min_dist, umap_random_state: Standard UMAP parameters. n_jobs: @@ -137,14 +141,20 @@ def create_chem_space_umap( if not fingerprint_config.folded: # Convert to CSR matrix - fps_csr = fingerprints_to_csr(fingerprints).X + if scaling == "tfidf": + fps_csr = fingerprints_to_tfidf(fingerprints).X + else: + fps_csr = fingerprints_to_csr(fingerprints).X - if log_count: - # Works well for count fingerprints ( for binary it's essentially unchanged). - fps_csr = _log1p_csr_inplace(fps_csr) + if scaling == "log": + fps_csr = _log1p_csr_inplace(fps_csr) coords = reducer.fit_transform(fps_csr) else: + if scaling == "log": + fingerprints = np.log1p(fingerprints) + elif scaling == "tfidf": + fingerprints *= idf_normalized((fingerprints > 0).sum(axis=0), fingerprints.shape[0]) coords = reducer.fit_transform(fingerprints) df[x_col] = coords[:, 0] @@ -163,13 +173,39 @@ def create_chem_space_umap_gpu( fpgen: Optional[Any] = None, fingerprint_config: Optional[FingerprintConfig] = None, show_progress: bool = True, - log_count: bool = False, + scaling: str = None, # UMAP (GPU / cuML) n_neighbors: int = 100, min_dist: float = 0.25, ) -> pd.DataFrame: """Compute fingerprints and create 2D UMAP coordinates using cuML (GPU). + Parameters + ---------- + data: + Input dataframe containing a SMILES column. + col_smiles: + Name of the SMILES column. + inplace: + If True, write x/y columns into `data` and return it. Else returns a copy. + x_col, y_col: + Output coordinate column names. + fpgen: + RDKit fingerprint generator. Defaults to Morgan radius=9, fpSize=4096. + fingerprint_config: + FingerprintConfig for chemap.compute_fingerprints. Defaults to: + FingerprintConfig(count=True, folded=False, invalid_policy="raise") + show_progress: + Forwarded to compute_fingerprints. + scaling: + Define scaling for count fingerprints. Default is None, which means no scaling. + Can be set to "log" for log1p scaling, or to "tfidf" for TF-IDF scaling of bits. + n_neighbors, min_dist, umap_random_state: + Standard UMAP parameters. + n_jobs: + Passed to umap-learn UMAP for parallelism. Ignores random_state when n_jobs != 1. + Default -1 uses all CPUs. + Notes ----- - cuML UMAP here is fixed to metric="cosine" @@ -222,12 +258,12 @@ def create_chem_space_umap_gpu( ) # Reduce memory footprint (works well for count fingerprints) - if not log_count: - # stays integer-like - fps = fingerprints.astype(np.int8, copy=False) + if scaling == "log": + fingerprints = np.log1p(fingerprints).astype(np.float32, copy=False) + elif scaling == "tfidf": + fingerprints *= idf_normalized((fingerprints > 0).sum(axis=0), fingerprints.shape[0]) else: - # log1p returns float - fps = np.log1p(fingerprints).astype(np.float32, copy=False) + fingerprints = fingerprints.astype(np.int8, copy=False) umap_model = cuUMAP( n_neighbors=int(n_neighbors), @@ -238,7 +274,7 @@ def create_chem_space_umap_gpu( n_components=2, ) - coords = umap_model.fit_transform(fps) + coords = umap_model.fit_transform(fingerprints) # cuML may return cupy/cudf-backed arrays; np.asarray makes it safe for pandas columns. coords_np = np.asarray(coords) diff --git a/chemap/plotting/scatter_plots.py b/chemap/plotting/scatter_plots.py index aee388a..22bdb3b 100644 --- a/chemap/plotting/scatter_plots.py +++ b/chemap/plotting/scatter_plots.py @@ -36,6 +36,9 @@ class ScatterStyle: alpha: float = 0.25 linewidths: float = 0.0 + display_legend: bool = True + legend_outside: bool = False + legend_title: Optional[str] = None legend_loc: str = "lower left" legend_frameon: bool = False @@ -132,21 +135,40 @@ def scatter_plot_base( ax.set_xlabel("") ax.set_ylabel("") - legend_title = style.legend_title if style.legend_title is not None else label_col - handles = _build_legend_handles( - legend_labels, - palette, - markersize=style.legend_markersize, - alpha=style.legend_alpha, - ) + # ---- legend (optional + outside option) ---- + if style.display_legend: + legend_title = style.legend_title if style.legend_title is not None else label_col + handles = _build_legend_handles( + legend_labels, + palette, + markersize=style.legend_markersize, + alpha=style.legend_alpha, + ) - ax.legend( - handles=handles, - title=legend_title, - loc=style.legend_loc, - frameon=style.legend_frameon, - ncol=style.legend_ncol, - ) + if style.legend_outside: + # Put legend outside right; loc controls anchor point of legend box itself. + ax.legend( + handles=handles, + title=legend_title, + loc="center left", + bbox_to_anchor=(1.02, 0.5), + frameon=style.legend_frameon, + ncol=style.legend_ncol, + borderaxespad=0.0, + ) + # Leave room on the right so legend isn't clipped + fig.tight_layout(rect=(0, 0, 0.85, 1)) + else: + ax.legend( + handles=handles, + title=legend_title, + loc=style.legend_loc, + frameon=style.legend_frameon, + ncol=style.legend_ncol, + ) + fig.tight_layout() + else: + fig.tight_layout() fig.tight_layout() return fig, ax @@ -174,6 +196,8 @@ def scatter_plot_all_classes( s: float = 5.0, alpha: float = 0.25, linewidths: float = 0.0, + display_legend: bool = True, + legend_outside: bool = False, legend_title: Optional[str] = None, legend_loc: str = "lower left", legend_frameon: bool = False, @@ -243,6 +267,8 @@ def scatter_plot_all_classes( s=s, alpha=alpha, linewidths=linewidths, + display_legend=display_legend, + legend_outside=legend_outside, legend_title=legend_title if legend_title is not None else subclass_col, legend_loc=legend_loc, legend_frameon=legend_frameon, @@ -300,6 +326,8 @@ def scatter_plot_hierarchical_labels( s: float = 2.0, alpha: float = 0.2, linewidths: float = 0.0, + display_legend: bool = True, + legend_outside: bool = False, legend_title: str = "Class / Superclass", legend_loc: str = "lower left", legend_frameon: bool = False, @@ -398,6 +426,8 @@ def scatter_plot_hierarchical_labels( s=s, alpha=alpha, linewidths=linewidths, + display_legend=display_legend, + legend_outside=legend_outside, legend_title=legend_title, legend_loc=legend_loc, legend_frameon=legend_frameon,