diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3cc558c..cd57a37 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,7 +27,11 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - cache: pip + # No `cache: pip` on purpose: the resolved env (torch + lightning + + # scanpy + numba + cuda libs) is >2 GB of wheels, and the post-job + # cache-save tar reliably hits the 10-minute step timeout on the + # GitHub-hosted runners. A cold resolve from a warm PyPI mirror + # only takes ~25s, so caching is a net loss here. - name: Install scsims with test extras run: | diff --git a/README.md b/README.md index 8d205cf..17f3f54 100755 --- a/README.md +++ b/README.md @@ -116,10 +116,10 @@ labeled fine-tuning set. from scsims import SIMS import scanpy as sc +# Stage 1: unsupervised pretraining on a large unlabeled corpus. +# No class_label needed -- the pretrainer never reads cell type labels. unlabeled = sc.read_h5ad("big_unlabeled_corpus.h5ad") -sims = SIMS(data=unlabeled, class_label="cell_type") # label col is ignored at this stage - -# Stage 1: unsupervised pretraining. Cell type labels are not used. +sims = SIMS(data=unlabeled) sims.pretrain( pretraining_ratio=0.2, accelerator="gpu", @@ -127,9 +127,10 @@ sims.pretrain( max_epochs=50, ) -# Stage 2: supervised fine-tuning. SIMS automatically detects the -# attached pretrainer and warm-starts the classifier from its encoder -# weights before fitting. +# Stage 2: supervised fine-tuning on a smaller labeled set. Build a +# fresh SIMS with the labels, load the pretrainer checkpoint, and +# train(). SIMS automatically warm-starts the classifier from the +# pretrained encoder before fitting. labeled = sc.read_h5ad("smaller_labeled_set.h5ad") sims = SIMS(data=labeled, class_label="cell_type") sims.load_pretrainer("./sims_pretrain_checkpoints/best.ckpt") diff --git a/scsims/data.py b/scsims/data.py index 70ebab8..335829e 100644 --- a/scsims/data.py +++ b/scsims/data.py @@ -8,6 +8,7 @@ import anndata as an import numpy as np +import pandas as pd import torch from scipy.sparse import issparse from sklearn.model_selection import train_test_split @@ -186,7 +187,7 @@ def clean_sample( def generate_dataloaders( data: an.AnnData, - class_label: str, + class_label: Optional[str] = None, test_prop=0.2, stratify=True, batch_size: int = 16, @@ -199,11 +200,21 @@ def generate_dataloaders( if isinstance(data, str): data = an.read_h5ad(data) - current_labels = data.obs.loc[:, class_label] + if class_label is None: + # Unsupervised mode (e.g. pretraining): no real labels exist. Use a + # zero placeholder so AnnDatasetMatrix still yields (features, label) + # tuples — the pretrainer's training step discards the label + # component anyway. Stratification has no meaning here. + current_labels = pd.Series( + np.zeros(data.shape[0], dtype=np.int64), + index=data.obs.index, + ) + stratify = False + else: + current_labels = data.obs.loc[:, class_label] - # make sure data can be stratified - if stratify: - if len(current_labels.unique()) < 3: + # make sure data can be stratified + if stratify and len(current_labels.unique()) < 3: warnings.warn( "One class has less than 3 samples, disabling stratification" ) diff --git a/scsims/lightning_train.py b/scsims/lightning_train.py index 82b181e..841d857 100644 --- a/scsims/lightning_train.py +++ b/scsims/lightning_train.py @@ -57,10 +57,14 @@ def __init__( self.setup() def prepare_data(self): - unique_targets = self.data.obs.loc[:, self.class_label].unique() - label_encoder = LabelEncoder().fit(unique_targets) + # `class_label` is optional: pretraining workflows pass an unlabeled + # AnnData and never set it. In that case there's nothing to encode. + if self.class_label is None: + self.label_encoder = None + return - self.label_encoder = label_encoder + unique_targets = self.data.obs.loc[:, self.class_label].unique() + self.label_encoder = LabelEncoder().fit(unique_targets) if not pd.api.types.is_numeric_dtype(self.data.obs.loc[:, self.class_label]): print("Numerically encoding class labels") @@ -86,18 +90,22 @@ def setup(self, stage: Optional[str] = None): self.trainloader = loaders[0] self.valloader = None self.testloader = None - else: + else: self.trainloader, self.valloader, self.testloader = loaders - print("Calculating weights") - labels = self.data.obs.loc[:, self.class_label].values - self.weights = torch.from_numpy( - compute_class_weight( - y=labels, - classes=np.unique(labels), - class_weight="balanced", - ) - ).float() + if self.class_label is None: + # No labels => no class weights. Pretraining doesn't use them. + self.weights = None + else: + print("Calculating weights") + labels = self.data.obs.loc[:, self.class_label].values + self.weights = torch.from_numpy( + compute_class_weight( + y=labels, + classes=np.unique(labels), + class_weight="balanced", + ) + ).float() self.setuped = True @@ -112,6 +120,8 @@ def test_dataloader(self): @cached_property def num_labels(self): + if self.class_label is None: + return None return self.data.obs.loc[:, self.class_label].nunique() @cached_property diff --git a/scsims/scvi_api.py b/scsims/scvi_api.py index c31b844..3a4a9aa 100644 --- a/scsims/scvi_api.py +++ b/scsims/scvi_api.py @@ -203,6 +203,13 @@ def load_pretrainer(self, weights_path: str, **kwargs) -> SIMSPretrainer: # ------------------------------------------------------------------ def train(self, *args, **kwargs): + if self.datamodule.class_label is None: + raise ValueError( + "SIMS.train() is supervised and requires a class_label. Build " + "the SIMS instance with `SIMS(data=adata, class_label=...)` " + "before calling .train(). For unsupervised pretraining on an " + "unlabeled dataset, use SIMS(data=adata).pretrain(...) instead." + ) print("Beginning training") if not hasattr(self, "_trainer"): self.setup_trainer(*args, **kwargs) diff --git a/tests/test_pretraining.py b/tests/test_pretraining.py index a31d47e..d5a1867 100644 --- a/tests/test_pretraining.py +++ b/tests/test_pretraining.py @@ -96,6 +96,42 @@ def test_pretrain_warm_start_transfers_encoder_weights(synthetic_anndata, tmp_pa raise AssertionError("no shape-matching encoder keys to spot-check") +def test_pretrain_on_truly_unlabeled_anndata(synthetic_anndata, tmp_path): + """`SIMS(data=adata)` (no class_label) should support pretraining on + a literally unlabeled AnnData. The 3.x API forced users to pass a + class_label even when pretraining ignored it; v4 cleaned that up.""" + # Drop the labels column entirely so any code path that tries to + # read it would crash. + del synthetic_anndata.obs["blobs"] + assert "blobs" not in synthetic_anndata.obs.columns + + sims = SIMS(data=synthetic_anndata) # no class_label kwarg! + assert sims.datamodule.class_label is None + assert sims.datamodule.label_encoder is None + assert sims.datamodule.weights is None + + sims.pretrain( + accelerator="cpu", + devices=1, + max_epochs=2, + enable_progress_bar=False, + logger=False, + checkpoint_dir=str(tmp_path / "pretrain"), + ) + assert isinstance(sims.pretrainer, SIMSPretrainer) + + +def test_train_without_class_label_raises_clear_error(synthetic_anndata): + """Calling .train() on a SIMS built without a class_label should raise + a ValueError that points users at the right API.""" + import pytest + + del synthetic_anndata.obs["blobs"] + sims = SIMS(data=synthetic_anndata) + with pytest.raises(ValueError, match="class_label"): + sims.train() + + def test_load_pretrainer_round_trip(synthetic_anndata, tmp_path): """Two-process workflow: pretrain in 'process A', save .ckpt, fresh SIMS in 'process B' loads it via load_pretrainer() and warm-starts