Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 52 additions & 16 deletions chemap/plotting/chem_space_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import numpy as np
import pandas as pd
from chemap import FingerprintConfig, compute_fingerprints
from chemap.fingerprint_conversions import fingerprints_to_csr
from chemap.fingerprint_conversions import (
fingerprints_to_csr,
fingerprints_to_tfidf,
idf_normalized,
)
from chemap.metrics import (
tanimoto_distance_dense,
tanimoto_distance_sparse,
Expand Down Expand Up @@ -53,7 +57,7 @@ def create_chem_space_umap(
fpgen: Optional[Any] = None,
fingerprint_config: Optional[FingerprintConfig] = None,
show_progress: bool = True,
log_count: bool = False,
scaling: str = None,
# UMAP (CPU / umap-learn)
n_neighbors: int = 100,
min_dist: float = 0.25,
Expand All @@ -80,9 +84,9 @@ def create_chem_space_umap(
FingerprintConfig(count=True, folded=False, invalid_policy="raise")
show_progress:
Forwarded to compute_fingerprints.
log_count:
If True, apply np.log1p to counts (works for sparse CSR and dense arrays).
(For binary fingerprints this is harmless)
scaling:
Define scaling for count fingerprints. Default is None, which means no scaling.
Can be set to "log" for log1p scaling, or to "tfidf" for TF-IDF scaling of bits.
n_neighbors, min_dist, umap_random_state:
Standard UMAP parameters.
n_jobs:
Expand Down Expand Up @@ -137,14 +141,20 @@ def create_chem_space_umap(

if not fingerprint_config.folded:
# Convert to CSR matrix
fps_csr = fingerprints_to_csr(fingerprints).X
if scaling == "tfidf":
fps_csr = fingerprints_to_tfidf(fingerprints).X
else:
fps_csr = fingerprints_to_csr(fingerprints).X

if log_count:
# Works well for count fingerprints ( for binary it's essentially unchanged).
fps_csr = _log1p_csr_inplace(fps_csr)
if scaling == "log":
fps_csr = _log1p_csr_inplace(fps_csr)

coords = reducer.fit_transform(fps_csr)
else:
if scaling == "log":
fingerprints = np.log1p(fingerprints)
elif scaling == "tfidf":
fingerprints *= idf_normalized((fingerprints > 0).sum(axis=0), fingerprints.shape[0])
coords = reducer.fit_transform(fingerprints)

df[x_col] = coords[:, 0]
Expand All @@ -163,13 +173,39 @@ def create_chem_space_umap_gpu(
fpgen: Optional[Any] = None,
fingerprint_config: Optional[FingerprintConfig] = None,
show_progress: bool = True,
log_count: bool = False,
scaling: str = None,
# UMAP (GPU / cuML)
n_neighbors: int = 100,
min_dist: float = 0.25,
) -> pd.DataFrame:
"""Compute fingerprints and create 2D UMAP coordinates using cuML (GPU).

Parameters
----------
data:
Input dataframe containing a SMILES column.
col_smiles:
Name of the SMILES column.
inplace:
If True, write x/y columns into `data` and return it. Else returns a copy.
x_col, y_col:
Output coordinate column names.
fpgen:
RDKit fingerprint generator. Defaults to Morgan radius=9, fpSize=4096.
fingerprint_config:
FingerprintConfig for chemap.compute_fingerprints. Defaults to:
FingerprintConfig(count=True, folded=False, invalid_policy="raise")
show_progress:
Forwarded to compute_fingerprints.
scaling:
Define scaling for count fingerprints. Default is None, which means no scaling.
Can be set to "log" for log1p scaling, or to "tfidf" for TF-IDF scaling of bits.
n_neighbors, min_dist, umap_random_state:
Standard UMAP parameters.
n_jobs:
Passed to umap-learn UMAP for parallelism. Ignores random_state when n_jobs != 1.
Default -1 uses all CPUs.

Notes
-----
- cuML UMAP here is fixed to metric="cosine"
Expand Down Expand Up @@ -222,12 +258,12 @@ def create_chem_space_umap_gpu(
)

# Reduce memory footprint (works well for count fingerprints)
if not log_count:
# stays integer-like
fps = fingerprints.astype(np.int8, copy=False)
if scaling == "log":
fingerprints = np.log1p(fingerprints).astype(np.float32, copy=False)
elif scaling == "tfidf":
fingerprints *= idf_normalized((fingerprints > 0).sum(axis=0), fingerprints.shape[0])
else:
# log1p returns float
fps = np.log1p(fingerprints).astype(np.float32, copy=False)
fingerprints = fingerprints.astype(np.int8, copy=False)

umap_model = cuUMAP(
n_neighbors=int(n_neighbors),
Expand All @@ -238,7 +274,7 @@ def create_chem_space_umap_gpu(
n_components=2,
)

coords = umap_model.fit_transform(fps)
coords = umap_model.fit_transform(fingerprints)

# cuML may return cupy/cudf-backed arrays; np.asarray makes it safe for pandas columns.
coords_np = np.asarray(coords)
Expand Down
58 changes: 44 additions & 14 deletions chemap/plotting/scatter_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class ScatterStyle:
alpha: float = 0.25
linewidths: float = 0.0

display_legend: bool = True
legend_outside: bool = False

legend_title: Optional[str] = None
legend_loc: str = "lower left"
legend_frameon: bool = False
Expand Down Expand Up @@ -132,21 +135,40 @@ def scatter_plot_base(
ax.set_xlabel("")
ax.set_ylabel("")

legend_title = style.legend_title if style.legend_title is not None else label_col
handles = _build_legend_handles(
legend_labels,
palette,
markersize=style.legend_markersize,
alpha=style.legend_alpha,
)
# ---- legend (optional + outside option) ----
if style.display_legend:
legend_title = style.legend_title if style.legend_title is not None else label_col
handles = _build_legend_handles(
legend_labels,
palette,
markersize=style.legend_markersize,
alpha=style.legend_alpha,
)

ax.legend(
handles=handles,
title=legend_title,
loc=style.legend_loc,
frameon=style.legend_frameon,
ncol=style.legend_ncol,
)
if style.legend_outside:
# Put legend outside right; loc controls anchor point of legend box itself.
ax.legend(
handles=handles,
title=legend_title,
loc="center left",
bbox_to_anchor=(1.02, 0.5),
frameon=style.legend_frameon,
ncol=style.legend_ncol,
borderaxespad=0.0,
)
# Leave room on the right so legend isn't clipped
fig.tight_layout(rect=(0, 0, 0.85, 1))
else:
ax.legend(
handles=handles,
title=legend_title,
loc=style.legend_loc,
frameon=style.legend_frameon,
ncol=style.legend_ncol,
)
fig.tight_layout()
else:
fig.tight_layout()

fig.tight_layout()
return fig, ax
Expand Down Expand Up @@ -174,6 +196,8 @@ def scatter_plot_all_classes(
s: float = 5.0,
alpha: float = 0.25,
linewidths: float = 0.0,
display_legend: bool = True,
legend_outside: bool = False,
legend_title: Optional[str] = None,
legend_loc: str = "lower left",
legend_frameon: bool = False,
Expand Down Expand Up @@ -243,6 +267,8 @@ def scatter_plot_all_classes(
s=s,
alpha=alpha,
linewidths=linewidths,
display_legend=display_legend,
legend_outside=legend_outside,
legend_title=legend_title if legend_title is not None else subclass_col,
legend_loc=legend_loc,
legend_frameon=legend_frameon,
Expand Down Expand Up @@ -300,6 +326,8 @@ def scatter_plot_hierarchical_labels(
s: float = 2.0,
alpha: float = 0.2,
linewidths: float = 0.0,
display_legend: bool = True,
legend_outside: bool = False,
legend_title: str = "Class / Superclass",
legend_loc: str = "lower left",
legend_frameon: bool = False,
Expand Down Expand Up @@ -398,6 +426,8 @@ def scatter_plot_hierarchical_labels(
s=s,
alpha=alpha,
linewidths=linewidths,
display_legend=display_legend,
legend_outside=legend_outside,
legend_title=legend_title,
legend_loc=legend_loc,
legend_frameon=legend_frameon,
Expand Down