Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "arc-state"
version = "0.10.2"
version = "0.10.3"
description = "State is a machine learning model that predicts cellular perturbation response across diverse contexts."
readme = "README.md"
authors = [
Expand All @@ -13,7 +13,7 @@ authors = [
requires-python = ">=3.10,<3.13"
dependencies = [
"anndata>=0.11.4",
"cell-load>=0.8.7",
"cell-load>=0.9.0",
"numpy>=2.2.6",
"pandas>=2.2.3",
"pyyaml>=6.0.2",
Expand Down
3 changes: 3 additions & 0 deletions src/state/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
run_tx_infer,
run_tx_predict,
run_tx_preprocess_train,
run_tx_sort,
run_tx_train,
)

Expand Down Expand Up @@ -126,6 +127,8 @@ def main():
case "preprocess_train":
# Run preprocessing using argparse
run_tx_preprocess_train(args.adata, args.output, args.num_hvgs)
case "sort":
run_tx_sort(args)


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions src/state/_cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
run_tx_infer,
run_tx_predict,
run_tx_preprocess_train,
run_tx_sort,
run_tx_train,
)

Expand All @@ -14,6 +15,7 @@
"run_tx_predict",
"run_tx_infer",
"run_tx_preprocess_train",
"run_tx_sort",
"run_emb_fit",
"run_emb_query",
"run_emb_transform",
Expand Down
3 changes: 3 additions & 0 deletions src/state/_cli/_tx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from ._infer import add_arguments_infer, run_tx_infer
from ._predict import add_arguments_predict, run_tx_predict
from ._preprocess_train import add_arguments_preprocess_train, run_tx_preprocess_train
from ._sort import add_arguments_sort, run_tx_sort
from ._train import add_arguments_train, run_tx_train

__all__ = [
"run_tx_train",
"run_tx_predict",
"run_tx_infer",
"run_tx_preprocess_train",
"run_tx_sort",
"add_arguments_tx",
]

Expand All @@ -21,3 +23,4 @@ def add_arguments_tx(parser: ap.ArgumentParser):
add_arguments_predict(subparsers.add_parser("predict"))
add_arguments_infer(subparsers.add_parser("infer"))
add_arguments_preprocess_train(subparsers.add_parser("preprocess_train"))
add_arguments_sort(subparsers.add_parser("sort"))
72 changes: 72 additions & 0 deletions src/state/_cli/_tx/_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import argparse as ap


def add_arguments_sort(parser: ap.ArgumentParser):
"""Add arguments for the sort subcommand."""
parser.add_argument(
"--input",
type=str,
required=True,
help="Path to input AnnData file (.h5ad)",
)
parser.add_argument(
"--output",
type=str,
required=True,
help="Path to output sorted AnnData file (.h5ad)",
)
parser.add_argument(
"--context-col",
type=str,
required=True,
help="obs column to sort by context (e.g. cell type)",
)
parser.add_argument(
"--batch-col",
type=str,
required=False,
default=None,
help="optional obs column to sort by batch (if omitted, sorts by context + perturbation)",
)
parser.add_argument(
"--pert-col",
type=str,
required=True,
help="obs column to sort by perturbation",
)


def run_tx_sort(args: ap.Namespace):
import logging
from pathlib import Path
Comment on lines +40 to +41
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better code organization and adherence to PEP 8, standard library imports like logging and pathlib should be placed at the top of the file. While it's a good practice to delay importing heavy libraries like anndata inside a function for CLI tools to improve startup time for other subcommands, this doesn't apply to lightweight standard libraries.


import anndata as ad

logging.basicConfig(level=logging.INFO)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

logging.basicConfig() should ideally be called only once at the application's entry point (e.g., in src/state/__main__.py). Calling it within subcommand functions can lead to unpredictable logging behavior, as it only has an effect the first time it's called. If another part of the application has already configured logging, this call will be ignored. Centralizing logging configuration would make it more robust.

logger = logging.getLogger(__name__)

input_path = args.input
output_path = args.output
sort_cols = [args.context_col]
if args.batch_col:
sort_cols.append(args.batch_col)
sort_cols.append(args.pert_col)

logger.info("Loading AnnData from %s", input_path)
adata = ad.read_h5ad(input_path)

missing = [col for col in sort_cols if col not in adata.obs.columns]
if missing:
raise ValueError(f"Missing obs columns for sorting: {missing}")

logger.info("Sorting AnnData by columns: %s", sort_cols)
order = adata.obs.sort_values(by=sort_cols, kind="mergesort").index
adata_sorted = adata[order].copy()

output_dir = Path(output_path).parent
if output_dir and not output_dir.exists():
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check if output_dir is redundant because Path(output_path).parent always returns a Path object, which is truthy. You can simplify the condition to just check for the directory's existence.

Suggested change
if output_dir and not output_dir.exists():
if not output_dir.exists():

output_dir.mkdir(parents=True, exist_ok=True)

logger.info("Writing sorted AnnData to %s", output_path)
adata_sorted.write_h5ad(output_path)
logger.info("Sort complete. Wrote %d cells.", adata_sorted.n_obs)