Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ authors = [
{ name = "Abhinav Adduri", email = "abhinav.adduri@arcinstitute.org" },
{ name = "Yusuf Roohani", email = "yusuf.roohani@arcinstitute.org" },
]
requires-python = ">=3.10,<3.13"
requires-python = ">=3.11,<3.13"
dependencies = [
"igraph>=0.11.8",
"pdex>=0.1.26",
"pdex>=0.2.0",
"polars>=1.30.0",
"pyyaml>=6.0.2",
"scanpy>=1.10.3",
Expand All @@ -24,7 +24,11 @@ requires = ["hatchling"]
build-backend = "hatchling.build"

[dependency-groups]
dev = ["ipykernel>=6.29.5", "pytest>=8.3.5", "ruff>=0.11.8"]
dev = [
"ipykernel>=6.29.5",
"pytest>=8.3.5",
"ruff>=0.11.8",
]

[project.scripts]
cell-eval = "cell_eval.__main__:main"
Expand Down
13 changes: 5 additions & 8 deletions src/cell_eval/_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import polars as pl
from numpy.typing import NDArray
from pdex import parallel_differential_expression
from pdex import pdex
from scipy.sparse import issparse

from ._evaluator import _build_pdex_kwargs, _convert_to_normlog
Expand All @@ -23,9 +23,7 @@ def build_base_mean_adata(
allow_discrete: bool = False,
output_path: str | None = None,
output_de_path: str | None = None,
batch_size: int = 1000,
num_threads: int = 1,
de_method: str = "wilcoxon",
pdex_kwargs: dict[str, Any] = {},
) -> ad.AnnData:
if isinstance(adata, str):
Expand Down Expand Up @@ -83,16 +81,15 @@ def build_base_mean_adata(
if output_de_path is not None:
logger.info("Calculating differential expression")
pdex_kwargs = _build_pdex_kwargs(
groupby_key=pert_col,
groupby=pert_col,
reference=control_pert,
num_workers=num_threads,
metric=de_method,
batch_size=batch_size,
threads=num_threads,
allow_discrete=allow_discrete,
pdex_kwargs=pdex_kwargs,
)
frame = parallel_differential_expression(
frame = pdex(
adata=baseline_adata,
mode="ref",
**pdex_kwargs,
)
logger.info(f"Saving differential expression results to {output_de_path}")
Expand Down
16 changes: 0 additions & 16 deletions src/cell_eval/_cli/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,6 @@ def parse_args_run(parser: ap.ArgumentParser):
default=1,
help="Number of threads to use for parallel processing [default: %(default)s]",
)
parser.add_argument(
"--batch-size",
type=int,
default=100,
help="Batch size for parallel processing [default: %(default)s]",
)
parser.add_argument(
"--de-method",
type=str,
default="wilcoxon",
help="Method to use for differential expression analysis [default: %(default)s]",
)
parser.add_argument(
"--allow-discrete",
action="store_true",
Expand Down Expand Up @@ -166,9 +154,7 @@ def run_evaluation(args: ap.Namespace):
de_real=args.de_real,
control_pert=args.control_pert,
pert_col=args.pert_col,
de_method=args.de_method,
num_threads=args.num_threads,
batch_size=args.batch_size,
outdir=args.outdir,
allow_discrete=args.allow_discrete,
prefix=ct,
Expand All @@ -189,9 +175,7 @@ def run_evaluation(args: ap.Namespace):
de_real=args.de_real,
control_pert=args.control_pert,
pert_col=args.pert_col,
de_method=args.de_method,
num_threads=args.num_threads,
batch_size=args.batch_size,
outdir=args.outdir,
allow_discrete=args.allow_discrete,
skip_de=args.profile == "pds",
Expand Down
53 changes: 16 additions & 37 deletions src/cell_eval/_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import pandas as pd
import polars as pl
import scanpy as sc
from pdex import parallel_differential_expression
from pdex import pdex

from cell_eval.utils import guess_is_lognorm

from ._pipeline import MetricPipeline
from ._types import PerturbationAnndataPair, initialize_de_comparison
from .utils import _cast_float16_to_float32

logger = logging.getLogger(__name__)

Expand All @@ -38,12 +39,8 @@ class MetricsEvaluator:
Control perturbation name.
pert_col: str = "target"
Perturbation column name.
de_method: str = "wilcoxon"
Differential expression method.
num_threads: int = -1
Number of threads for parallel differential expression.
batch_size: int = 100
Batch size for parallel differential expression.
outdir: str = "./cell-eval-outdir"
Output directory.
allow_discrete: bool = False
Expand All @@ -63,9 +60,7 @@ def __init__(
de_real: pl.DataFrame | str | None = None,
control_pert: str = "non-targeting",
pert_col: str = "target",
de_method: str = "wilcoxon",
num_threads: int = -1,
batch_size: int = 100,
outdir: str = "./cell-eval-outdir",
allow_discrete: bool = False,
prefix: str | None = None,
Expand Down Expand Up @@ -96,9 +91,7 @@ def __init__(
anndata_pair=self.anndata_pair,
de_pred=de_pred,
de_real=de_real,
de_method=de_method,
num_threads=num_threads if num_threads != -1 else mp.cpu_count(),
batch_size=batch_size,
allow_discrete=allow_discrete,
outdir=outdir,
prefix=prefix,
Expand Down Expand Up @@ -170,6 +163,10 @@ def _build_anndata_pair(
logger.info(f"Reading pred anndata from {pred}")
pred = ad.read_h5ad(pred)

# Cast float16 to float32 since NUMBA (used by pdex) does not support float16
_cast_float16_to_float32(real, which="real")
_cast_float16_to_float32(pred, which="pred")

# Validate that the input is normalized and log-transformed
_convert_to_normlog(real, which="real", allow_discrete=allow_discrete)
_convert_to_normlog(pred, which="pred", allow_discrete=allow_discrete)
Expand Down Expand Up @@ -220,9 +217,7 @@ def _build_de_comparison(
anndata_pair: PerturbationAnndataPair | None = None,
de_pred: pl.DataFrame | str | None = None,
de_real: pl.DataFrame | str | None = None,
de_method: str = "wilcoxon",
num_threads: int = 1,
batch_size: int = 100,
allow_discrete: bool = False,
outdir: str | None = None,
prefix: str | None = None,
Expand All @@ -233,9 +228,7 @@ def _build_de_comparison(
mode="real",
de_path=de_real,
anndata_pair=anndata_pair,
de_method=de_method,
num_threads=num_threads,
batch_size=batch_size,
allow_discrete=allow_discrete,
outdir=outdir,
prefix=prefix,
Expand All @@ -245,9 +238,7 @@ def _build_de_comparison(
mode="pred",
de_path=de_pred,
anndata_pair=anndata_pair,
de_method=de_method,
num_threads=num_threads,
batch_size=batch_size,
allow_discrete=allow_discrete,
outdir=outdir,
prefix=prefix,
Expand All @@ -258,42 +249,31 @@ def _build_de_comparison(

def _build_pdex_kwargs(
reference: str,
groupby_key: str,
num_workers: int,
batch_size: int,
metric: str,
groupby: str,
threads: int,
allow_discrete: bool,
pdex_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
pdex_kwargs = pdex_kwargs or {}
if "reference" not in pdex_kwargs:
pdex_kwargs["reference"] = reference
if "groupby_key" not in pdex_kwargs:
pdex_kwargs["groupby_key"] = groupby_key
if "num_workers" not in pdex_kwargs:
pdex_kwargs["num_workers"] = num_workers
if "batch_size" not in pdex_kwargs:
pdex_kwargs["batch_size"] = batch_size
if "metric" not in pdex_kwargs:
pdex_kwargs["metric"] = metric
if "groupby" not in pdex_kwargs:
pdex_kwargs["groupby"] = groupby
if "threads" not in pdex_kwargs:
pdex_kwargs["threads"] = threads
if "is_log1p" not in pdex_kwargs:
if allow_discrete:
pdex_kwargs["is_log1p"] = False
else:
pdex_kwargs["is_log1p"] = True

# always return polars DataFrames
pdex_kwargs["as_polars"] = True
return pdex_kwargs


def _load_or_build_de(
mode: Literal["pred", "real"],
de_path: pl.DataFrame | str | None = None,
anndata_pair: PerturbationAnndataPair | None = None,
de_method: str = "wilcoxon",
num_threads: int = 1,
batch_size: int = 100,
outdir: str | None = None,
prefix: str | None = None,
allow_discrete: bool = False,
Expand All @@ -305,16 +285,15 @@ def _load_or_build_de(
logger.info(f"Computing DE for {mode} data")
pdex_kwargs = _build_pdex_kwargs(
reference=anndata_pair.control_pert,
groupby_key=anndata_pair.pert_col,
num_workers=num_threads,
metric=de_method,
batch_size=batch_size,
groupby=anndata_pair.pert_col,
threads=num_threads,
allow_discrete=allow_discrete,
pdex_kwargs=pdex_kwargs or {},
)
logger.info(f"Using the following pdex kwargs: {pdex_kwargs}")
frame = parallel_differential_expression(
frame = pdex(
adata=anndata_pair.real if mode == "real" else anndata_pair.pred,
mode="ref",
**pdex_kwargs,
)
if outdir is not None:
Expand Down
20 changes: 20 additions & 0 deletions src/cell_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import anndata as ad
import numpy as np
import scipy.sparse as sp
from scipy.sparse import csc_matrix, csr_matrix

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -105,3 +106,22 @@ def split_anndata_on_celltype(
ct: adata[adata.obs[celltype_col] == ct]
for ct in adata.obs[celltype_col].unique()
}


def _cast_float16_to_float32(adata: ad.AnnData, which: str | None = None):
"""Cast float16 expression matrix to float32 (inplace).

NUMBA (used by pdex) does not support float16 operations.
"""

x = adata.X
dtype = x.dtype if not sp.issparse(x) else x.data.dtype
if dtype == np.float16:
if which:
logger.info(
f"Casting {which} anndata from float16 to float32 (NUMBA does not support float16)."
)
if sp.issparse(x):
adata.X = x.astype(np.float32)
else:
adata.X = x.astype(np.float32)
23 changes: 3 additions & 20 deletions tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def test_eval_pdex_kwargs():
control_pert="control",
pert_col="perturbation",
pdex_kwargs={
"exp_post_agg": True,
"geometric_mean": False,
},
)
evaluator.compute(
Expand All @@ -282,8 +282,8 @@ def test_eval_pdex_kwargs_duplicated():
control_pert="control",
pert_col="perturbation",
pdex_kwargs={
"exp_post_agg": True,
"num_workers": 4,
"geometric_mean": False,
"threads": 4,
},
)
evaluator.compute(
Expand Down Expand Up @@ -390,20 +390,3 @@ def test_eval_downsampled_cells():
break_on_error=True,
)
validate_expected_files(OUTDIR)


def test_eval_alt_metric():
adata_real = build_random_anndata()
adata_pred = downsample_cells(adata_real, fraction=0.5)
evaluator = MetricsEvaluator(
adata_pred=adata_pred,
adata_real=adata_real,
control_pert=CONTROL_VAR,
pert_col=PERT_COL,
outdir=OUTDIR,
de_method="anderson",
)
evaluator.compute(
break_on_error=True,
)
validate_expected_files(OUTDIR)