Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ An example evaluation command for a sets model:
state-sets sets predict --output_dir /home/aadduri/state-sets/test/ --checkpoint last.ckpt
```

An example inference command for a sets model:

```bash
state-sets sets infer --output /home/dhruvgautam/state-sets/test/ --output_dir /path/to/model/ --checkpoint /path/to/model/checkpoints/last.ckpt --adata /path/to/anndata/processed.h5 --pert_col gene --embed_key X_hvg
```

The toml files should be setup to define perturbation splits, if running fewshot experiments. Here are some examples:

```toml
Expand Down
4 changes: 4 additions & 0 deletions src/state_sets/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
add_arguments_state,
run_sets_predict,
run_sets_train,
run_sets_infer,
run_state_embed,
run_state_train,
)
Expand Down Expand Up @@ -60,6 +61,9 @@ def main():
case "predict":
# For now, predict uses argparse and not hydra
run_sets_predict(args)
case "infer":
# Run inference using argparse, similar to predict
run_sets_infer(args)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion src/state_sets/_cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._sets import add_arguments_sets, run_sets_predict, run_sets_train
from ._sets import add_arguments_sets, run_sets_predict, run_sets_train, run_sets_infer
from ._state import add_arguments_state, run_state_embed, run_state_train

__all__ = [
Expand All @@ -7,5 +7,6 @@
"run_sets_train",
"run_state_embed",
"run_sets_predict",
"run_sets_infer",
"run_state_train",
]
4 changes: 3 additions & 1 deletion src/state_sets/_cli/_sets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from ._predict import add_arguments_predict, run_sets_predict
from ._train import add_arguments_train, run_sets_train
from ._infer import add_arguments_infer, run_sets_infer

__all__ = ["run_sets_train", "run_sets_predict", "add_arguments_sets"]
__all__ = ["run_sets_train", "run_sets_predict", "run_sets_infer", "add_arguments_sets"]


def add_arguments_sets(parser: ap.ArgumentParser):
""""""
subparsers = parser.add_subparsers(required=True, dest="subcommand")
add_arguments_train(subparsers.add_parser("train"))
add_arguments_predict(subparsers.add_parser("predict"))
add_arguments_infer(subparsers.add_parser("infer"))
228 changes: 228 additions & 0 deletions src/state_sets/_cli/_sets/_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import argparse
import scanpy as sc
import torch
import numpy as np
import os
import pandas as pd
from tqdm import tqdm
import yaml
import pickle

from ...sets.models.pert_sets import PertSetsPerturbationModel


def add_arguments_infer(parser: argparse.ArgumentParser):
parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint (.ckpt)")
parser.add_argument("--adata", type=str, required=True, help="Path to input AnnData file (.h5ad)")
parser.add_argument("--embed_key", type=str, default="X_hvg", help="Key in adata.obsm for input features")
parser.add_argument(
"--pert_col", type=str, default="drugname_drugconc", help="Column in adata.obs for perturbation labels"
)
parser.add_argument("--output", type=str, default=None, help="Path to output AnnData file (.h5ad)")
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Path to the output_dir containing the config.yaml file and the pert_onehot_map.pt file that was saved during training.",
)
parser.add_argument(
"--celltype_col", type=str, default=None, help="Column in adata.obs for cell type labels (optional)"
)
parser.add_argument(
"--celltypes", type=str, default=None, help="Comma-separated list of cell types to include (optional)"
)
parser.add_argument("--batch_size", type=int, default=1000, help="Batch size for inference (default: 1000)")


def run_sets_infer(args):
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def load_config(cfg_path: str) -> dict:
"""Load config from the YAML file that was dumped during training."""
if not os.path.exists(cfg_path):
raise FileNotFoundError(f"Could not find config file: {cfg_path}")
with open(cfg_path, "r") as f:
cfg = yaml.safe_load(f)
return cfg

# Load the config
config_path = os.path.join(args.output_dir, "config.yaml")
cfg = load_config(config_path)
logger.info(f"Loaded config from {config_path}")

# Get perturbation dimensions and mapping from data module
var_dims_path = os.path.join(args.output_dir, "var_dims.pkl")
with open(var_dims_path, "rb") as f:
var_dims = pickle.load(f)
pert_dim = var_dims["pert_dim"]

# Load model
logger.info(f"Loading model from checkpoint: {args.checkpoint}")
model = PertSetsPerturbationModel.load_from_checkpoint(args.checkpoint)
model.eval()
cell_sentence_len = model.cell_sentence_len
device = next(model.parameters()).device

# Load AnnData
logger.info(f"Loading AnnData from: {args.adata}")
adata = sc.read_h5ad(args.adata)

# Optionally filter by cell type
if args.celltype_col is not None and args.celltypes is not None:
celltypes = [ct.strip() for ct in args.celltypes.split(",")]
if args.celltype_col not in adata.obs:
raise ValueError(f"Column '{args.celltype_col}' not found in adata.obs.")
initial_n = adata.n_obs
adata = adata[adata.obs[args.celltype_col].isin(celltypes)].copy()
logger.info(f"Filtered AnnData to {adata.n_obs} cells of types {celltypes} (from {initial_n} cells)")
elif args.celltype_col is not None:
if args.celltype_col not in adata.obs:
raise ValueError(f"Column '{args.celltype_col}' not found in adata.obs.")
logger.info(f"No cell type filtering applied, but cell type column '{args.celltype_col}' is available.")

# Get input features
if args.embed_key in adata.obsm:
X = adata.obsm[args.embed_key]
logger.info(f"Using adata.obsm['{args.embed_key}'] as input features: shape {X.shape}")
else:
X = adata.X
logger.info(f"Using adata.X as input features: shape {X.shape}")

# Prepare perturbation tensor using the data module's mapping
pert_names = adata.obs[args.pert_col].values
pert_tensor = torch.zeros((len(pert_names), pert_dim), device="cpu") # Keep on CPU initially
logger.info(f"Perturbation tensor shape: {pert_tensor.shape}")

# Load perturbation mapping from torch file
pert_onehot_map_path = os.path.join(args.output_dir, "pert_onehot_map.pt")
pert_onehot_map = torch.load(pert_onehot_map_path, weights_only=False)

logger.info(f"Data module has {len(pert_onehot_map)} perturbations in mapping")
logger.info(f"First 10 perturbations in data module: {list(pert_onehot_map.keys())[:10]}")

unique_pert_names = sorted(set(pert_names))
logger.info(f"AnnData has {len(unique_pert_names)} unique perturbations")
logger.info(f"First 10 perturbations in AnnData: {unique_pert_names[:10]}")

# Check overlap
overlap = set(unique_pert_names) & set(pert_onehot_map.keys())
logger.info(f"Overlap between AnnData and data module: {len(overlap)} perturbations")
if len(overlap) < len(unique_pert_names):
missing = set(unique_pert_names) - set(pert_onehot_map.keys())
logger.warning(f"Missing perturbations: {list(missing)[:10]}")

# Check if there's a control perturbation that might match
control_pert = cfg["data"]["kwargs"]["control_pert"]
logger.info(f"Control perturbation in data module: '{control_pert}'")

matched_count = 0
for idx, name in enumerate(pert_names):
if name in pert_onehot_map:
pert_tensor[idx] = pert_onehot_map[name]
matched_count += 1
else:
# For now, use control perturbation as fallback
if control_pert in pert_onehot_map:
pert_tensor[idx] = pert_onehot_map[control_pert]
else:
# Use first available perturbation as fallback
first_pert = list(pert_onehot_map.keys())[0]
pert_tensor[idx] = pert_onehot_map[first_pert]

logger.info(f"Matched {matched_count} out of {len(pert_names)} perturbations")

# Process in batches with progress bar
# Use cell_sentence_len as batch size since model expects this
n_samples = len(pert_names)
batch_size = cell_sentence_len # Model requires this exact batch size
n_batches = (n_samples + batch_size - 1) // batch_size # Ceiling division

logger.info(
f"Running inference on {n_samples} samples in {n_batches} batches of size {batch_size} (model's cell_sentence_len)..."
)

all_preds = []

with torch.no_grad():
progress_bar = tqdm(total=n_samples, desc="Processing samples", unit="samples")

for batch_idx in range(n_batches):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, n_samples)
current_batch_size = end_idx - start_idx

# Get batch data
X_batch = torch.tensor(X[start_idx:end_idx], dtype=torch.float32).to(device)
pert_batch = pert_tensor[start_idx:end_idx].to(device)
pert_names_batch = pert_names[start_idx:end_idx].tolist()

# Pad the batch to cell_sentence_len if it's the last incomplete batch
if current_batch_size < cell_sentence_len:
# Pad with zeros for embeddings
padding_size = cell_sentence_len - current_batch_size
X_pad = torch.zeros((padding_size, X_batch.shape[1]), device=device)
X_batch = torch.cat([X_batch, X_pad], dim=0)

# Pad perturbation tensor with control perturbation
pert_pad = torch.zeros((padding_size, pert_batch.shape[1]), device=device)
if control_pert in pert_onehot_map:
pert_pad[:] = pert_onehot_map[control_pert].to(device)
else:
pert_pad[:, 0] = 1 # Default to first perturbation
pert_batch = torch.cat([pert_batch, pert_pad], dim=0)

# Extend perturbation names
pert_names_batch.extend([control_pert] * padding_size)

# Prepare batch - use same format as working code
batch = {
"ctrl_cell_emb": X_batch,
"pert_emb": pert_batch, # Keep as 2D tensor
"pert_name": pert_names_batch,
"batch": torch.zeros((1, cell_sentence_len), device=device), # Use (1, cell_sentence_len)
}

# Run inference on batch using padded=False like in working code
batch_preds = model.predict_step(batch, batch_idx=batch_idx, padded=False)

# Extract predictions from the dictionary returned by predict_step
# Use gene decoder output if available, otherwise use latent predictions
if "pert_cell_counts_preds" in batch_preds and batch_preds["pert_cell_counts_preds"] is not None:
# Use gene space predictions (from decoder)
pred_tensor = batch_preds["pert_cell_counts_preds"]
else:
# Use latent space predictions
pred_tensor = batch_preds["preds"]

# Only keep predictions for the actual samples (not padding)
actual_preds = pred_tensor[:current_batch_size]
all_preds.append(actual_preds.cpu().numpy())

# Update progress bar
progress_bar.update(current_batch_size)

progress_bar.close()

# Concatenate all predictions
preds_np = np.concatenate(all_preds, axis=0)

# Save predictions to AnnData
pred_key = "model_preds"
adata.obsm[pred_key] = preds_np
output_path = args.output or args.adata.replace(".h5ad", "_with_preds.h5ad")
adata.write_h5ad(output_path)
logger.info(f"Saved predictions to {output_path} (in adata.obsm['{pred_key}'])")


def main():
parser = argparse.ArgumentParser(description="Run inference on AnnData with a trained model checkpoint.")
add_arguments_infer(parser)
args = parser.parse_args()
run_sets_infer(args)


if __name__ == "__main__":
main()
Loading