diff --git a/.gitignore b/.gitignore index b88be16..c272e22 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ _build/* .fleet .zed .ipynb_checkpoints/ +.venv-dask/ diff --git a/Cargo.toml b/Cargo.toml index b733005..60afdbd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ polars = { version = "0.51", features = [ "pivot" ] } ndarray = { version = "0.16", features = ["rayon"] } +rayon = "1" nalgebra-sparse = "0.11.0" anyhow = "1.0.86" log = "0.4.27" diff --git a/demo/OOC_BENCHMARK.md b/demo/OOC_BENCHMARK.md new file mode 100644 index 0000000..204580b --- /dev/null +++ b/demo/OOC_BENCHMARK.md @@ -0,0 +1,102 @@ +# Out-of-core preprocessing — memory vs scanpy + +> **Tier-1 OOC pipeline is complete.** All of QC, `normalize_total`, `log1p`, a fused +> `preprocess`, highly variable genes (Seurat), and PCA run disk-backed in bounded memory, exposed +> as `sr_ooc ` and `sr.pp.*` (drop-in for `sc.pp.*`). Each step is unit-tested against its +> in-memory / dense counterpart: log1p & normalize match scanpy's X exactly; QC matches +> (totals/mito/top-N/var); HVG selects the **same** genes as SingleRust's in-memory HVG; PCA +> matches a dense covariance-PCA reference (variance ratios exact, embedding norms match, up to +> per-component sign). End-to-end vs scanpy, the PCA *spectrum* tracks scanpy but the exact axes +> differ because SingleRust's Seurat HVG picks a different gene set than scanpy's (a pre-existing +> in-memory difference, not introduced by the OOC path). + +The benchmark below covers the QC + normalize_total + log1p portion. + +## Deterministic parallelism + +The compute-bound passes — QC accumulation, HVG sum/sum-of-squares, and PCA's gene×gene Gram +matrix — are parallelized with rayon. Naive parallel floating-point reduction is **not** +reproducible (FP addition isn't associative and rayon's work-stealing varies the summation order), +so all reductions use a **fixed-block, ordered-merge** scheme (`backed::processing::det`): rows are +cut into fixed-size blocks at fixed indices, each block is folded sequentially, and partials are +merged in block order. The partition and merge order depend only on the data size and a constant +block size — never on thread count or scheduling — so results are bit-identical on every run. + +Verified two ways: +- **Unit tests** run QC / HVG / preprocess / PCA under rayon pools of 1 vs 8 threads and assert + bit-identical metrics, masks, embeddings, and variance ratios. +- **At scale (50k cells, real data):** the full pipeline (`preprocess → hvg → pca`) run with + `RAYON_NUM_THREADS=1` vs `=18` produced **bit-identical** X, obs QC columns, var HVG columns, + `obsm["X_pca"]`, and `uns` variance ratios. + +(The cheap elementwise transforms — normalize/log1p — have no cross-row reduction and are +deterministic by construction; PCA projection is per-cell independent, also order-free.) + + +Same work in both lanes (QC + `normalize_total(1e4)` + `log1p`) on the same `.h5ad`, each run as +a subprocess under `/usr/bin/time -l` to capture **peak RSS** and wall time. + +```bash +cargo build --release --example sr_ooc +python demo/bench_ooc.py data/bench_input_500k.h5ad 20000 +``` + +## Result (500,000 cells × 48,788 genes, 48 GB / 18-core) + +| lane | time | peak RSS | +|---|---:|---:| +| scanpy in-memory | 13.5 s (compute) | 13.6 GB | +| scanpy + Dask OOC | 35.4 s (compute) | 14.5 GB | +| **SingleRust OOC (fused)** | 45.0 s (wall, incl. I/O) | **2.5 GB** | + +**SingleRust OOC uses ~5.4× less peak memory than scanpy in-memory and ~5.8× less than +scanpy+Dask**, and that footprint is chunk-bounded — it stays roughly flat as the cell count +grows, whereas scanpy in-memory scales linearly and eventually OOMs (≈46 GB at 2M cells, over a +48 GB machine). That is the point of out-of-core: process data that does not fit in RAM. + +The standout finding is the **Dask lane**: its sparse out-of-core path gives essentially **no +memory benefit** — peak RSS lands at/above scanpy in-memory (it has measured 7.5–14.5 GB across +runs, i.e. ≥ in-memory) while still costing 2.6× the runtime. This is exactly the Dask-sparse +immaturity scanpy's own issues describe. Native Rust streaming is the only lane with truly bounded +memory (2.5 GB). + +The SingleRust lane is a **single fused command** (`sr_ooc preprocess` = QC + normalize_total + +log1p in one job; the per-cell total computed for QC is reused as the normalization row-sum, so +it's computed once). Fusing cut its wall time from 128 s (three separate ops + a working copy) to +45 s — now in the same ballpark as Dask on time, at ~1/6th the memory. + +### Caveats / honest reading + +- **Time is not apples-to-apples.** scanpy's number is compute-only (excludes its initial load); + the Dask lane includes a lazy read + chunked write-out; SingleRust's wall time includes a 6 GB + working-copy plus two full read/write passes to disk as *separate* CLI processes. A fused + single-pass pipeline (+ chunk-size tuning, parallel chunk processing) would cut SingleRust's time + substantially. **Peak RSS is the fair, durable signal** and it is immune to thermal/throttling. +- SingleRust OOC RSS includes one chunk (here 20k rows) plus small per-cell / per-gene + accumulators; smaller chunks lower it further. +- Both scanpy lanes are run with the newer `.venv-dask` stack (anndata 0.12, scanpy 1.12, dask) + for consistency; SingleRust is the `sr_ooc` release binary. + +### Projection to 2M cells (not run — memory safety) + +At ~2M cells the count matrix is ≈46 GB dense-equivalent; scanpy in-memory would exceed the 48 GB +machine (OOM/heavy swap), and the Dask lane's ~0.6× memory ratio still lands near the ceiling. +SingleRust OOC stays ~2 GB (chunk-bounded) and is the only lane that completes comfortably. We +did **not** run 2M here to avoid destabilizing the machine; it needs a ~32 GB base dataset and a +box with headroom. + +## Reproducing the Dask lane + +The Dask lane needs `anndata>=0.11` (`experimental.read_elem_lazy`) + dask + xarray, which require +Python ≥3.10 — kept in an isolated venv so the main 3.9 demo env is untouched: + +```bash +python3.12 -m venv .venv-dask +. .venv-dask/bin/activate +pip install "anndata>=0.11" "scanpy>=1.11" "dask[array]" xarray hdf5plugin h5py scipy numpy +``` + +`bench_ooc.py` auto-uses `.venv-dask/bin/python` for both scanpy lanes (override with `SR_PY`). +Relevant scanpy issues: [#4095](https://github.com/scverse/scanpy/issues/4095), +[dask#11880](https://github.com/dask/dask/issues/11880), +[pydata/sparse#860](https://github.com/pydata/sparse/issues/860). diff --git a/demo/__pycache__/singlerust.cpython-312.pyc b/demo/__pycache__/singlerust.cpython-312.pyc new file mode 100644 index 0000000..9527218 Binary files /dev/null and b/demo/__pycache__/singlerust.cpython-312.pyc differ diff --git a/demo/_scanpy_dask_pp.py b/demo/_scanpy_dask_pp.py new file mode 100644 index 0000000..aa3c748 --- /dev/null +++ b/demo/_scanpy_dask_pp.py @@ -0,0 +1,65 @@ +"""scanpy OUT-OF-CORE via Dask: read only X lazily as a dask array (obs/var eager + small), run +dask-enabled normalize_total + log1p (and attempt qc), then stream the result to disk to force +chunked execution. + +Requires anndata>=0.11 + dask + xarray. Run with the isolated .venv-dask. Prints +STEP_SECONDS= and QC_DASK=ok|err:. Peak RSS is captured by the caller via +/usr/bin/time -l; if Dask's sparse path densifies/materializes, RSS spikes — itself the finding. + + python demo/_scanpy_dask_pp.py [chunk_rows] +""" +import sys +import time +import tempfile +import os + +import hdf5plugin # noqa: F401 +import h5py +import anndata as ad +import scanpy as sc +from anndata.experimental import read_elem_lazy +from anndata.io import read_elem + + +def main(): + path = sys.argv[1] + chunk = int(sys.argv[2]) if len(sys.argv) > 2 else 20000 + + # Keep the file open for the whole run — the lazy dask array reads from it on compute. + f = h5py.File(path, "r") + try: + X = read_elem_lazy(f["X"], chunks=(chunk, -1)) # dask array, streamed in row blocks + obs = read_elem(f["obs"]) # eager, small + var = read_elem(f["var"]) + adata = ad.AnnData(X=X, obs=obs, var=var) + + t = time.perf_counter() + + qc_status = "ok" + try: + adata.var["mito"] = adata.var_names.str.upper().str.startswith("MT-") + sc.pp.calculate_qc_metrics( + adata, qc_vars=["mito"], percent_top=[50, 100, 200, 500], log1p=True, inplace=True + ) + except Exception as e: + qc_status = f"err:{type(e).__name__}:{str(e)[:120]}" + + sc.pp.normalize_total(adata, target_sum=1e4) + sc.pp.log1p(adata) + + # Force streaming execution by writing the transformed X out (chunked). + out = tempfile.NamedTemporaryFile(suffix=".h5ad", delete=False).name + try: + adata.write_h5ad(out) + finally: + if os.path.exists(out): + os.remove(out) + + print(f"STEP_SECONDS={time.perf_counter() - t}") + print(f"QC_DASK={qc_status}") + finally: + f.close() + + +if __name__ == "__main__": + main() diff --git a/demo/_scanpy_inmem_pp.py b/demo/_scanpy_inmem_pp.py new file mode 100644 index 0000000..c46325f --- /dev/null +++ b/demo/_scanpy_inmem_pp.py @@ -0,0 +1,29 @@ +"""scanpy IN-MEMORY qc + normalize_total + log1p on an .h5ad (loads the whole matrix). + +Used by the OOC benchmark as the baseline lane. Prints STEP_SECONDS=. +Peak RSS (captured by the caller via /usr/bin/time -l) scales with the dataset — that's the +point of the comparison against SingleRust's bounded-memory OOC lane. +""" +import sys +import time + +import hdf5plugin # noqa: F401 +import scanpy as sc +import anndata as ad + + +def main(): + path = sys.argv[1] + adata = ad.read_h5ad(path) # full load into memory + adata.var["mito"] = adata.var_names.str.upper().str.startswith("MT-") + t = time.perf_counter() + sc.pp.calculate_qc_metrics( + adata, qc_vars=["mito"], percent_top=[50, 100, 200, 500], log1p=True, inplace=True + ) + sc.pp.normalize_total(adata, target_sum=1e4) + sc.pp.log1p(adata) + print(f"STEP_SECONDS={time.perf_counter() - t}") + + +if __name__ == "__main__": + main() diff --git a/demo/bench_ooc.py b/demo/bench_ooc.py new file mode 100644 index 0000000..fc3029d --- /dev/null +++ b/demo/bench_ooc.py @@ -0,0 +1,110 @@ +"""Out-of-core memory benchmark: scanpy in-memory vs SingleRust OOC. + +Same work in both lanes — QC + normalize_total(1e4) + log1p — on the same .h5ad. Each lane runs +as a subprocess under /usr/bin/time -l so we capture **peak RSS** (the headline) and compute time: + + * scanpy in-memory : loads the whole matrix; peak RSS scales with the dataset. + * SingleRust OOC : streams the file in chunks; peak RSS stays ~bounded. + +Usage: + python demo/bench_ooc.py [chunk_size] + +(A scanpy+Dask out-of-core lane needs anndata>=0.11's read_elem_as_dask; see notes in the +performance roadmap. This harness leaves a hook for it.) +""" +import os +import re +import subprocess +import sys +import pathlib + +ROOT = pathlib.Path(__file__).resolve().parent.parent +BIN = ROOT / "target" / "release" / "examples" / "sr_ooc" +ENV = dict(os.environ) +ENV["PATH"] = "/opt/homebrew/bin:" + str(pathlib.Path.home() / ".cargo" / "bin") + ":" + ENV.get("PATH", "") +ENV.setdefault("HDF5_USE_FILE_LOCKING", "FALSE") + +# Python with the newer stack (anndata>=0.11 + dask) for both scanpy lanes; override with SR_PY. +PY = os.environ.get("SR_PY", str(ROOT / ".venv-dask" / "bin" / "python")) + + +def run_timed(cmd): + """Run under /usr/bin/time -l; return (compute_seconds|None, peak_rss_gb).""" + p = subprocess.run(["/usr/bin/time", "-l", *cmd], cwd=ROOT, env=ENV, capture_output=True, text=True) + if p.returncode: + sys.stderr.write(p.stdout[-1000:] + "\n" + p.stderr[-2000:] + "\n") + raise RuntimeError(f"failed: {' '.join(map(str, cmd))}") + secs = None + for line in p.stdout.splitlines(): + if line.startswith("STEP_SECONDS="): + secs = float(line.split("=", 1)[1]) + m = re.search(r"(\d+)\s+maximum resident set size", p.stderr) + rss_gb = int(m.group(1)) / 1e9 if m else float("nan") # macOS: bytes + w = re.search(r"([\d.]+)\s+real", p.stderr) # wall time from /usr/bin/time -l + wall = float(w.group(1)) if w else float("nan") + return secs if secs is not None else wall, rss_gb + + +def scanpy_inmem(path): + return run_timed([PY, str(ROOT / "demo" / "_scanpy_inmem_pp.py"), str(path)]) + + +def scanpy_dask(path, chunk): + args = [PY, str(ROOT / "demo" / "_scanpy_dask_pp.py"), str(path)] + if chunk: + args.append(str(chunk)) + return run_timed(args) + + +def singlerust_ooc(path, chunk): + # Single fused pass (qc + normalize_total + log1p), reading source -> temp output. No 6 GB + # copy and one process, so wall time and peak RSS are a single clean /usr/bin/time -l measure. + out = path.with_suffix(".ooc_out.h5ad") + chunk_args = ["--chunk", str(chunk)] if chunk else [] + t, r = run_timed([str(BIN), "preprocess", str(path), "--out", str(out), + "--target-sum", "10000.0", *chunk_args]) + out.unlink(missing_ok=True) + return t, r + + +def main(): + if len(sys.argv) < 2: + print("usage: python demo/bench_ooc.py [chunk_size]") + sys.exit(2) + path = pathlib.Path(sys.argv[1]) + chunk = int(sys.argv[2]) if len(sys.argv) > 2 else None + if not BIN.exists(): + raise SystemExit(f"build the CLI first: cargo build --release --example sr_ooc ({BIN} missing)") + + import h5py + with h5py.File(path) as f: + shape = list(f["X"].attrs.get("shape", [])) + print(f"dataset: {path.name} shape={shape} chunk={chunk or 'default'}\n") + + s_t, s_rss = scanpy_inmem(path) + print(f"scanpy in-memory : {s_t:7.2f}s compute | peak RSS {s_rss:6.2f} GB") + + if pathlib.Path(PY).exists(): + try: + d_t, d_rss = scanpy_dask(path, chunk) + print(f"scanpy + Dask OOC: {d_t:7.2f}s compute | peak RSS {d_rss:6.2f} GB") + except Exception as e: + d_rss = None + print(f"scanpy + Dask OOC: FAILED ({e})") + else: + d_rss = None + print(f"scanpy + Dask OOC: skipped (no {PY}; create .venv-dask with anndata>=0.11 + dask)") + + r_t, r_rss = singlerust_ooc(path, chunk) + print(f"SingleRust OOC : {r_t:7.2f}s wall | peak RSS {r_rss:6.2f} GB") + + print() + if s_rss and r_rss: + print(f"memory: SingleRust OOC uses {s_rss / r_rss:.1f}× less peak RAM than scanpy in-memory") + if d_rss and r_rss: + print(f"memory: SingleRust OOC uses {d_rss / r_rss:.1f}× less peak RAM than scanpy+Dask " + f"(Dask's sparse path barely beats in-memory)") + + +if __name__ == "__main__": + main() diff --git a/demo/singlerust.py b/demo/singlerust.py new file mode 100644 index 0000000..9d793d1 --- /dev/null +++ b/demo/singlerust.py @@ -0,0 +1,125 @@ +"""singlerust — a near-drop-in `sc.pp.*` replacement backed by SingleRust's out-of-core engine. + +The goal: take an existing scanpy preprocessing script and change as little as possible to run +the heavy steps in Rust, out-of-core, on a disk-backed `.h5ad`. + + import scanpy as sc import singlerust as sr + adata = sc.read_h5ad(path, backed="r") adata = sc.read_h5ad(path, backed="r") + sc.pp.calculate_qc_metrics(adata, ...) sr.pp.calculate_qc_metrics(adata) + sc.pp.normalize_total(adata, 1e4) sr.pp.normalize_total(adata, target_sum=1e4) + sc.pp.log1p(adata) sr.pp.log1p(adata) + +Each `sr.pp.*` call operates on the file the AnnData is backed by (or a path you pass directly), +streaming in bounded memory, and edits it **in place** — matching scanpy's inplace default. The +expression matrix is never fully loaded. + +Accepted `adata` argument forms: + * a path / `pathlib.Path` to a `.h5ad` + * a backed AnnData (``sc.read_h5ad(path, backed="r"|"r+")``) — its ``.filename`` is used +An in-memory AnnData is rejected with a clear message (use scanpy directly, or write to disk). + +This shells out to the ``sr_ooc`` example binary; build it once with +``cargo build --release --features enrichment --example sr_ooc``. +""" +from __future__ import annotations + +import os +import pathlib +import subprocess +import sys + +# Locate the compiled CLI (override with SR_OOC_BIN). +_ROOT = pathlib.Path(__file__).resolve().parent.parent +_BIN = pathlib.Path(os.environ.get("SR_OOC_BIN", _ROOT / "target" / "release" / "examples" / "sr_ooc")) + + +def _resolve_path(adata) -> pathlib.Path: + if isinstance(adata, (str, os.PathLike)): + return pathlib.Path(adata) + # Duck-type a backed AnnData: it exposes `.filename` (truthy when backed). The Rust process + # opens the file read-write, so we must release Python's handle first (HDF5 file locking), + # otherwise the open collides. The backed object is stale afterwards — re-read to see results. + fn = getattr(adata, "filename", None) + if fn: + path = pathlib.Path(fn) + f = getattr(adata, "file", None) + if f is not None and hasattr(f, "close"): + try: + f.close() + except Exception: + pass + return path + raise TypeError( + "singlerust operates on a disk-backed .h5ad. Pass a path, or open with " + "sc.read_h5ad(path, backed='r+'). (Got an in-memory AnnData.)" + ) + + +def _run(args: list[str]) -> None: + if not _BIN.exists(): + raise FileNotFoundError( + f"sr_ooc binary not found at {_BIN}. Build it with:\n" + " cargo build --release --features enrichment --example sr_ooc" + ) + env = dict(os.environ) + env.setdefault("HDF5_USE_FILE_LOCKING", "FALSE") # avoid stale-lock false positives + proc = subprocess.run([str(_BIN), *args], capture_output=True, text=True, env=env) + if proc.returncode != 0: + sys.stderr.write(proc.stdout + "\n" + proc.stderr + "\n") + raise RuntimeError(f"sr_ooc {' '.join(args)} failed (exit {proc.returncode})") + + +class _PP: + """Mirror of the subset of ``scanpy.pp`` that SingleRust implements out-of-core.""" + + def calculate_qc_metrics(self, adata, *, chunk=None, **_ignored) -> None: + """≈ ``sc.pp.calculate_qc_metrics(adata, inplace=True)`` — writes obs/var in place. + + Mitochondrial genes are auto-detected by ``MT-``/``mt-`` prefix; top-N segments are + [50,100,200,500]. Extra scanpy kwargs are accepted and ignored for compatibility. + """ + path = _resolve_path(adata) + _run(["qc", str(path), *(["--chunk", str(chunk)] if chunk else [])]) + + def normalize_total(self, adata, *, target_sum=1e4, chunk=None, **_ignored) -> None: + """≈ ``sc.pp.normalize_total(adata, target_sum=...)`` — rewrites X in place.""" + path = _resolve_path(adata) + _run(["normalize_total", str(path), "--target-sum", repr(float(target_sum)), + *(["--chunk", str(chunk)] if chunk else [])]) + + def log1p(self, adata, *, chunk=None, **_ignored) -> None: + """≈ ``sc.pp.log1p(adata)`` — rewrites X in place.""" + path = _resolve_path(adata) + _run(["log1p", str(path), *(["--chunk", str(chunk)] if chunk else [])]) + + def normalize_total_log1p(self, adata, *, target_sum=1e4, chunk=None, **_ignored) -> None: + """Fused normalize_total + log1p in a single streaming pass (one rewrite of X).""" + path = _resolve_path(adata) + _run(["normalize_total", str(path), "--target-sum", repr(float(target_sum)), "--log1p", + *(["--chunk", str(chunk)] if chunk else [])]) + + def highly_variable_genes(self, adata, *, n_top_genes=2000, chunk=None, **_ignored) -> None: + """≈ ``sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=...)`` — writes + means/dispersions/dispersions_norm/highly_variable into var in place.""" + path = _resolve_path(adata) + _run(["hvg", str(path), + *(["--n-top-genes", str(int(n_top_genes))] if n_top_genes else []), + *(["--chunk", str(chunk)] if chunk else [])]) + + def pca(self, adata, *, n_comps=50, chunk=None, **_ignored) -> None: + """≈ ``sc.pp.pca(adata, n_comps=...)`` — out-of-core covariance PCA over the HVGs + (run highly_variable_genes first); writes obsm["X_pca"] + uns variance ratios in place.""" + path = _resolve_path(adata) + _run(["pca", str(path), "--n-comps", str(int(n_comps)), + *(["--chunk", str(chunk)] if chunk else [])]) + + def preprocess(self, adata, *, target_sum=1e4, log1p=True, chunk=None, **_ignored) -> None: + """Fused QC + normalize_total + log1p in one out-of-core job (≈ sc.pp QC then + normalize_total then log1p). Writes obs/var QC metrics + the normalized X in place.""" + path = _resolve_path(adata) + _run(["preprocess", str(path), "--target-sum", repr(float(target_sum)), + *([] if log1p else ["--no-log1p"]), + *(["--chunk", str(chunk)] if chunk else [])]) + + +pp = _PP() diff --git a/examples/sr_ooc.rs b/examples/sr_ooc.rs new file mode 100644 index 0000000..b1ddf91 --- /dev/null +++ b/examples/sr_ooc.rs @@ -0,0 +1,117 @@ +//! # `sr_ooc` — out-of-core preprocessing CLI +//! +//! Thin command-line front end over SingleRust's disk-backed streaming ops, so a Python shim +//! (or a shell) can run them on a `.h5ad` without loading it into memory. +//! +//! ```text +//! sr_ooc qc [--chunk N] +//! sr_ooc normalize_total [--target-sum 1e4] [--log1p] [--chunk N] [--out OUT] +//! sr_ooc log1p [--chunk N] [--out OUT] +//! ``` +//! +//! `qc` always writes obs/var back into the file in place. `normalize_total`/`log1p` rewrite `X`; +//! with no `--out` they edit the file in place (write a sibling temp, then atomically rename). + +use std::path::{Path, PathBuf}; + +use single_rust::backed::processing::hvg::highly_variable_genes_backed; +use single_rust::backed::processing::pca::pca_backed; +use single_rust::backed::processing::pipeline::preprocess_backed; +use single_rust::backed::processing::qc::qc_metrics_backed; +use single_rust::backed::processing::transformation::{log1p_backed, normalize_total_backed}; + +fn main() -> anyhow::Result<()> { + let args: Vec = std::env::args().collect(); + if args.len() < 3 { + eprintln!( + "usage:\n sr_ooc qc [--chunk N]\n \ + sr_ooc normalize_total [--target-sum F] [--log1p] [--out OUT] [--chunk N]\n \ + sr_ooc log1p [--out OUT] [--chunk N]\n \ + sr_ooc preprocess [--target-sum F] [--no-log1p] [--out OUT] [--chunk N]\n \ + sr_ooc hvg [--n-top-genes N] [--chunk N]\n \ + sr_ooc pca [--n-comps N] [--chunk N] (needs hvg first)" + ); + std::process::exit(2); + } + let cmd = args[1].as_str(); + let input = PathBuf::from(&args[2]); + let flags = &args[3..]; + + let chunk = flag_val(flags, "--chunk").map(|s| s.parse()).transpose()?; + let out = flag_val(flags, "--out").map(PathBuf::from); + + match cmd { + "qc" => { + qc_metrics_backed(&input, chunk)?; + println!("qc -> {} (in place)", input.display()); + } + "hvg" => { + let n_top = flag_val(flags, "--n-top-genes") + .map(|s| s.parse()) + .transpose()?; + highly_variable_genes_backed(&input, n_top, chunk)?; + println!("hvg -> {} (in place)", input.display()); + } + "pca" => { + let n_comps = flag_val(flags, "--n-comps") + .map(|s| s.parse()) + .transpose()? + .unwrap_or(50); + pca_backed(&input, n_comps, None, chunk)?; + println!("pca -> {} obsm[\"X_pca\"] (in place)", input.display()); + } + "log1p" => { + in_place_or_out(&input, out, |inp, outp| log1p_backed(inp, outp, chunk))?; + } + "normalize_total" => { + let target = flag_val(flags, "--target-sum") + .map(|s| s.parse()) + .transpose()? + .unwrap_or(1e4); + let log1p = flags.iter().any(|f| f == "--log1p"); + in_place_or_out(&input, out, |inp, outp| { + normalize_total_backed(inp, outp, target, log1p, chunk) + })?; + } + "preprocess" => { + // Fused QC + normalize_total + log1p (log1p on by default; --no-log1p to disable). + let target = flag_val(flags, "--target-sum") + .map(|s| s.parse()) + .transpose()? + .unwrap_or(1e4); + let log1p = !flags.iter().any(|f| f == "--no-log1p"); + in_place_or_out(&input, out, |inp, outp| { + preprocess_backed(inp, outp, target, log1p, chunk) + })?; + } + other => anyhow::bail!("unknown command '{other}'"), + } + Ok(()) +} + +/// Run a transform that needs an output path. If `out` is `None`, edit `input` in place by +/// writing to a sibling temp file and renaming over the original on success. +fn in_place_or_out( + input: &Path, + out: Option, + run: impl FnOnce(&Path, &Path) -> anyhow::Result<()>, +) -> anyhow::Result<()> { + match out { + Some(outp) => { + run(input, &outp)?; + println!("-> {}", outp.display()); + } + None => { + let tmp = input.with_extension("h5ad.tmp"); + run(input, &tmp)?; + std::fs::rename(&tmp, input)?; + println!("-> {} (in place)", input.display()); + } + } + Ok(()) +} + +/// Value following `name` in the flag list (e.g. `--chunk 1000` -> `Some("1000")`). +fn flag_val<'a>(flags: &'a [String], name: &str) -> Option<&'a str> { + flags.iter().position(|f| f == name).and_then(|i| flags.get(i + 1)).map(|s| s.as_str()) +} diff --git a/src/backed/processing/det.rs b/src/backed/processing/det.rs new file mode 100644 index 0000000..66ac366 --- /dev/null +++ b/src/backed/processing/det.rs @@ -0,0 +1,59 @@ +//! Deterministic parallelism helpers. +//! +//! Floating-point addition is not associative, so a naive parallel reduction (rayon's work +//! stealing splits and combines partials in a scheduling-dependent order) produces results that +//! vary bit-for-bit between runs and thread counts. Everything here instead uses **fixed-size +//! blocks with an ordered merge**: row range `0..n` is cut into `ceil(n/BLOCK)` blocks at fixed +//! indices, each block is folded sequentially (deterministic), and the per-block partials are +//! merged in block order. The block boundaries and merge order depend only on `n` and the +//! constant block size — never on the thread count or scheduler — so the result is identical on +//! every run and on any machine. + +use rayon::prelude::*; + +/// Fixed block size (rows) for deterministic block reductions. Constant so the partition (and +/// hence the floating-point summation order) is reproducible across runs and machines. +pub(crate) const DET_BLOCK: usize = 4096; + +/// Deterministically reduce `0..n` in parallel. +/// +/// `init` makes a zero partial, `fold_row(&mut partial, row)` accumulates one row into a partial +/// (called sequentially within a block, in increasing row order), and `merge(&mut acc, partial)` +/// combines a block's partial into the accumulator (called in increasing block order). The result +/// is bit-identical regardless of how many threads rayon uses. +pub(crate) fn det_block_reduce( + n: usize, + init: Init, + fold_row: Fold, + merge: Merge, +) -> P +where + P: Send, + Init: Fn() -> P + Sync, + Fold: Fn(&mut P, usize) + Sync, + Merge: Fn(&mut P, P), +{ + if n == 0 { + return init(); + } + let nblocks = n.div_ceil(DET_BLOCK); + // Parallel across blocks; `collect` preserves block order. + let partials: Vec

= (0..nblocks) + .into_par_iter() + .map(|b| { + let lo = b * DET_BLOCK; + let hi = ((b + 1) * DET_BLOCK).min(n); + let mut p = init(); + for row in lo..hi { + fold_row(&mut p, row); + } + p + }) + .collect(); + // Sequential, ordered merge -> fixed summation order. + let mut acc = init(); + for p in partials { + merge(&mut acc, p); + } + acc +} diff --git a/src/backed/processing/hvg.rs b/src/backed/processing/hvg.rs new file mode 100644 index 0000000..d6674aa --- /dev/null +++ b/src/backed/processing/hvg.rs @@ -0,0 +1,226 @@ +//! # Out-of-core highly variable genes (Seurat flavor) +//! +//! One streaming pass over `X` accumulates per-gene sum and sum-of-squares; from those we form +//! the per-gene mean and **sample** variance (`(Σx²/n − mean²)·n/(n−1)`) — identical to the +//! in-memory `var_col` — then hand off to the shared [`seurat_select`] for the dispersion +//! binning/normalization and top-N selection. Results are written into `var` in place. +//! +//! Memory is bounded by one chunk plus two `n_vars`-length f64 accumulators. + +use std::path::Path; + +use anndata::data::DynCsrMatrix; +use anndata::{AnnData, AnnDataOp, ArrayData, ArrayElemOp, Backend}; +use anndata_hdf5::H5; +use anyhow::bail; +use polars::prelude::Column; + +use crate::memory::processing::hvg::seurat_select; +use crate::shared::processing::HVGParams; + +use super::det::det_block_reduce; +use super::transformation::DEFAULT_CHUNK_SIZE; + +/// Compute highly variable genes (Seurat flavor) out-of-core and write the results +/// (`means`, `dispersions`, `dispersions_norm`, `highly_variable`) into `var` in place. +pub fn highly_variable_genes_backed( + path: &Path, + n_top_genes: Option, + chunk_size: Option, +) -> anyhow::Result<()> { + let chunk_size = chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE); + let adata = AnnData::

::open(H5::open_rw(path)?)?; + let (n_obs, n_vars) = (adata.n_obs(), adata.n_vars()); + + // Streaming pass: per-gene sum and sum of squares. Each chunk is reduced with a deterministic + // block reduction (parallel, fixed summation order), then added to the running totals. + let mut col_sum = vec![0.0_f64; n_vars]; + let mut col_sumsq = vec![0.0_f64; n_vars]; + for (chunk, _start, _end) in adata.x().iter::(chunk_size) { + let (s, sq) = sumsq_chunk(&chunk, n_vars)?; + for (a, b) in col_sum.iter_mut().zip(s.iter()) { + *a += *b; + } + for (a, b) in col_sumsq.iter_mut().zip(sq.iter()) { + *a += *b; + } + } + + let n = n_obs as f64; + let raw_means: Vec = col_sum.iter().map(|&s| s / n).collect(); + let variances: Vec = col_sum + .iter() + .zip(col_sumsq.iter()) + .map(|(&s, &sq)| { + let mean = s / n; + let pop_var = sq / n - mean * mean; + if n > 1.0 { + pop_var * (n / (n - 1.0)) // Bessel-corrected, matches in-memory var_col + } else { + 0.0 + } + }) + .collect(); + + let params = HVGParams { + n_top_genes, + ..Default::default() + }; + let (log1p_means, log_dispersions, dispersions_norm, highly_variable) = + seurat_select(&raw_means, &variances, ¶ms)?; + + let mut var = adata.read_var()?; + var.with_column(Column::new("means".into(), log1p_means))?; + var.with_column(Column::new("dispersions".into(), log_dispersions))?; + var.with_column(Column::new("dispersions_norm".into(), dispersions_norm))?; + var.with_column(Column::new("highly_variable".into(), highly_variable))?; + adata.set_var(var)?; + + adata.close()?; + Ok(()) +} + +/// One chunk's per-gene `(sum, sum_of_squares)` via a deterministic block reduction over rows. +fn sumsq_chunk(chunk: &ArrayData, n_vars: usize) -> anyhow::Result<(Vec, Vec)> { + match chunk { + ArrayData::CsrMatrix(DynCsrMatrix::F32(m)) => Ok(sumsq_rows(m, n_vars)), + ArrayData::CsrMatrix(DynCsrMatrix::F64(m)) => Ok(sumsq_rows(m, n_vars)), + other => bail!("OOC HVG supports only F32/F64 CSR matrices, got {:?}", other), + } +} + +fn sumsq_rows( + m: &nalgebra_sparse::CsrMatrix, + n_vars: usize, +) -> (Vec, Vec) { + det_block_reduce( + m.nrows(), + || (vec![0.0_f64; n_vars], vec![0.0_f64; n_vars]), + |(sum, sumsq), r| { + let row = m.row(r); + for (&c, &v) in row.col_indices().iter().zip(row.values().iter()) { + let v = v.to_f64().unwrap_or(0.0); + sum[c] += v; + sumsq[c] += v * v; + } + }, + |(asum, asq), (sum, sq)| { + for (a, b) in asum.iter_mut().zip(sum.iter()) { + *a += *b; + } + for (a, b) in asq.iter_mut().zip(sq.iter()) { + *a += *b; + } + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use anndata::data::DynCsrMatrix; + use nalgebra_sparse::{CooMatrix, CsrMatrix}; + + fn tmp(name: &str) -> std::path::PathBuf { + let mut p = std::env::temp_dir(); + p.push(name); + let _ = std::fs::remove_file(&p); + p + } + + /// OOC HVG must select exactly the same genes as the in-memory implementation on the same + /// data (both share `seurat_select`; only the mean/variance computation differs in path). + #[test] + fn ooc_hvg_matches_in_memory() -> anyhow::Result<()> { + // 40 cells × 12 genes, varied magnitudes so dispersion ranking is unambiguous. + let (nr, nc) = (40usize, 12usize); + let mut coo = CooMatrix::::new(nr, nc); + for i in 0..nr { + for j in 0..nc { + let v = (((i * 13 + j * 7) % 11) as f32) * (1.0 + (j as f32) * 0.3); + if v != 0.0 { + coo.push(i, j, v); + } + } + } + let path = tmp("sr_ooc_hvg.h5ad"); + let adata = AnnData::
::new(&path)?; + adata.set_obs_names((0..nr).map(|i| format!("c{i}")).collect::>().into())?; + adata.set_var_names((0..nc).map(|j| format!("g{j}")).collect::>().into())?; + adata.set_x(ArrayData::CsrMatrix(DynCsrMatrix::F32(CsrMatrix::from(&coo))))?; + adata.close()?; + + // In-memory reference (does not mutate the file). + let im = crate::io::read_h5ad_memory(&path)?; + crate::memory::processing::compute_highly_variable_genes( + &im, + Some(HVGParams { n_top_genes: Some(5), ..Default::default() }), + )?; + let ref_mask: Vec = im + .var() + .get_column_from_df("highly_variable")? + .bool()? + .into_iter() + .map(|b| b.unwrap_or(false)) + .collect(); + + // Out-of-core (mutates the file's var), tiny chunk to exercise streaming. + highly_variable_genes_backed(&path, Some(5), Some(7))?; + let back = AnnData::
::open(H5::open(&path)?)?; + let ooc_mask: Vec = back + .read_var()? + .column("highly_variable")? + .bool()? + .into_iter() + .map(|b| b.unwrap_or(false)) + .collect(); + back.close()?; + + assert_eq!(ref_mask, ooc_mask, "OOC HVG mask must equal in-memory HVG mask"); + assert!(ooc_mask.iter().any(|&b| b), "expected some HVGs selected"); + std::fs::remove_file(path).ok(); + Ok(()) + } + + /// Determinism: HVG run with 1 vs 8 threads must give bit-identical means/dispersions and the + /// same selection (guards the parallel sum/sum-of-squares reduction). + #[test] + fn ooc_hvg_is_deterministic_across_thread_counts() -> anyhow::Result<()> { + let (nr, nc) = (5000usize, 80usize); + let build = |tag: &str| -> anyhow::Result { + let mut coo = CooMatrix::::new(nr, nc); + for i in 0..nr { + for j in 0..nc { + let v = (((i * 29 + j * 11) % 17) as f32) * (0.5 + (j as f32) * 0.1); + if v != 0.0 { + coo.push(i, j, v); + } + } + } + let p = tmp(&format!("sr_hvg_det_{tag}.h5ad")); + let a = AnnData::
::new(&p)?; + a.set_obs_names((0..nr).map(|i| format!("c{i}")).collect::>().into())?; + a.set_var_names((0..nc).map(|j| format!("g{j}")).collect::>().into())?; + a.set_x(ArrayData::CsrMatrix(DynCsrMatrix::F32(CsrMatrix::from(&coo))))?; + a.close()?; + Ok(p) + }; + let run = |threads: usize, tag: &str| -> anyhow::Result<(Vec, Vec)> { + let p = build(tag)?; + let pool = rayon::ThreadPoolBuilder::new().num_threads(threads).build()?; + pool.install(|| highly_variable_genes_backed(&p, Some(20), Some(1000)))?; + let a = AnnData::
::open(H5::open(&p)?)?; + let var = a.read_var()?; + let means = var.column("means")?.f64()?.into_iter().map(|x| x.unwrap_or(f64::NAN)).collect(); + let hv = var.column("highly_variable")?.bool()?.into_iter().map(|b| b.unwrap_or(false)).collect(); + a.close()?; + std::fs::remove_file(p).ok(); + Ok((means, hv)) + }; + let (m1, h1) = run(1, "t1")?; + let (m8, h8) = run(8, "t8")?; + assert_eq!(m1, m8, "HVG means differ between 1 and 8 threads"); + assert_eq!(h1, h8, "HVG selection differs between 1 and 8 threads"); + Ok(()) + } +} diff --git a/src/backed/processing/mod.rs b/src/backed/processing/mod.rs index 8b13789..6c87d1a 100644 --- a/src/backed/processing/mod.rs +++ b/src/backed/processing/mod.rs @@ -1 +1,16 @@ - +//! Out-of-core (disk-backed) processing — streaming transforms/metrics that never hold the full +//! expression matrix in memory. +//! +//! - [`transformation`]: streaming `normalize_total` / `log1p`. +//! - [`qc`]: streaming QC metrics, written into `obs`/`var` in place. +//! - [`pipeline`]: fused QC + normalize + log1p in one job. +//! - [`hvg`]: streaming highly variable genes (Seurat), written into `var` in place. +//! +//! Parallel passes use [`det`]'s fixed-block ordered reduction so results are bit-identical +//! regardless of thread count (floating-point summation order is pinned). +pub(crate) mod det; +pub mod hvg; +pub mod pca; +pub mod pipeline; +pub mod qc; +pub mod transformation; diff --git a/src/backed/processing/pca.rs b/src/backed/processing/pca.rs new file mode 100644 index 0000000..056b1e8 --- /dev/null +++ b/src/backed/processing/pca.rs @@ -0,0 +1,406 @@ +//! # Out-of-core PCA (covariance / eigendecomposition method) +//! +//! Exact PCA over the highly-variable genes without ever holding the cell×gene matrix in memory: +//! +//! 1. **Pass 1** streams `X` and accumulates, over the selected genes only, the gene×gene Gram +//! matrix `Gᵢⱼ = Σ_cells xᵢxⱼ` and per-gene sums. Memory is `O(n_hvg²)` (e.g. 2000² f64 = 32 MB) +//! plus one chunk — independent of cell count. +//! 2. The centered covariance `C = (G − n·μμᵀ)/(n−1)` is eigendecomposed (`n_hvg × n_hvg`, +//! symmetric). Top-`k` eigenvectors are the principal axes; eigenvalues are the variances. +//! 3. **Pass 2** streams `X` again and projects each centered cell onto the top axes, writing the +//! embedding to `obsm["X_pca"]` (and variance ratios to `uns`). +//! +//! This is exact (not randomized) PCA with centering; results match the in-memory PCA up to the +//! usual per-component sign ambiguity. + +use std::path::Path; + +use anndata::data::{DynArray, DynCsrMatrix}; +use anndata::{ + AnnData, AnnDataOp, ArrayData, ArrayElemOp, AxisArraysOp, Backend, Data, ElemCollectionOp, +}; +use anndata_hdf5::H5; +use anyhow::bail; +use nalgebra::{DMatrix, SymmetricEigen}; +use ndarray::Array2; +use rayon::prelude::*; + +use super::det::det_block_reduce; +use super::transformation::DEFAULT_CHUNK_SIZE; + +/// Run PCA out-of-core over the genes flagged in `var["highly_variable"]`, writing the cell +/// embedding to `obsm["X_{key}"]` (default `X_pca`) and variance ratios to +/// `uns["{key}_variance_ratio"]`, in place. +pub fn pca_backed( + path: &Path, + n_comps: usize, + key: Option<&str>, + chunk_size: Option, +) -> anyhow::Result<()> { + let key = key.unwrap_or("pca"); + let chunk_size = chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE); + let adata = AnnData::
::open(H5::open_rw(path)?)?; + let (n_obs, n_vars) = (adata.n_obs(), adata.n_vars()); + + // Selected (highly variable) genes -> local index map. + let hv: Vec = adata + .read_var()? + .column("highly_variable")? + .bool()? + .into_iter() + .map(|b| b.unwrap_or(false)) + .collect(); + let selected: Vec = hv.iter().enumerate().filter(|(_, &b)| b).map(|(i, _)| i).collect(); + let n_sel = selected.len(); + if n_sel == 0 { + bail!("OOC PCA needs var['highly_variable'] set (run hvg first); none selected"); + } + let mut local = vec![-1i64; n_vars]; + for (li, &g) in selected.iter().enumerate() { + local[g] = li as i64; + } + + // ---- Pass 1: Gram matrix + per-gene sums over selected genes ---- + // Each read-chunk's contribution is computed with a deterministic block reduction (parallel + // across fixed row-blocks, ordered merge), then added to the running totals in chunk order. + let mut gram = vec![0.0_f64; n_sel * n_sel]; + let mut col_sum = vec![0.0_f64; n_sel]; + for (chunk, _start, _end) in adata.x().iter::(chunk_size) { + let (g_chunk, cs_chunk) = gram_chunk(&chunk, &local, n_sel)?; + for (a, b) in gram.iter_mut().zip(g_chunk.iter()) { + *a += *b; + } + for (a, b) in col_sum.iter_mut().zip(cs_chunk.iter()) { + *a += *b; + } + } + + // ---- Covariance + eigendecomposition (small, n_sel × n_sel) ---- + let n = n_obs as f64; + let mean: Vec = col_sum.iter().map(|&s| s / n).collect(); + let mut cov = DMatrix::::zeros(n_sel, n_sel); + for i in 0..n_sel { + for j in 0..n_sel { + cov[(i, j)] = (gram[i * n_sel + j] - n * mean[i] * mean[j]) / (n - 1.0); + } + } + let eig = SymmetricEigen::new(cov); + let mut order: Vec = (0..n_sel).collect(); + order.sort_by(|&a, &b| { + eig.eigenvalues[b] + .partial_cmp(&eig.eigenvalues[a]) + .unwrap_or(std::cmp::Ordering::Equal) + }); + let k = n_comps.min(n_sel); + let total_var: f64 = eig.eigenvalues.iter().sum::().max(f64::MIN_POSITIVE); + let top: Vec = order[..k].to_vec(); + let variance_ratio: Vec = top.iter().map(|&t| eig.eigenvalues[t] / total_var).collect(); + + // Precompute the per-component centering offset c_k = Σ_j mean_j · V[j,k]. + let offset: Vec = (0..k) + .map(|c| (0..n_sel).map(|j| mean[j] * eig.eigenvectors[(j, top[c])]).sum::()) + .collect(); + + // ---- Pass 2: project centered cells onto the top axes ---- + // Each cell's embedding is independent (a fixed-order sum over its own nonzeros), so rows are + // computed in parallel and written to disjoint positions — order-independent, hence + // deterministic, with no cross-row reduction. + let mut emb = Array2::::zeros((n_obs, k)); + for (chunk, start, _end) in adata.x().iter::(chunk_size) { + let rows = project_chunk(&chunk, &local, &eig, &top, &offset, k)?; + for (i, row) in rows.into_iter().enumerate() { + for (c, val) in row.into_iter().enumerate() { + emb[[start + i, c]] = val; + } + } + } + + // ---- Store results in place ---- + let emb_ad: ArrayData = DynArray::from(emb).into(); + adata.obsm().add(&format!("X_{key}"), emb_ad)?; + let vr: ArrayData = DynArray::from(ndarray::Array1::from(variance_ratio)).into(); + adata.uns().add(&format!("{key}_variance_ratio"), Data::ArrayData(vr))?; + + adata.close()?; + Ok(()) +} + +/// Gather a CSR row's selected nonzeros as (local_index, value), reusing `buf`. +fn gather_selected( + cols: &[usize], + vals: &[T], + local: &[i64], + buf: &mut Vec<(usize, f64)>, +) { + buf.clear(); + for (&c, &v) in cols.iter().zip(vals.iter()) { + let li = local[c]; + if li >= 0 { + buf.push((li as usize, v.to_f64().unwrap_or(0.0))); + } + } +} + +/// One chunk's contribution to the Gram matrix and per-gene sums, via a deterministic block +/// reduction over the chunk's rows. Returns `(gram_flat, col_sum)`. +fn gram_chunk( + chunk: &ArrayData, + local: &[i64], + n_sel: usize, +) -> anyhow::Result<(Vec, Vec)> { + match chunk { + ArrayData::CsrMatrix(DynCsrMatrix::F32(m)) => Ok(gram_rows(m, local, n_sel)), + ArrayData::CsrMatrix(DynCsrMatrix::F64(m)) => Ok(gram_rows(m, local, n_sel)), + other => bail!("OOC PCA supports only F32/F64 CSR matrices, got {:?}", other), + } +} + +fn gram_rows( + m: &nalgebra_sparse::CsrMatrix, + local: &[i64], + n_sel: usize, +) -> (Vec, Vec) { + det_block_reduce( + m.nrows(), + || (vec![0.0_f64; n_sel * n_sel], vec![0.0_f64; n_sel]), + |(gram, col_sum), r| { + let row = m.row(r); + let mut buf: Vec<(usize, f64)> = Vec::new(); + gather_selected(row.col_indices(), row.values(), local, &mut buf); + for &(li, v) in buf.iter() { + col_sum[li] += v; + } + for &(la, va) in buf.iter() { + let base = la * n_sel; + for &(lb, vb) in buf.iter() { + gram[base + lb] += va * vb; + } + } + }, + |(ag, acs), (g, cs)| { + for (a, b) in ag.iter_mut().zip(g.iter()) { + *a += *b; + } + for (a, b) in acs.iter_mut().zip(cs.iter()) { + *a += *b; + } + }, + ) +} + +/// Project a chunk's cells onto the top axes; returns one `k`-vector per row (in row order). +/// Rows are independent so they are computed in parallel; each value is a fixed-order sum, so the +/// result does not depend on the thread count. +fn project_chunk( + chunk: &ArrayData, + local: &[i64], + eig: &SymmetricEigen, + top: &[usize], + offset: &[f64], + k: usize, +) -> anyhow::Result>> { + match chunk { + ArrayData::CsrMatrix(DynCsrMatrix::F32(m)) => Ok(project_rows(m, local, eig, top, offset, k)), + ArrayData::CsrMatrix(DynCsrMatrix::F64(m)) => Ok(project_rows(m, local, eig, top, offset, k)), + other => bail!("OOC PCA supports only F32/F64 CSR matrices, got {:?}", other), + } +} + +fn project_rows( + m: &nalgebra_sparse::CsrMatrix, + local: &[i64], + eig: &SymmetricEigen, + top: &[usize], + offset: &[f64], + k: usize, +) -> Vec> { + (0..m.nrows()) + .into_par_iter() + .map(|r| { + let row = m.row(r); + let mut buf: Vec<(usize, f64)> = Vec::new(); + gather_selected(row.col_indices(), row.values(), local, &mut buf); + let mut out = vec![0.0_f64; k]; + for (c, o) in out.iter_mut().enumerate() { + *o = -offset[c]; + } + for &(li, v) in buf.iter() { + for (c, o) in out.iter_mut().enumerate() { + *o += v * eig.eigenvectors[(li, top[c])]; + } + } + out + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use nalgebra::DMatrix; + use nalgebra_sparse::{CooMatrix, CsrMatrix}; + + fn tmp(name: &str) -> std::path::PathBuf { + let mut p = std::env::temp_dir(); + p.push(name); + let _ = std::fs::remove_file(&p); + p + } + + /// OOC covariance-PCA must match a dense covariance-PCA reference computed in-test: identical + /// variance ratios and, up to per-component sign, identical embeddings. + #[test] + fn ooc_pca_matches_dense_reference() -> anyhow::Result<()> { + let (nr, nc) = (60usize, 8usize); + let mut dense = vec![vec![0.0_f64; nc]; nr]; + let mut coo = CooMatrix::::new(nr, nc); + for i in 0..nr { + for j in 0..nc { + let v = (((i * 7 + j * 5) % 9) as f64) * 0.5; + dense[i][j] = v; + if v != 0.0 { + coo.push(i, j, v as f32); + } + } + } + // All genes selected as HVG for the test. + let path = tmp("sr_ooc_pca.h5ad"); + let adata = AnnData::
::new(&path)?; + adata.set_obs_names((0..nr).map(|i| format!("c{i}")).collect::>().into())?; + adata.set_var_names((0..nc).map(|j| format!("g{j}")).collect::>().into())?; + adata.set_x(ArrayData::CsrMatrix(DynCsrMatrix::F32(CsrMatrix::from(&coo))))?; + let mut var = adata.read_var()?; + var.with_column(polars::prelude::Column::new( + "highly_variable".into(), + vec![true; nc], + ))?; + adata.set_var(var)?; + adata.close()?; + + let k = 3; + pca_backed(&path, k, None, Some(16))?; + + // Read OOC results. + let back = AnnData::
::open(H5::open(&path)?)?; + let emb_ooc = match back.obsm().get_item::("X_pca")?.unwrap() { + ArrayData::Array(DynArray::F64(a)) => a.into_dimensionality::()?, + other => panic!("unexpected obsm type {:?}", other), + }; + let vr_ooc = match back.uns().get_item::("pca_variance_ratio")?.unwrap() { + Data::ArrayData(ArrayData::Array(DynArray::F64(a))) => { + a.into_dimensionality::()?.to_vec() + } + other => panic!("unexpected uns type {:?}", other), + }; + back.close()?; + + // ---- dense reference: covariance PCA on the same matrix ---- + let n = nr as f64; + let mut means = vec![0.0; nc]; + for r in &dense { + for j in 0..nc { + means[j] += r[j] / n; + } + } + let mut cov = DMatrix::::zeros(nc, nc); + for r in &dense { + for a in 0..nc { + for b in 0..nc { + cov[(a, b)] += (r[a] - means[a]) * (r[b] - means[b]) / (n - 1.0); + } + } + } + let eig = SymmetricEigen::new(cov); + let mut order: Vec = (0..nc).collect(); + order.sort_by(|&a, &b| eig.eigenvalues[b].partial_cmp(&eig.eigenvalues[a]).unwrap()); + let total: f64 = eig.eigenvalues.iter().sum(); + let vr_ref: Vec = order[..k].iter().map(|&t| eig.eigenvalues[t] / total).collect(); + + // variance ratios must match closely + for (a, b) in vr_ooc.iter().zip(vr_ref.iter()) { + assert!((a - b).abs() < 1e-9, "var ratio {a} vs {b}"); + } + // embeddings: column norms match the reference (sign-independent check) + for c in 0..k { + let t = order[c]; + // reference embedding column = centered @ eigenvector_t + let mut ref_col = vec![0.0; nr]; + for (i, r) in dense.iter().enumerate() { + ref_col[i] = (0..nc).map(|j| (r[j] - means[j]) * eig.eigenvectors[(j, t)]).sum(); + } + let norm_ref: f64 = ref_col.iter().map(|x| x * x).sum::().sqrt(); + let norm_ooc: f64 = (0..nr).map(|i| emb_ooc[[i, c]].powi(2)).sum::().sqrt(); + assert!((norm_ref - norm_ooc).abs() < 1e-6, "col {c} norm {norm_ref} vs {norm_ooc}"); + } + std::fs::remove_file(path).ok(); + Ok(()) + } + + /// Write a CSR fixture with all genes flagged highly_variable. + fn write_pca_fixture(path: &Path, nr: usize, nc: usize) -> anyhow::Result<()> { + let mut coo = CooMatrix::::new(nr, nc); + for i in 0..nr { + for j in 0..nc { + let v = (((i * 31 + j * 17) % 13) as f32) * 0.25; + if v != 0.0 { + coo.push(i, j, v); + } + } + } + let a = AnnData::
::new(path)?; + a.set_obs_names((0..nr).map(|i| format!("c{i}")).collect::>().into())?; + a.set_var_names((0..nc).map(|j| format!("g{j}")).collect::>().into())?; + a.set_x(ArrayData::CsrMatrix(DynCsrMatrix::F32(CsrMatrix::from(&coo))))?; + let mut var = a.read_var()?; + var.with_column(polars::prelude::Column::new("highly_variable".into(), vec![true; nc]))?; + a.set_var(var)?; + a.close()?; + Ok(()) + } + + fn read_pca(path: &Path) -> anyhow::Result<(Vec, Vec)> { + let a = AnnData::
::open(H5::open(path)?)?; + let emb = match a.obsm().get_item::("X_pca")?.unwrap() { + ArrayData::Array(DynArray::F64(arr)) => arr.into_raw_vec_and_offset().0, + other => panic!("unexpected obsm {:?}", other), + }; + let vr = match a.uns().get_item::("pca_variance_ratio")?.unwrap() { + Data::ArrayData(ArrayData::Array(DynArray::F64(arr))) => arr.into_raw_vec_and_offset().0, + other => panic!("unexpected uns {:?}", other), + }; + a.close()?; + Ok((emb, vr)) + } + + /// Determinism: PCA run with 1 thread vs 8 threads (and a span of chunk sizes that force + /// different row-block partitions) must produce **bit-identical** embeddings and variance + /// ratios. This guards the parallel Gram reduction's fixed summation order. + #[test] + fn ooc_pca_is_deterministic_across_thread_counts() -> anyhow::Result<()> { + let (nr, nc, k) = (5000usize, 60usize, 10usize); + + let run = |threads: usize, chunk: usize, tag: &str| -> anyhow::Result<(Vec, Vec)> { + let path = tmp(&format!("sr_pca_det_{tag}.h5ad")); + write_pca_fixture(&path, nr, nc)?; + let pool = rayon::ThreadPoolBuilder::new().num_threads(threads).build()?; + pool.install(|| pca_backed(&path, k, None, Some(chunk)))?; + let out = read_pca(&path)?; + std::fs::remove_file(path).ok(); + Ok(out) + }; + + // Baseline: single-threaded. + let (emb1, vr1) = run(1, 1000, "t1")?; + // 8 threads, same chunking -> must match bit-for-bit. + let (emb8, vr8) = run(8, 1000, "t8")?; + assert_eq!(emb1, emb8, "embeddings differ between 1 and 8 threads"); + assert_eq!(vr1, vr8, "variance ratios differ between 1 and 8 threads"); + + // Repeat the 8-thread run -> identical to itself (no run-to-run drift). + let (emb8b, vr8b) = run(8, 1000, "t8b")?; + assert_eq!(emb8, emb8b, "8-thread run not reproducible across repeats"); + assert_eq!(vr8, vr8b); + + Ok(()) + } +} diff --git a/src/backed/processing/pipeline.rs b/src/backed/processing/pipeline.rs new file mode 100644 index 0000000..10af679 --- /dev/null +++ b/src/backed/processing/pipeline.rs @@ -0,0 +1,167 @@ +//! # Fused out-of-core preprocessing +//! +//! `preprocess_backed` runs QC + `normalize_total` + `log1p` against a disk-backed `.h5ad` in a +//! single fused job, writing one output file. The key saving over calling the three ops +//! separately: the per-cell total computed for QC **is** the row-sum the normalization needs, so +//! it is computed once. Two streamed passes total (QC accumulate, then normalize+log1p write), +//! instead of the ~4 read/write passes the separate CLI commands incur — roughly halving I/O. + +use std::path::Path; + +use anndata::{AnnData, AnnDataOp, ArrayData, ArrayElemOp, Backend}; +use anndata_hdf5::H5; + +use super::qc::{stream_qc, write_metrics}; +use super::transformation::{normalize_chunk, DEFAULT_CHUNK_SIZE}; + +/// QC (on raw counts) + `normalize_total(target_sum)` + optional `log1p`, fused. Writes a new +/// `.h5ad` at `output` with QC metrics in obs/var and the normalized (log1p) matrix in `X`. +pub fn preprocess_backed( + input: &Path, + output: &Path, + target_sum: f64, + log1p: bool, + chunk_size: Option, +) -> anyhow::Result<()> { + let chunk_size = chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE); + let adata = AnnData::
::open(H5::open(input)?)?; + let n_obs = adata.n_obs(); + + // Pass 1: QC on raw counts. acc.total holds per-cell totals == normalization row-sums. + let (acc, mito_mask) = stream_qc(&adata, chunk_size)?; + + // Output: copy obs/var (+ names), then layer the QC metrics on top. + let out = AnnData::
::new(output)?; + out.set_obs_names(adata.obs_names())?; + out.set_var_names(adata.var_names())?; + out.set_obs(adata.read_obs()?)?; + out.set_var(adata.read_var()?)?; + write_metrics(&out, &mito_mask, &acc, n_obs)?; + + // Pass 2: normalize (reusing the totals from pass 1) + log1p, streamed straight to X. + let sums = acc.total; + let iter = adata + .x() + .iter::(chunk_size) + .map(move |(chunk, start, _end)| { + normalize_chunk(chunk, start, &sums, target_sum, log1p).expect("normalize chunk") + }); + out.set_x_from_iter(iter)?; + + out.close()?; + adata.close()?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use anndata::data::DynCsrMatrix; + use nalgebra_sparse::{CooMatrix, CsrMatrix}; + + fn tmp(name: &str) -> std::path::PathBuf { + let mut p = std::env::temp_dir(); + p.push(name); + let _ = std::fs::remove_file(&p); + p + } + + #[test] + fn preprocess_fused_matches_separate_semantics() -> anyhow::Result<()> { + let inp = tmp("sr_pre_in.h5ad"); + let out = tmp("sr_pre_out.h5ad"); + let rows = vec![ + vec![10.0_f64, 0.0, 5.0, 5.0], + vec![1.0, 2.0, 3.0, 4.0], + vec![0.0, 0.0, 0.0, 0.0], + ]; + let (nr, nc) = (rows.len(), rows[0].len()); + let mut coo = CooMatrix::::new(nr, nc); + for (i, r) in rows.iter().enumerate() { + for (j, &v) in r.iter().enumerate() { + if v != 0.0 { + coo.push(i, j, v as f32); + } + } + } + let adata = AnnData::
::new(&inp)?; + adata.set_obs_names((0..nr).map(|i| format!("c{i}")).collect::>().into())?; + adata.set_var_names(["MT-a", "g1", "g2", "g3"].iter().map(|s| s.to_string()).collect::>().into())?; + adata.set_x(ArrayData::CsrMatrix(DynCsrMatrix::F32(CsrMatrix::from(&coo))))?; + adata.close()?; + + preprocess_backed(&inp, &out, 1e4, true, Some(2))?; + + // Verify: obs has QC (total_counts on raw), X is normalize_total+log1p. + let res = AnnData::
::open(H5::open(&out)?)?; + let obs = res.read_obs()?; + let total = obs.column("total_counts")?.f64()?; + assert_eq!(total.get(0), Some(20.0)); // raw total of row 0 + assert_eq!(total.get(1), Some(10.0)); + let pct_mito = obs.column("pct_counts_mito")?.f64()?; + assert!((pct_mito.get(0).unwrap() - 50.0).abs() < 1e-6); // 10/20 + + let x = res.x().get::()?.unwrap(); + let m = match x { + ArrayData::CsrMatrix(DynCsrMatrix::F32(m)) => m, + _ => panic!("expected CSR f32"), + }; + // row 0, col 0: raw 10, total 20 -> 10/20*1e4 = 5000 -> ln1p(5000) + let row0: Vec<(usize, f32)> = m.row(0).col_indices().iter().zip(m.row(0).values()).map(|(&c, &v)| (c, v)).collect(); + let v00 = row0.iter().find(|(c, _)| *c == 0).unwrap().1 as f64; + assert!((v00 - (5000.0_f64).ln_1p()).abs() < 1e-2, "v00={v00}"); + + res.close()?; + for p in [inp, out] { + std::fs::remove_file(p).ok(); + } + Ok(()) + } + + /// Determinism: the fused preprocess (parallel QC + normalize/log1p) under 1 vs 8 threads must + /// produce bit-identical X values and obs totals. + #[test] + fn preprocess_is_deterministic_across_thread_counts() -> anyhow::Result<()> { + let (nr, nc) = (4000usize, 60usize); + let build = |tag: &str| -> anyhow::Result { + let mut coo = CooMatrix::::new(nr, nc); + for i in 0..nr { + for j in 0..nc { + let v = (((i * 23 + j * 13) % 9) as f32) * 0.5; + if v != 0.0 { + coo.push(i, j, v); + } + } + } + let p = tmp(&format!("sr_pre_det_{tag}.h5ad")); + let a = AnnData::
::new(&p)?; + a.set_obs_names((0..nr).map(|i| format!("c{i}")).collect::>().into())?; + a.set_var_names((0..nc).map(|j| format!("g{j}")).collect::>().into())?; + a.set_x(ArrayData::CsrMatrix(DynCsrMatrix::F32(CsrMatrix::from(&coo))))?; + a.close()?; + Ok(p) + }; + let run = |threads: usize, tag: &str| -> anyhow::Result<(Vec, Vec)> { + let inp = build(tag)?; + let out = tmp(&format!("sr_pre_det_out_{tag}.h5ad")); + let pool = rayon::ThreadPoolBuilder::new().num_threads(threads).build()?; + pool.install(|| preprocess_backed(&inp, &out, 1e4, true, Some(700)))?; + let a = AnnData::
::open(H5::open(&out)?)?; + let xvals = match a.x().get::()?.unwrap() { + ArrayData::CsrMatrix(DynCsrMatrix::F32(m)) => m.values().to_vec(), + _ => panic!("expected CSR f32"), + }; + let total = a.read_obs()?.column("total_counts")?.f64()?.into_iter().map(|x| x.unwrap_or(f64::NAN)).collect(); + a.close()?; + for p in [inp, out] { + std::fs::remove_file(p).ok(); + } + Ok((xvals, total)) + }; + let (x1, t1) = run(1, "t1")?; + let (x8, t8) = run(8, "t8")?; + assert_eq!(x1, x8, "X values differ across thread counts"); + assert_eq!(t1, t8, "obs total_counts differ across thread counts"); + Ok(()) + } +} diff --git a/src/backed/processing/qc.rs b/src/backed/processing/qc.rs new file mode 100644 index 0000000..ef8f0a5 --- /dev/null +++ b/src/backed/processing/qc.rs @@ -0,0 +1,336 @@ +//! # Out-of-core QC metrics (disk-backed, single streaming pass) +//! +//! Computes the same cell- and gene-level QC metrics as +//! [`crate::memory::statistics::qc::qc_metrics`], but streams `X` in row-chunks so peak memory +//! is bounded by one chunk plus a few per-cell / per-gene accumulator vectors. `X` is never +//! modified — the metrics are written back into `obs`/`var` of the same file in place. + +use std::path::Path; + +use anndata::data::DynCsrMatrix; +use anndata::{AnnData, AnnDataOp, ArrayData, ArrayElemOp, Backend}; +use anndata_hdf5::H5; +use anyhow::bail; +use polars::prelude::Column; + +use super::det::det_block_reduce; +use super::transformation::DEFAULT_CHUNK_SIZE; + +const PERCENT_TOP: [usize; 4] = [50, 100, 200, 500]; + +/// Per-cell metrics for one cell: (n_genes, total, mito_total, top-N proportions). +type CellMetrics = (u32, f64, f64, [f64; PERCENT_TOP.len()]); + +/// One chunk's QC contribution: per-cell metrics (in row order) + per-gene partials. +struct QcChunk { + cells: Vec, + col_total: Vec, + col_nnz: Vec, +} + +/// Per-cell and per-gene accumulators filled during the streaming pass. +pub(crate) struct QcAcc { + // cell-level (indexed by global obs index) + n_genes: Vec, + /// Per-cell total counts — this is also the row-sum the normalization step needs, so a fused + /// pipeline computes it once here instead of in a separate pass. + pub(crate) total: Vec, + mito_total: Vec, + pct_top: Vec<[f64; PERCENT_TOP.len()]>, + // gene-level (indexed by var index, accumulated across chunks) + col_total: Vec, + col_nnz: Vec, +} + +/// Stream `X` once and accumulate all QC metrics; returns the accumulators and the mito mask. +/// Shared by [`qc_metrics_backed`] and the fused preprocess pipeline. +pub(crate) fn stream_qc( + adata: &AnnData
, + chunk_size: usize, +) -> anyhow::Result<(QcAcc, Vec)> { + let (n_obs, n_vars) = (adata.n_obs(), adata.n_vars()); + let mito_mask: Vec = adata + .var_names() + .into_vec() + .iter() + .map(|n| n.starts_with("MT-") || n.starts_with("mt-")) + .collect(); + let mut acc = QcAcc { + n_genes: vec![0; n_obs], + total: vec![0.0; n_obs], + mito_total: vec![0.0; n_obs], + pct_top: vec![[0.0; PERCENT_TOP.len()]; n_obs], + col_total: vec![0.0; n_vars], + col_nnz: vec![0; n_vars], + }; + for (chunk, start, _end) in adata.x().iter::(chunk_size) { + let qc = qc_chunk(&chunk, &mito_mask, n_vars)?; + for (i, (ng, tot, mito, pct)) in qc.cells.into_iter().enumerate() { + let g = start + i; + acc.n_genes[g] = ng; + acc.total[g] = tot; + acc.mito_total[g] = mito; + acc.pct_top[g] = pct; + } + for (a, b) in acc.col_total.iter_mut().zip(qc.col_total.iter()) { + *a += *b; + } + for (a, b) in acc.col_nnz.iter_mut().zip(qc.col_nnz.iter()) { + *a += *b; + } + } + Ok((acc, mito_mask)) +} + +/// Compute standard QC metrics out-of-core and store them in `obs`/`var` of the file in place. +/// +/// Mitochondrial genes are detected by an `MT-`/`mt-` var-name prefix (matching the in-memory +/// `qc_metrics`). Adds to `obs`: `n_genes_by_counts`, `total_counts`, +/// `pct_counts_in_top_{50,100,200,500}_genes`, `total_counts_mito`, `pct_counts_mito` +/// (+ `log1p_` variants); to `var`: `mito`, `n_cells_by_counts`, `mean_counts`, +/// `pct_dropout_by_counts`, `total_counts` (+ `log1p_` variants). +pub fn qc_metrics_backed(path: &Path, chunk_size: Option) -> anyhow::Result<()> { + let chunk_size = chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE); + let adata = AnnData::
::open(H5::open_rw(path)?)?; + let n_obs = adata.n_obs(); + let (acc, mito_mask) = stream_qc(&adata, chunk_size)?; + write_metrics(&adata, &mito_mask, &acc, n_obs)?; + adata.close()?; + Ok(()) +} + +/// One chunk's QC contribution via a deterministic block reduction over rows. Per-cell metrics +/// come out in row order; per-gene partials are summed in fixed (block) order. +fn qc_chunk(chunk: &ArrayData, mito_mask: &[bool], n_vars: usize) -> anyhow::Result { + match chunk { + ArrayData::CsrMatrix(DynCsrMatrix::F32(m)) => Ok(qc_rows(m, mito_mask, n_vars)), + ArrayData::CsrMatrix(DynCsrMatrix::F64(m)) => Ok(qc_rows(m, mito_mask, n_vars)), + other => bail!("OOC qc supports only F32/F64 CSR matrices, got {:?}", other), + } +} + +fn qc_rows( + m: &nalgebra_sparse::CsrMatrix, + mito_mask: &[bool], + n_vars: usize, +) -> QcChunk { + det_block_reduce( + m.nrows(), + || QcChunk { + cells: Vec::new(), + col_total: vec![0.0_f64; n_vars], + col_nnz: vec![0_u32; n_vars], + }, + |p, r| { + let row = m.row(r); + let cols = row.col_indices(); + let vals = row.values(); + let mut total = 0.0; + let mut mito = 0.0; + for (&c, &v) in cols.iter().zip(vals.iter()) { + let v = v.to_f64().unwrap_or(0.0); + total += v; + if mito_mask[c] { + mito += v; + } + p.col_total[c] += v; + p.col_nnz[c] += 1; + } + let mut pct = [0.0_f64; PERCENT_TOP.len()]; + if total > 0.0 { + let mut scratch: Vec = + vals.iter().map(|&v| v.to_f64().unwrap_or(0.0)).collect(); + scratch.sort_unstable_by(|a, b| { + b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal) + }); + let mut running = 0.0; + let mut next = 0; + for (rank, &v) in scratch.iter().enumerate() { + running += v; + while next < PERCENT_TOP.len() + && rank + 1 == PERCENT_TOP[next].min(scratch.len()) + { + pct[next] = running / total * 100.0; + next += 1; + } + } + while next < PERCENT_TOP.len() { + pct[next] = 100.0; + next += 1; + } + } + p.cells.push((vals.len() as u32, total, mito, pct)); + }, + |acc, p| { + acc.cells.extend(p.cells); + for (a, b) in acc.col_total.iter_mut().zip(p.col_total.iter()) { + *a += *b; + } + for (a, b) in acc.col_nnz.iter_mut().zip(p.col_nnz.iter()) { + *a += *b; + } + }, + ) +} + +pub(crate) fn write_metrics( + adata: &AnnData
, + mito_mask: &[bool], + acc: &QcAcc, + n_obs: usize, +) -> anyhow::Result<()> { + let log1p = |xs: &[f64]| -> Vec { xs.iter().map(|&x| x.ln_1p()).collect() }; + + // ----- obs ----- + let mut obs = adata.read_obs()?; + let n_genes_f: Vec = acc.n_genes.iter().map(|&n| n as f64).collect(); + obs.with_column(Column::new("n_genes_by_counts".into(), acc.n_genes.clone()))?; + obs.with_column(Column::new("log1p_n_genes_by_counts".into(), log1p(&n_genes_f)))?; + obs.with_column(Column::new("total_counts".into(), acc.total.clone()))?; + obs.with_column(Column::new("log1p_total_counts".into(), log1p(&acc.total)))?; + for (k, n) in PERCENT_TOP.iter().enumerate() { + let col: Vec = acc.pct_top.iter().map(|p| p[k]).collect(); + obs.with_column(Column::new(format!("pct_counts_in_top_{n}_genes").into(), col))?; + } + obs.with_column(Column::new("total_counts_mito".into(), acc.mito_total.clone()))?; + obs.with_column(Column::new("log1p_total_counts_mito".into(), log1p(&acc.mito_total)))?; + let pct_mito: Vec = acc + .mito_total + .iter() + .zip(acc.total.iter()) + .map(|(&m, &t)| if t > 0.0 { m / t * 100.0 } else { 0.0 }) + .collect(); + obs.with_column(Column::new("pct_counts_mito".into(), pct_mito))?; + adata.set_obs(obs)?; + + // ----- var ----- + let mut var = adata.read_var()?; + let mean: Vec = acc.col_total.iter().map(|&t| t / n_obs as f64).collect(); + let pct_dropout: Vec = acc + .col_nnz + .iter() + .map(|&n| (1.0 - n as f64 / n_obs as f64) * 100.0) + .collect(); + var.with_column(Column::new("mito".into(), mito_mask.to_vec()))?; + var.with_column(Column::new("n_cells_by_counts".into(), acc.col_nnz.clone()))?; + var.with_column(Column::new("mean_counts".into(), mean.clone()))?; + var.with_column(Column::new("log1p_mean_counts".into(), log1p(&mean)))?; + var.with_column(Column::new("pct_dropout_by_counts".into(), pct_dropout))?; + var.with_column(Column::new("total_counts".into(), acc.col_total.clone()))?; + var.with_column(Column::new("log1p_total_counts".into(), log1p(&acc.col_total)))?; + adata.set_var(var)?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::backed::processing::transformation::log1p_backed; // ensure module wiring + use anndata::data::DynCsrMatrix; + use nalgebra_sparse::{CooMatrix, CsrMatrix}; + + fn write_fixture(path: &Path, rows: &[Vec], var_names: &[&str]) -> anyhow::Result<()> { + let (nr, nc) = (rows.len(), rows[0].len()); + let mut coo = CooMatrix::::new(nr, nc); + for (i, r) in rows.iter().enumerate() { + for (j, &v) in r.iter().enumerate() { + if v != 0.0 { + coo.push(i, j, v as f32); + } + } + } + let adata = AnnData::
::new(path)?; + adata.set_obs_names((0..nr).map(|i| format!("c{i}")).collect::>().into())?; + adata.set_var_names(var_names.iter().map(|s| s.to_string()).collect::>().into())?; + adata.set_x(ArrayData::CsrMatrix(DynCsrMatrix::F32(CsrMatrix::from(&coo))))?; + adata.close()?; + Ok(()) + } + + fn tmp(name: &str) -> std::path::PathBuf { + let mut p = std::env::temp_dir(); + p.push(name); + let _ = std::fs::remove_file(&p); + p + } + + #[test] + fn ooc_qc_matches_hand_computed() -> anyhow::Result<()> { + let _ = log1p_backed; // silence unused import if test layout changes + let path = tmp("sr_ooc_qc.h5ad"); + // 3 cells × 4 genes; gene 0 is "MT-x" (mitochondrial). + let rows = vec![ + vec![10.0, 0.0, 5.0, 5.0], // total 20, mito 10 -> 50% + vec![0.0, 0.0, 0.0, 0.0], // empty + vec![1.0, 2.0, 3.0, 4.0], // total 10, mito 1 -> 10% + ]; + write_fixture(&path, &rows, &["MT-x", "g1", "g2", "g3"])?; + qc_metrics_backed(&path, Some(2))?; + + let adata = AnnData::
::open(H5::open(&path)?)?; + let obs = adata.read_obs()?; + let var = adata.read_var()?; + + let total = obs.column("total_counts")?.f64()?; + assert_eq!(total.get(0), Some(20.0)); + assert_eq!(total.get(2), Some(10.0)); + let n_genes = obs.column("n_genes_by_counts")?.u32()?; + assert_eq!(n_genes.get(0), Some(3)); + assert_eq!(n_genes.get(1), Some(0)); + let pct_mito = obs.column("pct_counts_mito")?.f64()?; + assert!((pct_mito.get(0).unwrap() - 50.0).abs() < 1e-6); + assert!((pct_mito.get(2).unwrap() - 10.0).abs() < 1e-6); + + // var: gene 0 total = 11 across cells, detected in 2 cells, mito=true + let vtotal = var.column("total_counts")?.f64()?; + assert_eq!(vtotal.get(0), Some(11.0)); + let ncells = var.column("n_cells_by_counts")?.u32()?; + assert_eq!(ncells.get(0), Some(2)); + let mito = var.column("mito")?.bool()?; + assert_eq!(mito.get(0), Some(true)); + assert_eq!(mito.get(1), Some(false)); + + adata.close()?; + std::fs::remove_file(path).ok(); + Ok(()) + } + + /// Determinism: QC under 1 vs 8 threads must yield bit-identical obs/var metrics (guards the + /// parallel per-gene reduction and per-cell ordering). + #[test] + fn ooc_qc_is_deterministic_across_thread_counts() -> anyhow::Result<()> { + let (nr, nc) = (4000usize, 50usize); + let names: Vec<&str> = (0..nc) + .map(|j| if j < 3 { "MT-x" } else { "g" }) + .collect(); + let build = |tag: &str| -> anyhow::Result { + let rows: Vec> = (0..nr) + .map(|i| (0..nc).map(|j| (((i * 19 + j * 7) % 11) as f64)).collect()) + .collect(); + let p = tmp(&format!("sr_qc_det_{tag}.h5ad")); + write_fixture(&p, &rows, &names)?; + Ok(p) + }; + let run = |threads: usize, tag: &str| -> anyhow::Result<(Vec, Vec, Vec)> { + let p = build(tag)?; + let pool = rayon::ThreadPoolBuilder::new().num_threads(threads).build()?; + pool.install(|| qc_metrics_backed(&p, Some(700)))?; + let a = AnnData::
::open(H5::open(&p)?)?; + let obs = a.read_obs()?; + let var = a.read_var()?; + let total = obs.column("total_counts")?.f64()?.into_iter().map(|x| x.unwrap_or(f64::NAN)).collect(); + let pct = obs.column("pct_counts_in_top_50_genes")?.f64()?.into_iter().map(|x| x.unwrap_or(f64::NAN)).collect(); + let vtot = var.column("total_counts")?.f64()?.into_iter().map(|x| x.unwrap_or(f64::NAN)).collect(); + a.close()?; + std::fs::remove_file(p).ok(); + Ok((total, pct, vtot)) + }; + let a = run(1, "t1")?; + let b = run(8, "t8")?; + assert_eq!(a.0, b.0, "obs total_counts differ across thread counts"); + assert_eq!(a.1, b.1, "obs pct_top differ across thread counts"); + assert_eq!(a.2, b.2, "var total_counts differ across thread counts"); + Ok(()) + } +} diff --git a/src/backed/processing/transformation.rs b/src/backed/processing/transformation.rs new file mode 100644 index 0000000..8cd11a7 --- /dev/null +++ b/src/backed/processing/transformation.rs @@ -0,0 +1,290 @@ +//! # Out-of-core transformations (disk-backed, streaming) +//! +//! These operate on a disk-backed `.h5ad` without ever holding the full expression matrix in +//! memory. The expression matrix `X` is streamed in row-chunks (`x().iter::(n)`), +//! each chunk is transformed, and the result is streamed straight back out to a new `.h5ad` +//! (`set_x_from_iter`, which appends to an extendable HDF5 dataset). Peak memory is therefore +//! bounded by one chunk, not the dataset — this is the basis for processing larger-than-RAM data +//! natively, the niche scanpy reaches for Dask to fill. +//! +//! Only CSR `f32`/`f64` matrices are supported (the standard scRNA-seq layout). + +use std::path::Path; + +use anndata::data::DynCsrMatrix; +use anndata::{AnnData, AnnDataOp, ArrayData, ArrayElemOp, Backend}; +use anndata_hdf5::H5; +use anyhow::bail; +use single_algebra::Log1P; + +/// Default streaming chunk size (rows/cells per block). +pub const DEFAULT_CHUNK_SIZE: usize = 5_000; + +/// Open a backed `.h5ad` for reading and create a fresh backed `.h5ad` for the result, copying +/// the (small) obs/var annotations and dimension names across. Returns `(input, output)`. +fn open_io(input: &Path, output: &Path) -> anyhow::Result<(AnnData
, AnnData
)> { + let adata = AnnData::
::open(H5::open(input)?)?; + let out = AnnData::
::new(output)?; + out.set_obs_names(adata.obs_names())?; + out.set_var_names(adata.var_names())?; + out.set_obs(adata.read_obs()?)?; + out.set_var(adata.read_var()?)?; + Ok((adata, out)) +} + +/// Apply `log1p` (natural `ln(1 + x)`) to every value of `X`, streaming, writing the result to +/// `output`. Memory stays bounded by `chunk_size` rows regardless of dataset size. +pub fn log1p_backed(input: &Path, output: &Path, chunk_size: Option) -> anyhow::Result<()> { + let chunk_size = chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE); + let (adata, out) = open_io(input, output)?; + + let iter = adata + .x() + .iter::(chunk_size) + .map(|(chunk, _start, _end)| { + // The iterator can't yield a Result, so transform failures (unsupported dtype) panic + // with context. We guarantee dtype up front in the public callers. + log1p_chunk(chunk).expect("log1p chunk transform") + }); + out.set_x_from_iter(iter)?; + + out.close()?; + adata.close()?; + Ok(()) +} + +/// Library-size normalize each cell to `target_sum`, then optionally `log1p`, streaming to +/// `output`. Two passes over `X`: pass 1 accumulates per-cell sums (one f64 per cell — tiny), +/// pass 2 streams chunks, scales each row, and writes. Memory stays bounded by `chunk_size`. +pub fn normalize_total_backed( + input: &Path, + output: &Path, + target_sum: f64, + log1p: bool, + chunk_size: Option, +) -> anyhow::Result<()> { + let chunk_size = chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE); + + // Pass 1: per-cell totals (streamed). Keeps only n_obs f64s in memory. + let sums = { + let adata = AnnData::
::open(H5::open(input)?)?; + let n_obs = adata.n_obs(); + let mut sums = vec![0.0_f64; n_obs]; + for (chunk, start, _end) in adata.x().iter::(chunk_size) { + accumulate_row_sums(&chunk, start, &mut sums)?; + } + adata.close()?; + sums + }; + + // Pass 2: scale (and optionally log1p) each row, streamed back out. + let (adata, out) = open_io(input, output)?; + let iter = adata + .x() + .iter::(chunk_size) + .map(move |(chunk, start, _end)| { + normalize_chunk(chunk, start, &sums, target_sum, log1p) + .expect("normalize chunk transform") + }); + out.set_x_from_iter(iter)?; + out.close()?; + adata.close()?; + Ok(()) +} + +/// `ln(1 + x)` over a chunk's stored values, preserving sparsity (zeros map to 0). +fn log1p_chunk(chunk: ArrayData) -> anyhow::Result { + match chunk { + ArrayData::CsrMatrix(DynCsrMatrix::F32(mut m)) => { + m.log1p_normalize()?; + Ok(ArrayData::CsrMatrix(DynCsrMatrix::F32(m))) + } + ArrayData::CsrMatrix(DynCsrMatrix::F64(mut m)) => { + m.log1p_normalize()?; + Ok(ArrayData::CsrMatrix(DynCsrMatrix::F64(m))) + } + other => bail!("OOC log1p supports only F32/F64 CSR matrices, got {:?}", other), + } +} + +/// Add each row's value-sum into `sums[start + local_row]`. +fn accumulate_row_sums(chunk: &ArrayData, start: usize, sums: &mut [f64]) -> anyhow::Result<()> { + match chunk { + ArrayData::CsrMatrix(DynCsrMatrix::F32(m)) => { + for (i, row) in m.row_iter().enumerate() { + sums[start + i] = row.values().iter().map(|&v| v as f64).sum(); + } + Ok(()) + } + ArrayData::CsrMatrix(DynCsrMatrix::F64(m)) => { + for (i, row) in m.row_iter().enumerate() { + sums[start + i] = row.values().iter().sum(); + } + Ok(()) + } + other => bail!("OOC normalize supports only F32/F64 CSR matrices, got {:?}", other), + } +} + +/// Scale each row to `target_sum` (rows with zero total are left untouched), optionally log1p. +pub(crate) fn normalize_chunk( + chunk: ArrayData, + start: usize, + sums: &[f64], + target_sum: f64, + log1p: bool, +) -> anyhow::Result { + match chunk { + ArrayData::CsrMatrix(DynCsrMatrix::F32(mut m)) => { + scale_rows_f32(&mut m, start, sums, target_sum, log1p); + Ok(ArrayData::CsrMatrix(DynCsrMatrix::F32(m))) + } + ArrayData::CsrMatrix(DynCsrMatrix::F64(mut m)) => { + scale_rows_f64(&mut m, start, sums, target_sum, log1p); + Ok(ArrayData::CsrMatrix(DynCsrMatrix::F64(m))) + } + other => bail!("OOC normalize supports only F32/F64 CSR matrices, got {:?}", other), + } +} + +macro_rules! impl_scale_rows { + ($name:ident, $t:ty) => { + fn $name( + m: &mut nalgebra_sparse::CsrMatrix<$t>, + start: usize, + sums: &[f64], + target_sum: f64, + log1p: bool, + ) { + let offsets = m.row_offsets().to_vec(); + let values = m.values_mut(); + for row in 0..offsets.len() - 1 { + let total = sums[start + row]; + if total <= 0.0 { + continue; + } + let scale = target_sum / total; + for v in &mut values[offsets[row]..offsets[row + 1]] { + let mut x = (*v as f64) * scale; + if log1p { + x = x.ln_1p(); + } + *v = x as $t; + } + } + } + }; +} +impl_scale_rows!(scale_rows_f32, f32); +impl_scale_rows!(scale_rows_f64, f64); + +#[cfg(test)] +mod tests { + use super::*; + use anndata::data::DynCsrMatrix; + use anndata::ArrayData; + use nalgebra_sparse::{CooMatrix, CsrMatrix}; + + fn write_fixture(path: &Path, rows: &[Vec]) -> anyhow::Result<()> { + let (nr, nc) = (rows.len(), rows[0].len()); + let mut coo = CooMatrix::::new(nr, nc); + for (i, r) in rows.iter().enumerate() { + for (j, &v) in r.iter().enumerate() { + if v != 0.0 { + coo.push(i, j, v as f32); + } + } + } + let csr = CsrMatrix::from(&coo); + let adata = AnnData::
::new(path)?; + adata.set_obs_names((0..nr).map(|i| format!("c{i}")).collect::>().into())?; + adata.set_var_names((0..nc).map(|j| format!("g{j}")).collect::>().into())?; + adata.set_x(ArrayData::CsrMatrix(DynCsrMatrix::F32(csr)))?; + adata.close()?; + Ok(()) + } + + fn read_dense(path: &Path) -> anyhow::Result> { + let adata = AnnData::
::open(H5::open(path)?)?; + let x = adata.x().get::()?.unwrap(); + // Densify the CSR ourselves (independent of crate conversion helpers). + let dense = match x { + ArrayData::CsrMatrix(DynCsrMatrix::F32(m)) => { + let mut d = ndarray::Array2::::zeros((m.nrows(), m.ncols())); + for (i, row) in m.row_iter().enumerate() { + for (&j, &v) in row.col_indices().iter().zip(row.values().iter()) { + d[[i, j]] = v as f64; + } + } + d + } + other => anyhow::bail!("unexpected X type {:?}", other), + }; + adata.close()?; + Ok(dense) + } + + fn tmp(name: &str) -> std::path::PathBuf { + let mut p = std::env::temp_dir(); + p.push(name); + let _ = std::fs::remove_file(&p); + p + } + + + + /// Streaming log1p over a backed file must equal elementwise ln(1+x), with a tiny chunk size + /// so multiple chunks are exercised. + #[test] + fn ooc_log1p_matches_dense() -> anyhow::Result<()> { + let inp = tmp("sr_ooc_log1p_in.h5ad"); + let out = tmp("sr_ooc_log1p_out.h5ad"); + let rows = vec![ + vec![1.0, 0.0, 3.0], + vec![0.0, 7.0, 0.0], + vec![2.0, 2.0, 2.0], + vec![0.0, 0.0, 0.0], + vec![5.0, 0.0, 1.0], + ]; + write_fixture(&inp, &rows)?; + log1p_backed(&inp, &out, Some(2))?; // chunk_size=2 -> 3 chunks + + let got = read_dense(&out)?; + for (i, r) in rows.iter().enumerate() { + for (j, &v) in r.iter().enumerate() { + assert!((got[[i, j]] - (v).ln_1p()).abs() < 1e-6, "[{i},{j}] {}", got[[i, j]]); + } + } + for p in [inp, out] { + std::fs::remove_file(p).ok(); + } + Ok(()) + } + + /// Streaming normalize_total(1e4)+log1p must match the row-scaled-then-log1p reference. + #[test] + fn ooc_normalize_total_matches_reference() -> anyhow::Result<()> { + let inp = tmp("sr_ooc_norm_in.h5ad"); + let out = tmp("sr_ooc_norm_out.h5ad"); + let rows = vec![ + vec![1.0, 3.0, 0.0], // sum 4 + vec![0.0, 0.0, 0.0], // zero row -> untouched + vec![2.0, 2.0, 6.0], // sum 10 + ]; + write_fixture(&inp, &rows)?; + normalize_total_backed(&inp, &out, 1e4, true, Some(2))?; + + let got = read_dense(&out)?; + for (i, r) in rows.iter().enumerate() { + let total: f64 = r.iter().sum(); + for (j, &v) in r.iter().enumerate() { + let expect = if total > 0.0 { (v / total * 1e4).ln_1p() } else { 0.0 }; + assert!((got[[i, j]] - expect).abs() < 1e-3, "[{i},{j}] {} vs {}", got[[i, j]], expect); + } + } + for p in [inp, out] { + std::fs::remove_file(p).ok(); + } + Ok(()) + } +} diff --git a/src/memory/processing/hvg/mod.rs b/src/memory/processing/hvg/mod.rs index e706847..56db748 100644 --- a/src/memory/processing/hvg/mod.rs +++ b/src/memory/processing/hvg/mod.rs @@ -167,6 +167,54 @@ pub fn compute_highly_variable_genes( } } +/// Seurat-flavor dispersion processing from per-gene means and variances. +/// +/// Shared by the in-memory and out-of-core HVG paths: given the same `raw_means`/`variances` +/// (column mean and sample variance of `X`) it returns identical +/// `(log1p_means, log_dispersions, dispersions_norm, highly_variable)`. Factoring it out keeps the +/// disk-backed implementation byte-for-byte consistent with the in-memory one. +/// `(log1p_means, log_dispersions, dispersions_norm, highly_variable)` from [`seurat_select`]. +pub(crate) type SeuratResult = (Vec, Vec, Vec, Vec); + +pub(crate) fn seurat_select( + raw_means: &[f64], + variances: &[f64], + params: &HVGParams, +) -> anyhow::Result { + let dispersions: Vec = raw_means + .iter() + .zip(variances.iter()) + .map(|(&mean, &var)| { + let safe_mean = if mean > 1e-12 { mean } else { 1e-12 }; + var / safe_mean + }) + .collect(); + + let log1p_means: Vec = raw_means.iter().map(|&x| (x + 1.0).ln()).collect(); + let log_dispersions: Vec = dispersions + .iter() + .map(|&x| if x > 0.0 { x.ln() } else { f64::NAN }) + .collect(); + + let (bin_indices, _) = equal_width_binning(&log1p_means, params.n_bins)?; + let (mut bin_means, mut bin_stds) = + calculate_bin_stats(&log_dispersions, &bin_indices, params.n_bins)?; + postprocess_seurat_dispersions(&mut bin_means, &mut bin_stds)?; + let normalized_dispersions = + normalize_dispersions(&log_dispersions, &bin_indices, &bin_means, &bin_stds)?; + + let highly_variable = subset_genes( + &log1p_means, + &normalized_dispersions, + params.n_top_genes, + params.min_mean, + params.max_mean, + params.min_dispersion, + )?; + + Ok((log1p_means, log_dispersions, normalized_dispersions, highly_variable)) +} + /// Post-process dispersion statistics for Seurat method to handle edge cases. /// /// Handles bins with single genes where standard deviation cannot be computed. @@ -506,57 +554,10 @@ fn compute_seurat_hvg( let variances: Vec = x.variance_whole::(&Direction::COLUMN)?; - // Calculate dispersions with proper handling of zero means - let dispersions: Vec = raw_means - .iter() - .zip(variances.iter()) - .map(|(&mean, &var)| { - let safe_mean = if mean > 1e-12 { mean } else { 1e-12 }; - var / safe_mean - }) - .collect(); - - // For Seurat flavor, use log1p of means for binning and storage - // This matches what Python does AFTER reverting log normalization - let log1p_means: Vec = raw_means.iter().map(|&x| (x + 1.0).ln()).collect(); - - // Log dispersions with NaN for zero dispersions (matching Python) - let log_dispersions: Vec = dispersions - .iter() - .map(|&x| { - if x > 0.0 { - x.ln() - } else { - f64::NAN // Python sets dispersion[dispersion == 0] = np.nan - } - }) - .collect(); - - let n_bins = params.n_bins; - - // Use equal-width binning on log1p_means (like Python's pd.cut) - let (bin_indices, _) = equal_width_binning(&log1p_means, n_bins)?; - - // Calculate mean and std for each bin - let (mut bin_means, mut bin_stds) = - calculate_bin_stats(&log_dispersions, &bin_indices, n_bins)?; - - // Handle single-gene bins (like Python's _postprocess_dispersions_seurat) - postprocess_seurat_dispersions(&mut bin_means, &mut bin_stds)?; - - // Normalize dispersions - let normalized_dispersions = - normalize_dispersions(&log_dispersions, &bin_indices, &bin_means, &bin_stds)?; - - // Select highly variable genes using raw means for filtering - let highly_variable = subset_genes( - &log1p_means, // Pass log-transformed means - &normalized_dispersions, - params.n_top_genes, - params.min_mean, - params.max_mean, - params.min_dispersion, - )?; + // Seurat dispersion binning/normalization + selection (shared with the out-of-core path so + // both produce identical results from the same per-gene means/variances). + let (log1p_means, log_dispersions, normalized_dispersions, highly_variable) = + seurat_select(&raw_means, &variances, ¶ms)?; // Store results - IMPORTANT: Store log1p means to match Python let mut var_df = adata.var().get_data(); diff --git a/src/memory/processing/mod.rs b/src/memory/processing/mod.rs index 7f0a73e..097186d 100644 --- a/src/memory/processing/mod.rs +++ b/src/memory/processing/mod.rs @@ -13,7 +13,7 @@ pub mod diffexp; pub mod dimred; pub mod filtering; -mod hvg; +pub(crate) mod hvg; mod transformation; pub mod enrichment;