Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4d39b85
refactor: update to new pdex
noamteyssier Feb 26, 2026
13bfa6c
refactor: remove unused pdex arguments for batch_size and metric
noamteyssier Feb 26, 2026
e276970
fix: properly recase float16 to float32 at minimum
noamteyssier Feb 26, 2026
e5407e4
refactor: move recast to utils
noamteyssier Feb 26, 2026
f3ea3ab
test: remove unused alt metric
noamteyssier Feb 26, 2026
824e63a
chore: finish removing all outdated pdex arguments
noamteyssier Feb 26, 2026
287ed12
Merge pull request #222 from ArcInstitute/220-update-to-new-pdex
noamteyssier Feb 26, 2026
76b3f93
dep: added ty to project
noamteyssier Feb 26, 2026
e25e1e1
style: fix all typing errors or ambiguities
noamteyssier Feb 26, 2026
4c529f9
ci: added typing to ci
noamteyssier Feb 26, 2026
88023cf
Merge pull request #223 from ArcInstitute/refactor/move-typing-to-ty
noamteyssier Feb 26, 2026
a3609d1
chore: remove unused configuration
noamteyssier Feb 26, 2026
fe81c31
feat: added a claude md
noamteyssier Feb 26, 2026
e9c11a9
Merge pull request #224 from ArcInstitute/feat/claude-support
noamteyssier Feb 26, 2026
8901383
chore(semver): bump - breaking changes
noamteyssier Feb 26, 2026
4984a71
refactor: remove redundant code
noamteyssier Feb 26, 2026
e7047a2
fix: deprecation warning on is_view
noamteyssier Feb 26, 2026
3e36c0e
dep: enforce anndata version
noamteyssier Feb 26, 2026
6a71856
Merge pull request #226 from ArcInstitute/fix/deprecation-notices
noamteyssier Feb 26, 2026
2d314f5
Use left join for DE Spearman/AUC
beabevi Mar 5, 2026
b923274
Merge pull request #228 from beabevi/fix-de-metrics-left-join
noamteyssier Mar 19, 2026
ff5acae
dep: remove upper limit on python version
noamteyssier Mar 24, 2026
e629a20
ci: added more python versions to pytest
noamteyssier Mar 24, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on: [push, pull_request]
jobs:
all_jobs:
runs-on: ubuntu-latest
needs: [formatting, pytest, cli-test]
needs: [formatting, typing, pytest, cli-test]
steps:
- name: Complete
run: echo "Complete"
Expand Down Expand Up @@ -50,7 +50,7 @@ jobs:
run: |
uv run ruff format --check
pytest:
typing:
runs-on: ubuntu-latest

needs: [install-job]
Expand All @@ -69,6 +69,33 @@ jobs:
run: |
uv sync --all-extras --dev
- name: run type checking
run: |
uv run ty check
pytest:
runs-on: ubuntu-latest

needs: [install-job]

strategy:
matrix:
python-version: ["3.12", "3.13", "3.14"]

steps:
- uses: actions/checkout@v4

- name: install uv
uses: astral-sh/setup-uv@v5
with:
enable-cache: true
cache-dependency-glob: "pyproject.toml"
python-version: "${{ matrix.python-version }}"

- name: install dependencies
run: |
uv sync --all-extras --dev
- name: run pytest
run: |
uv run pytest -v
Expand Down
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.12
3.14
84 changes: 84 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

## Project Overview

**cell-eval** is a Python package and CLI tool for evaluating the performance of models that predict cellular responses to perturbations at the single-cell level. Developed by the Arc Research Institute.

It generally revolves around a *real* anndata and a *predicted* anndata where it measures the general differences between the two across a variety of metrics.

- Python 3.11–3.12, managed with **UV** and built with **hatchling**
- CLI entry point: `cell-eval` (defined in `src/cell_eval/__main__.py`)

## Common Commands

```bash
# Install dependencies
uv sync --all-extras --dev

# Run all tests
uv run pytest -v

# Run a single test
uv run pytest tests/test_eval.py::test_broken_adata_not_normlog -v

# Formatting (check / fix)
uv run ruff format --check
uv run ruff format

# Type checking
uv run ty check

# Verify CLI works
uv run cell-eval --help
```

CI runs: formatting, typing, pytest, and cli-test (see `.github/workflows/CI.yml`).

## Architecture

### Core Data Flow

```
AnnData inputs (predicted + real)
→ MetricsEvaluator (validation, normalization, DE computation)
→ MetricPipeline (profile-based metric selection + execution)
→ metrics_registry (global MetricRegistry instance)
→ individual metric functions
→ polars DataFrames (per-perturbation + aggregated results)
```

### Key Abstractions

- **`MetricsEvaluator`** (`src/cell_eval/_evaluator.py`) — Main programmatic entry point. Validates input AnnData objects, computes differential expression via `pdex`, and orchestrates the metric pipeline.

- **`MetricRegistry`** (`src/cell_eval/metrics/_registry.py`) — Global singleton `metrics_registry`. Metrics are registered with a name, type (`DE` or `ANNDATA_PAIR`), compute function, and best-value indicator. Supports both plain functions and class-based metrics requiring instantiation.

- **`MetricPipeline`** (`src/cell_eval/_pipeline/_runner.py`) — Selects and runs metrics based on a profile (`full`, `minimal`, `vcc`, `de`, `anndata`, `pds`). Collects per-perturbation results and aggregates them.

- **`Metric` protocol** (`src/cell_eval/metrics/base.py`) — All metric functions take either a `PerturbationAnndataPair` or `DEComparison` and return `float | dict[str, float]`.

- **Type system** (`src/cell_eval/_types/`) — Immutable dataclasses: `PerturbationAnndataPair`, `DEComparison`, plus enums `MetricType`, `MetricBestValue`, `DESortBy`.

### Metrics

Metrics are split into two categories registered in `src/cell_eval/metrics/_impl.py`:

- **AnnData metrics** (`_anndata.py`): pearson_delta, mse, mae, mse_delta, mae_delta, discrimination_score, clustering_agreement, edistance
- **DE metrics** (`_de.py`): overlap/precision at N, spearman correlations, direction match, significant gene recall, ROC/PR AUC

### CLI

Subcommands in `src/cell_eval/_cli/`: `prep` (data preparation for VCC), `run` (evaluation), `baseline` (create baseline), `score` (normalize against baseline). CLI defaults are in `_cli/_const.py`.

### Test Data Utilities

`cell_eval.data` provides `build_random_anndata()` and `downsample_cells()` for generating synthetic AnnData objects in tests.

## Conventions

- Uses `polars` (not pandas) for DataFrames
- Uses `match`/`case` statements (Python 3.10+ syntax)
- Type hints throughout; PEP 561 `py.typed` marker present
- Private modules prefixed with `_` (public API is re-exported from `__init__.py`)
18 changes: 10 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,34 +1,36 @@
[project]
name = "cell-eval"
version = "0.6.8"
version = "0.7.0"
description = "Evaluation metrics for single-cell perturbation predictions"
readme = "README.md"
authors = [
{ name = "Noam Teyssier", email = "noam.teyssier@arcinstitute.org" },
{ name = "Abhinav Adduri", email = "abhinav.adduri@arcinstitute.org" },
{ name = "Yusuf Roohani", email = "yusuf.roohani@arcinstitute.org" },
]
requires-python = ">=3.10,<3.13"
requires-python = ">=3.11"
dependencies = [
"igraph>=0.11.8",
"pdex>=0.1.26",
"pdex>=0.2.0",
"polars>=1.30.0",
"pyyaml>=6.0.2",
"scanpy>=1.10.3",
"pyarrow>=18.0.0",
"tqdm>=4.67.1",
"anndata>=0.12.10",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[dependency-groups]
dev = ["ipykernel>=6.29.5", "pytest>=8.3.5", "ruff>=0.11.8"]
dev = [
"ipykernel>=6.29.5",
"pytest>=8.3.5",
"ruff>=0.11.8",
"ty>=0.0.19",
]

[project.scripts]
cell-eval = "cell_eval.__main__:main"

[tool.pyright]
venvPath = "."
venv = ".venv"
5 changes: 0 additions & 5 deletions ruff.toml

This file was deleted.

26 changes: 12 additions & 14 deletions src/cell_eval/_baseline.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
from typing import Any
from typing import Any, cast

import anndata as ad
import numpy as np
import pandas as pd
import polars as pl
from numpy.typing import NDArray
from pdex import parallel_differential_expression
from pdex import pdex
from scipy.sparse import issparse

from ._evaluator import _build_pdex_kwargs, _convert_to_normlog
Expand All @@ -23,9 +24,7 @@ def build_base_mean_adata(
allow_discrete: bool = False,
output_path: str | None = None,
output_de_path: str | None = None,
batch_size: int = 1000,
num_threads: int = 1,
de_method: str = "wilcoxon",
pdex_kwargs: dict[str, Any] = {},
) -> ad.AnnData:
if isinstance(adata, str):
Expand Down Expand Up @@ -67,7 +66,7 @@ def build_base_mean_adata(
(int(counts[counts_col].sum()), baseline.size),
baseline,
),
var=adata.var,
var=cast(pd.DataFrame, adata.var),
obs=obs,
)

Expand All @@ -78,21 +77,20 @@ def build_base_mean_adata(

if output_path is not None:
logger.info(f"Saving baseline data to {output_path}")
baseline_adata.write_h5ad(output_path)
baseline_adata.write_h5ad(output_path) # type: ignore[invalid-argument-type]

if output_de_path is not None:
logger.info("Calculating differential expression")
pdex_kwargs = _build_pdex_kwargs(
groupby_key=pert_col,
groupby=pert_col,
reference=control_pert,
num_workers=num_threads,
metric=de_method,
batch_size=batch_size,
threads=num_threads,
allow_discrete=allow_discrete,
pdex_kwargs=pdex_kwargs,
)
frame = parallel_differential_expression(
frame = pdex(
adata=baseline_adata,
mode="ref",
**pdex_kwargs,
)
logger.info(f"Saving differential expression results to {output_de_path}")
Expand Down Expand Up @@ -137,9 +135,9 @@ def _build_counts_df_from_adata(
raise ValueError(
f"Column '{pert_col}' not found in adata.obs: {adata.obs.columns}"
)
if control_pert not in adata.obs[pert_col].unique():
if control_pert not in cast(pd.Series, adata.obs[pert_col]).unique():
raise ValueError(
f"Control pert '{control_pert}' not found in adata.obs[{pert_col}]: {adata.obs[pert_col].unique()}"
f"Control pert '{control_pert}' not found in adata.obs[{pert_col}]: {cast(pd.Series, adata.obs[pert_col]).unique()}"
)
logger.info("Building counts DataFrame from adata")
return (
Expand All @@ -161,7 +159,7 @@ def _build_pert_baseline(
raise ValueError(
f"Column '{pert_col}' not found in adata.obs: {adata.obs.columns}"
)
unique_perts = adata.obs[pert_col].unique()
unique_perts = cast(pd.Series, adata.obs[pert_col]).unique()
if control_pert not in unique_perts:
raise ValueError(
f"Control pert '{control_pert}' not found in unique_perts: {unique_perts}"
Expand Down
11 changes: 6 additions & 5 deletions src/cell_eval/_cli/_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shutil
import subprocess
from tempfile import TemporaryDirectory
from typing import cast

import anndata as ad
import numpy as np
Expand Down Expand Up @@ -136,9 +137,9 @@ def strip_anndata(
raise ValueError(
f"Provided celltype column: '{celltype_col}' missing from anndata: {adata.obs.columns}"
)
if ntc_name not in adata.obs[pert_col].unique():
if ntc_name not in cast(pd.Series, adata.obs[pert_col]).unique():
raise ValueError(
f"Provided negative control name: '{ntc_name}' missing from anndata: {adata.obs[pert_col].unique()}"
f"Provided negative control name: '{ntc_name}' missing from anndata: {cast(pd.Series, adata.obs[pert_col]).unique()}"
)

# Check if expected dimension is provided and matches the length of the genelist
Expand Down Expand Up @@ -196,11 +197,11 @@ def strip_anndata(

logger.info("Simplifying obs dataframe")
new_obs = pd.DataFrame(
{output_pert_col: adata.obs[pert_col].values},
{output_pert_col: cast(pd.Series, adata.obs[pert_col]).values},
index=np.arange(adata.shape[0]).astype(str),
)
if celltype_col:
new_obs[output_celltype_col] = adata.obs[celltype_col].values
new_obs[output_celltype_col] = cast(pd.Series, adata.obs[celltype_col]).values

logger.info("Simplifying var dataframe")
new_var = pd.DataFrame(
Expand All @@ -225,7 +226,7 @@ def strip_anndata(

# Write the h5ad file
logger.info(f"Writing h5ad output to {tmp_h5ad}")
minimal.write_h5ad(tmp_h5ad)
minimal.write_h5ad(tmp_h5ad) # type: ignore[invalid-argument-type]

# Zstd compress the h5ad file (will create pred.h5ad.zst)
logger.info(f"Zstd compressing {tmp_h5ad}")
Expand Down
16 changes: 0 additions & 16 deletions src/cell_eval/_cli/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,6 @@ def parse_args_run(parser: ap.ArgumentParser):
default=1,
help="Number of threads to use for parallel processing [default: %(default)s]",
)
parser.add_argument(
"--batch-size",
type=int,
default=100,
help="Batch size for parallel processing [default: %(default)s]",
)
parser.add_argument(
"--de-method",
type=str,
default="wilcoxon",
help="Method to use for differential expression analysis [default: %(default)s]",
)
parser.add_argument(
"--allow-discrete",
action="store_true",
Expand Down Expand Up @@ -166,9 +154,7 @@ def run_evaluation(args: ap.Namespace):
de_real=args.de_real,
control_pert=args.control_pert,
pert_col=args.pert_col,
de_method=args.de_method,
num_threads=args.num_threads,
batch_size=args.batch_size,
outdir=args.outdir,
allow_discrete=args.allow_discrete,
prefix=ct,
Expand All @@ -189,9 +175,7 @@ def run_evaluation(args: ap.Namespace):
de_real=args.de_real,
control_pert=args.control_pert,
pert_col=args.pert_col,
de_method=args.de_method,
num_threads=args.num_threads,
batch_size=args.batch_size,
outdir=args.outdir,
allow_discrete=args.allow_discrete,
skip_de=args.profile == "pds",
Expand Down
Loading
Loading