diff --git a/docs/api/calib.rst b/docs/api/calib.rst index daff8e8c7..7599b2a10 100644 --- a/docs/api/calib.rst +++ b/docs/api/calib.rst @@ -22,6 +22,8 @@ confidence levels: - :class:`~pyhealth.calib.predictionset.SCRIB`: Class-specific risk control - :class:`~pyhealth.calib.predictionset.FavMac`: Value-maximizing sets with cost control - :class:`~pyhealth.calib.predictionset.CovariateLabel`: Covariate shift adaptive conformal +- :class:`~pyhealth.calib.predictionset.ClusterLabel`: K-means cluster-based conformal prediction +- :class:`~pyhealth.calib.predictionset.NeighborhoodLabel`: Neighborhood Conformal Prediction (NCP) Getting Started --------------- diff --git a/docs/api/calib/pyhealth.calib.predictionset.rst b/docs/api/calib/pyhealth.calib.predictionset.rst index 63d314203..fe445ea1b 100644 --- a/docs/api/calib/pyhealth.calib.predictionset.rst +++ b/docs/api/calib/pyhealth.calib.predictionset.rst @@ -16,6 +16,8 @@ Available Methods pyhealth.calib.predictionset.SCRIB pyhealth.calib.predictionset.FavMac pyhealth.calib.predictionset.CovariateLabel + pyhealth.calib.predictionset.ClusterLabel + pyhealth.calib.predictionset.NeighborhoodLabel LABEL (Least Ambiguous Set-valued Classifier) ---------------------------------------------- @@ -49,6 +51,22 @@ CovariateLabel (Covariate Shift Adaptive) :undoc-members: :show-inheritance: +ClusterLabel (K-means Cluster-based Conformal) +---------------------------------------------- + +.. autoclass:: pyhealth.calib.predictionset.ClusterLabel + :members: + :undoc-members: + :show-inheritance: + +NeighborhoodLabel (Neighborhood Conformal Prediction) +----------------------------------------------------- + +.. autoclass:: pyhealth.calib.predictionset.NeighborhoodLabel + :members: + :undoc-members: + :show-inheritance: + Helper Functions ---------------- diff --git a/examples/conformal_eeg/tuev_kmeans_conformal.py b/examples/conformal_eeg/tuev_kmeans_conformal.py index 82aa27e3f..b5a043dad 100644 --- a/examples/conformal_eeg/tuev_kmeans_conformal.py +++ b/examples/conformal_eeg/tuev_kmeans_conformal.py @@ -9,7 +9,8 @@ 6) Evaluates prediction-set coverage/miscoverage and efficiency on the test split. Example (from repo root): - python examples/conformal_eeg/tuev_kmeans_conformal.py --root downloads/tuev/v2.0.1/edf --n-clusters 5 + python examples/conformal_eeg/tuev_kmeans_conformal.py --root /srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf --n-clusters 5 + python examples/conformal_eeg/tuev_kmeans_conformal.py --quick-test --log-file quicktest_kmeans.log Notes: - ClusterLabel uses K-means clustering on embeddings to compute cluster-specific thresholds. @@ -20,11 +21,30 @@ import argparse import random +import sys from pathlib import Path import numpy as np import torch + +class _Tee: + """Writes to both a stream and a file.""" + + def __init__(self, stream, file): + self._stream = stream + self._file = file + + def write(self, data): + self._stream.write(data) + self._file.write(data) + self._file.flush() + + def flush(self): + self._stream.flush() + self._file.flush() + + from pyhealth.calib.predictionset.cluster import ClusterLabel from pyhealth.calib.utils import extract_embeddings from pyhealth.datasets import TUEVDataset, get_dataloader, split_by_sample_conformal @@ -40,13 +60,13 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--root", type=str, - default="downloads/tuev/v2.0.1/edf", + default="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf", help="Path to TUEV edf/ folder.", ) - parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) + parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--epochs", type=int, default=3) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--alpha", type=float, default=0.1, help="Miscoverage rate (e.g., 0.1 => 90% target coverage).") parser.add_argument( "--ratios", @@ -69,6 +89,17 @@ def parse_args() -> argparse.Namespace: default=None, help="Device string, e.g. 'cuda:0' or 'cpu'. Defaults to auto-detect.", ) + parser.add_argument( + "--log-file", + type=str, + default=None, + help="Path to log file. Stdout and stderr are teed to this file.", + ) + parser.add_argument( + "--quick-test", + action="store_true", + help="Smoke test: dev=True, max 2000 samples, 2 epochs, ~5-10 min.", + ) return parser.parse_args() @@ -84,6 +115,23 @@ def main() -> None: args = parse_args() set_seed(args.seed) + orig_stdout, orig_stderr = sys.stdout, sys.stderr + log_file = None + if args.log_file: + log_file = open(args.log_file, "w", encoding="utf-8") + sys.stdout = _Tee(orig_stdout, log_file) + sys.stderr = _Tee(orig_stderr, log_file) + + try: + _run(args) + finally: + if log_file is not None: + sys.stdout = orig_stdout + sys.stderr = orig_stderr + log_file.close() + + +def _run(args: argparse.Namespace) -> None: device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu") root = Path(args.root) if not root.exists(): @@ -92,11 +140,19 @@ def main() -> None: "Pass --root to point to your downloaded TUEV edf/ directory." ) + epochs = 2 if args.quick_test else args.epochs + quick_test_max_samples = 2000 # cap samples so quick-test finishes in ~5-10 min + if args.quick_test: + print("*** QUICK TEST MODE (dev=True, 2 epochs, max 2000 samples) ***") + print("=" * 80) print("STEP 1: Load TUEV + build task dataset") print("=" * 80) - dataset = TUEVDataset(root=str(root), subset=args.subset) + dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") + if args.quick_test and len(sample_dataset) > quick_test_max_samples: + sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) + print(f"Capped to {quick_test_max_samples} samples for quick-test.") print(f"Task samples: {len(sample_dataset)}") print(f"Input schema: {sample_dataset.input_schema}") @@ -129,7 +185,7 @@ def main() -> None: trainer.train( train_dataloader=train_loader, val_dataloader=val_loader, - epochs=args.epochs, + epochs=epochs, monitor="accuracy" if val_loader is not None else None, ) diff --git a/examples/conformal_eeg/tuev_ncp_conformal.py b/examples/conformal_eeg/tuev_ncp_conformal.py new file mode 100644 index 000000000..9b51a7756 --- /dev/null +++ b/examples/conformal_eeg/tuev_ncp_conformal.py @@ -0,0 +1,430 @@ +"""Neighborhood Conformal Prediction (NCP) on TUEV EEG Events using ContraWR. + +This script: +1) Loads the TUEV dataset and applies the EEGEventsTUEV task. +2) Splits into train/val/cal/test using split conformal protocol. +3) Trains a ContraWR model. +4) Extracts calibration embeddings and calibrates a NeighborhoodLabel (NCP) predictor. +5) Evaluates prediction-set coverage/miscoverage and efficiency on the test split. + +With --n-seeds > 1: fixes the test set (--split-seed), runs multiple training runs +with different seeds (different train/val/cal splits and model init), reports +coverage / set size / accuracy as mean ± std (error bars). + +Example (from repo root): + python examples/conformal_eeg/tuev_ncp_conformal.py --root /srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf + python examples/conformal_eeg/tuev_ncp_conformal.py --quick-test --log-file quicktest_ncp.log + python examples/conformal_eeg/tuev_ncp_conformal.py --alpha 0.1 --n-seeds 5 --split-seed 0 --log-file ncp_seeds5.log +""" + +from __future__ import annotations + +import argparse +import random +import sys +from pathlib import Path + +import numpy as np +import torch + + +class _Tee: + """Writes to both a stream and a file.""" + + def __init__(self, stream, file): + self._stream = stream + self._file = file + + def write(self, data): + self._stream.write(data) + self._file.write(data) + self._file.flush() + + def flush(self): + self._stream.flush() + self._file.flush() + + +from pyhealth.calib.predictionset.cluster import NeighborhoodLabel +from pyhealth.calib.utils import extract_embeddings +from pyhealth.datasets import TUEVDataset, get_dataloader, split_by_sample_conformal +from pyhealth.models import ContraWR +from pyhealth.tasks import EEGEventsTUEV +from pyhealth.trainer import Trainer, get_metrics_fn + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Neighborhood conformal prediction (NCP) on TUEV EEG events using ContraWR." + ) + parser.add_argument( + "--root", + type=str, + default="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf", + help="Path to TUEV edf/ folder.", + ) + parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) + parser.add_argument("--seed", type=int, default=42, help="Run seed (or first of run seeds when n-seeds > 1).") + parser.add_argument( + "--n-seeds", + type=int, + default=1, + help="Number of runs for mean±std. Test set fixed; train/val/cal vary by seed.", + ) + parser.add_argument( + "--split-seed", + type=int, + default=0, + help="Fixed seed for initial split (fixes test set when n-seeds > 1).", + ) + parser.add_argument( + "--seeds", + type=str, + default=None, + help="Comma-separated run seeds, e.g. 42,43,44,45,46. Overrides --seed and --n-seeds.", + ) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--epochs", type=int, default=20) + parser.add_argument("--alpha", type=float, default=0.1, help="Miscoverage rate (e.g., 0.1 => 90% target coverage).") + parser.add_argument( + "--ratios", + type=float, + nargs=4, + default=(0.6, 0.1, 0.15, 0.15), + metavar=("TRAIN", "VAL", "CAL", "TEST"), + help="Split ratios for train/val/cal/test. Must sum to 1.0.", + ) + parser.add_argument( + "--k-neighbors", + type=int, + default=50, + help="Number of nearest calibration neighbors for NCP.", + ) + parser.add_argument( + "--lambda-L", + type=float, + default=100.0, + help="Temperature for NCP exponential weights; smaller => more localization.", + ) + parser.add_argument("--n-fft", type=int, default=128, help="STFT FFT size used by ContraWR.") + parser.add_argument( + "--device", + type=str, + default=None, + help="Device string, e.g. 'cuda:0' or 'cpu'. Defaults to auto-detect.", + ) + parser.add_argument( + "--log-file", + type=str, + default=None, + help="Path to log file. Stdout and stderr are teed to this file.", + ) + parser.add_argument( + "--quick-test", + action="store_true", + help="Smoke test: dev=True, max 2000 samples, 2 epochs, ~5-10 min.", + ) + return parser.parse_args() + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def _split_remainder_into_train_val_cal(sample_dataset, remainder_indices, ratios, run_seed): + """Split remainder indices into train/val/cal by renormalized ratios. Uses run_seed for shuffle.""" + r0, r1, r2, r3 = ratios + remainder_frac = 1.0 - r3 + if remainder_frac <= 0: + raise ValueError("Test ratio must be < 1 so remainder (train+val+cal) is non-empty.") + # Renormalize so train/val/cal ratios sum to 1 on the remainder + r_train = r0 / remainder_frac + r_val = r1 / remainder_frac + remainder = np.asarray(remainder_indices, dtype=np.int64) + np.random.seed(run_seed) + shuffled = np.random.permutation(remainder) + M = len(shuffled) + train_end = int(M * r_train) + val_end = int(M * (r_train + r_val)) + train_index = shuffled[:train_end] + val_index = shuffled[train_end:val_end] + cal_index = shuffled[val_end:] + train_ds = sample_dataset.subset(train_index.tolist()) + val_ds = sample_dataset.subset(val_index.tolist()) + cal_ds = sample_dataset.subset(cal_index.tolist()) + return train_ds, val_ds, cal_ds + + +def _run_one_ncp( + sample_dataset, + train_ds, + val_ds, + cal_ds, + test_loader, + args, + device, + epochs, + return_metrics=False, +): + """Train ContraWR, calibrate NCP, evaluate on test. Optionally return metrics dict for aggregation.""" + train_loader = get_dataloader(train_ds, batch_size=args.batch_size, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=args.batch_size, shuffle=False) if len(val_ds) else None + + print("\n" + "=" * 80) + print("STEP 3: Train ContraWR") + print("=" * 80) + model = ContraWR(dataset=sample_dataset, n_fft=args.n_fft).to(device) + trainer = Trainer(model=model, device=device, enable_logging=False) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=epochs, + monitor="accuracy" if val_loader is not None else None, + ) + + if not return_metrics: + print("\nBase model performance on test set:") + y_true_base, y_prob_base, _loss_base = trainer.inference(test_loader) + base_metrics = get_metrics_fn("multiclass")( + y_true_base, y_prob_base, metrics=["accuracy", "f1_weighted"] + ) + for metric, value in base_metrics.items(): + print(f" {metric}: {value:.4f}") + + print("\n" + "=" * 80) + print("STEP 4: Neighborhood Conformal Prediction (NCP / NeighborhoodLabel)") + print("=" * 80) + print(f"Target miscoverage alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})") + print(f"k_neighbors: {args.k_neighbors}, lambda_L: {args.lambda_L}") + + cal_embeddings = extract_embeddings(model, cal_ds, batch_size=args.batch_size, device=device) + if not return_metrics: + print(f" cal_embeddings shape: {cal_embeddings.shape}") + + ncp_predictor = NeighborhoodLabel( + model=model, + alpha=float(args.alpha), + k_neighbors=args.k_neighbors, + lambda_L=args.lambda_L, + ) + ncp_predictor.calibrate(cal_dataset=cal_ds, cal_embeddings=cal_embeddings) + + y_true, y_prob, _loss, extra = Trainer(model=ncp_predictor).inference( + test_loader, additional_outputs=["y_predset"] + ) + ncp_metrics = get_metrics_fn("multiclass")( + y_true, y_prob, metrics=["accuracy", "miscoverage_ps"], y_predset=extra["y_predset"] + ) + predset = extra["y_predset"] + if isinstance(predset, np.ndarray): + predset_t = torch.tensor(predset) + else: + predset_t = predset + avg_set_size = predset_t.float().sum(dim=1).mean().item() + miscoverage = ncp_metrics["miscoverage_ps"] + if isinstance(miscoverage, np.ndarray): + miscoverage = float(miscoverage.item() if miscoverage.size == 1 else miscoverage.mean()) + else: + miscoverage = float(miscoverage) + coverage = 1.0 - miscoverage + + if return_metrics: + return { + "accuracy": float(ncp_metrics["accuracy"]), + "coverage": coverage, + "miscoverage": miscoverage, + "avg_set_size": avg_set_size, + } + + print("\nNCP (NeighborhoodLabel) Results:") + print(f" Accuracy: {ncp_metrics['accuracy']:.4f}") + print(f" Empirical miscoverage: {miscoverage:.4f}") + print(f" Empirical coverage: {coverage:.4f}") + print(f" Average set size: {avg_set_size:.2f}") + print(f" k_neighbors: {args.k_neighbors}") + print("\n--- Single-run summary (for reporting) ---") + print(f" alpha={args.alpha}, target_coverage={1 - args.alpha:.2f}, empirical_coverage={coverage:.4f}, miscoverage={miscoverage:.4f}, accuracy={ncp_metrics['accuracy']:.4f}, avg_set_size={avg_set_size:.2f}") + + +def main() -> None: + args = parse_args() + # Seed set per run in multi-seed mode; for single run set once here + if args.n_seeds <= 1 and args.seeds is None: + set_seed(args.seed) + + orig_stdout, orig_stderr = sys.stdout, sys.stderr + log_file = None + if args.log_file: + log_file = open(args.log_file, "w", encoding="utf-8") + sys.stdout = _Tee(orig_stdout, log_file) + sys.stderr = _Tee(orig_stderr, log_file) + + try: + _run(args) + finally: + if log_file is not None: + sys.stdout = orig_stdout + sys.stderr = orig_stderr + log_file.close() + + +def _run(args: argparse.Namespace) -> None: + device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu") + root = Path(args.root) + if not root.exists(): + raise FileNotFoundError( + f"TUEV root not found: {root}. " + "Pass --root to point to your downloaded TUEV edf/ directory." + ) + + epochs = 2 if args.quick_test else args.epochs + quick_test_max_samples = 2000 # cap samples so quick-test finishes in ~5-10 min + if args.quick_test: + print("*** QUICK TEST MODE (dev=True, 2 epochs, max 2000 samples) ***") + + print("=" * 80) + print("STEP 1: Load TUEV + build task dataset") + print("=" * 80) + dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) + sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") + if args.quick_test and len(sample_dataset) > quick_test_max_samples: + sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) + print(f"Capped to {quick_test_max_samples} samples for quick-test.") + + print(f"Task samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + # Experiment configuration (for PI / reporting) + print("\n--- Experiment configuration ---") + print(f" dataset_root: {root}") + print(f" subset: {args.subset}, ratios: train/val/cal/test = {args.ratios[0]:.2f}/{args.ratios[1]:.2f}/{args.ratios[2]:.2f}/{args.ratios[3]:.2f}") + print(f" alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})") + print(f" k_neighbors: {args.k_neighbors}, lambda_L: {args.lambda_L}") + print(f" epochs: {epochs}, batch_size: {args.batch_size}, device: {device}, seed: {args.seed}") + + if len(sample_dataset) == 0: + raise RuntimeError("No samples produced. Verify TUEV root/subset/task.") + + ratios = list(args.ratios) + use_multi_seed = args.n_seeds > 1 or args.seeds is not None + if use_multi_seed: + run_seeds = ( + [int(s.strip()) for s in args.seeds.split(",")] + if args.seeds + else [args.seed + i for i in range(args.n_seeds)] + ) + n_runs = len(run_seeds) + print(f" multi_seed: n_runs={n_runs}, run_seeds={run_seeds}, split_seed={args.split_seed} (fixed test set)") + print(f"Multi-seed mode: {n_runs} runs (fixed test set), run seeds: {run_seeds}") + + if not use_multi_seed: + # Single run: original behavior + print("\n" + "=" * 80) + print("STEP 2: Split train/val/cal/test") + print("=" * 80) + train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal( + dataset=sample_dataset, ratios=ratios, seed=args.seed + ) + print(f"Train: {len(train_ds)}") + print(f"Val: {len(val_ds)}") + print(f"Cal: {len(cal_ds)}") + print(f"Test: {len(test_ds)}") + + test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False) + _run_one_ncp( + sample_dataset=sample_dataset, + train_ds=train_ds, + val_ds=val_ds, + cal_ds=cal_ds, + test_loader=test_loader, + args=args, + device=device, + epochs=epochs, + ) + print("\n--- Split sizes and seed (for reporting) ---") + print(f" train={len(train_ds)}, val={len(val_ds)}, cal={len(cal_ds)}, test={len(test_ds)}, seed={args.seed}") + return + + # Multi-seed: fix test set, vary train/val/cal per run + print("\n" + "=" * 80) + print("STEP 2: Fix test set (split-seed), then run multiple train/cal splits") + print("=" * 80) + train_idx, val_idx, cal_idx, test_idx = split_by_sample_conformal( + dataset=sample_dataset, ratios=ratios, seed=args.split_seed, get_index=True + ) + # Convert to numpy for indexing + train_index = train_idx.numpy() if hasattr(train_idx, "numpy") else np.array(train_idx) + val_index = val_idx.numpy() if hasattr(val_idx, "numpy") else np.array(val_idx) + cal_index = cal_idx.numpy() if hasattr(cal_idx, "numpy") else np.array(cal_idx) + test_index = test_idx.numpy() if hasattr(test_idx, "numpy") else np.array(test_idx) + remainder_indices = np.concatenate([train_index, val_index, cal_index]) + test_ds = sample_dataset.subset(test_index.tolist()) + test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False) + n_test = len(test_ds) + print(f"Fixed test set size: {n_test}") + + accs, coverages, miscoverages, set_sizes = [], [], [], [] + for run_i, run_seed in enumerate(run_seeds): + print("\n" + "=" * 80) + print(f"Run {run_i + 1} / {n_runs} (seed={run_seed})") + print("=" * 80) + set_seed(run_seed) + train_ds, val_ds, cal_ds = _split_remainder_into_train_val_cal( + sample_dataset, remainder_indices, ratios, run_seed + ) + print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Cal: {len(cal_ds)}") + + metrics = _run_one_ncp( + sample_dataset=sample_dataset, + train_ds=train_ds, + val_ds=val_ds, + cal_ds=cal_ds, + test_loader=test_loader, + args=args, + device=device, + epochs=epochs, + return_metrics=True, + ) + accs.append(metrics["accuracy"]) + coverages.append(metrics["coverage"]) + miscoverages.append(metrics["miscoverage"]) + set_sizes.append(metrics["avg_set_size"]) + + accs = np.array(accs) + coverages = np.array(coverages) + miscoverages_arr = np.array(miscoverages) + set_sizes = np.array(set_sizes) + + # Per-run table (for PI / reporting) + print("\n" + "=" * 80) + print("Per-run NCP results (fixed test set)") + print("=" * 80) + print(f" {'Run':<4} {'Seed':<6} {'Accuracy':<10} {'Coverage':<10} {'Miscoverage':<12} {'Avg set size':<12}") + print(" " + "-" * 54) + for i in range(n_runs): + print(f" {i+1:<4} {run_seeds[i]:<6} {accs[i]:<10.4f} {coverages[i]:<10.4f} {miscoverages_arr[i]:<12.4f} {set_sizes[i]:<12.2f}") + + print("\n" + "=" * 80) + print("NCP summary (mean ± std over {} runs, fixed test set)".format(n_runs)) + print("=" * 80) + print(f" Accuracy: {accs.mean():.4f} ± {accs.std():.4f}") + print(f" Empirical coverage: {coverages.mean():.4f} ± {coverages.std():.4f}") + print(f" Empirical miscoverage: {miscoverages_arr.mean():.4f} ± {miscoverages_arr.std():.4f}") + print(f" Average set size: {set_sizes.mean():.2f} ± {set_sizes.std():.2f}") + print(f" Target coverage: {1 - args.alpha:.0%} (alpha={args.alpha})") + print(f" k_neighbors: {args.k_neighbors}, lambda_L: {args.lambda_L}") + print(f" Test set size: {n_test} (fixed across runs)") + print(f" Run seeds: {run_seeds}") + print("\n--- Min / Max (across runs) ---") + print(f" Coverage: [{coverages.min():.4f}, {coverages.max():.4f}]") + print(f" Set size: [{set_sizes.min():.2f}, {set_sizes.max():.2f}]") + print(f" Accuracy: [{accs.min():.4f}, {accs.max():.4f}]") + + +if __name__ == "__main__": + main() diff --git a/pyhealth/calib/predictionset/__init__.py b/pyhealth/calib/predictionset/__init__.py index 46760945f..f8328d35c 100644 --- a/pyhealth/calib/predictionset/__init__.py +++ b/pyhealth/calib/predictionset/__init__.py @@ -1,10 +1,18 @@ """Prediction set construction methods""" from pyhealth.calib.predictionset.base_conformal import BaseConformal -from pyhealth.calib.predictionset.cluster import ClusterLabel +from pyhealth.calib.predictionset.cluster import ClusterLabel, NeighborhoodLabel from pyhealth.calib.predictionset.covariate import CovariateLabel from pyhealth.calib.predictionset.favmac import FavMac from pyhealth.calib.predictionset.label import LABEL from pyhealth.calib.predictionset.scrib import SCRIB -__all__ = ["BaseConformal", "LABEL", "SCRIB", "FavMac", "CovariateLabel", "ClusterLabel"] +__all__ = [ + "BaseConformal", + "LABEL", + "SCRIB", + "FavMac", + "CovariateLabel", + "ClusterLabel", + "NeighborhoodLabel", +] diff --git a/pyhealth/calib/predictionset/base_conformal/__init__.py b/pyhealth/calib/predictionset/base_conformal/__init__.py index b1ceefdee..8e37c0d6d 100644 --- a/pyhealth/calib/predictionset/base_conformal/__init__.py +++ b/pyhealth/calib/predictionset/base_conformal/__init__.py @@ -45,6 +45,32 @@ def _query_quantile(scores: np.ndarray, alpha: float) -> float: return -np.inf if loc == -1 else scores[loc] +def _query_weighted_quantile( + scores: np.ndarray, alpha: float, weights: np.ndarray +) -> float: + """Compute weighted quantile of scores (e.g. for NCP or covariate shift). + + Args: + scores: Array of conformity scores + alpha: Quantile level (between 0 and 1) + weights: Weights for each score (same length as scores) + + Returns: + The weighted alpha-quantile of scores + """ + sorted_indices = np.argsort(scores) + sorted_scores = scores[sorted_indices] + sorted_weights = weights[sorted_indices] + w_sum = np.sum(sorted_weights) + if w_sum <= 0: + return -np.inf + cum_weights = np.cumsum(sorted_weights) / w_sum + idx = np.searchsorted(cum_weights, alpha, side="left") + if idx >= len(sorted_scores): + idx = len(sorted_scores) - 1 + return float(sorted_scores[idx]) + + class BaseConformal(SetPredictor): """Base Conformal Prediction for multiclass classification. diff --git a/pyhealth/calib/predictionset/cluster/__init__.py b/pyhealth/calib/predictionset/cluster/__init__.py index eacd89967..44f26dc28 100644 --- a/pyhealth/calib/predictionset/cluster/__init__.py +++ b/pyhealth/calib/predictionset/cluster/__init__.py @@ -1,5 +1,6 @@ -"""Cluster-based prediction set methods""" +"""Cluster-based and neighborhood prediction set methods""" from pyhealth.calib.predictionset.cluster.cluster_label import ClusterLabel +from pyhealth.calib.predictionset.cluster.neighborhood_label import NeighborhoodLabel -__all__ = ["ClusterLabel"] +__all__ = ["ClusterLabel", "NeighborhoodLabel"] diff --git a/pyhealth/calib/predictionset/cluster/neighborhood_label.py b/pyhealth/calib/predictionset/cluster/neighborhood_label.py new file mode 100644 index 000000000..2d9f2dc6d --- /dev/null +++ b/pyhealth/calib/predictionset/cluster/neighborhood_label.py @@ -0,0 +1,234 @@ +""" +Neighborhood Conformal Prediction (NCP). + +""" + +from typing import Dict, Optional, Union + +import numpy as np +import torch +from sklearn.neighbors import NearestNeighbors +from torch.utils.data import IterableDataset + +from pyhealth.calib.base_classes import SetPredictor +from pyhealth.calib.predictionset.base_conformal import _query_weighted_quantile +from pyhealth.calib.utils import extract_embeddings, prepare_numpy_dataset +from pyhealth.models import BaseModel + +__all__ = ["NeighborhoodLabel"] + + +class NeighborhoodLabel(SetPredictor): + """Neighborhood Conformal Prediction (NCP) for multiclass classification. + + Reference: + Ghosh, S., Belkhouja, T., Yan, Y., & Doppa, J. R. (2023). + Improving Uncertainty Quantification of Deep Classifiers via + Neighborhood Conformal Prediction. + + Args: + model: A trained base model that supports embedding extraction + (must support `embed=True` in forward pass). + alpha: Target miscoverage rate; marginal coverage P(Y not in C(X)) <= alpha. + k_neighbors: Number of nearest calibration neighbors. Default 50. + lambda_L: Temperature for exponential weights; smaller => more localization. + Default 100.0. + debug: If True, process fewer samples for faster iteration. + + Examples: + >>> from pyhealth.datasets import TUEVDataset, split_by_sample_conformal + >>> from pyhealth.datasets import get_dataloader + >>> from pyhealth.models import ContraWR + >>> from pyhealth.tasks import EEGEventsTUEV + >>> from pyhealth.calib.predictionset.cluster import NeighborhoodLabel + >>> from pyhealth.calib.utils import extract_embeddings + >>> from pyhealth.trainer import Trainer, get_metrics_fn + >>> + >>> dataset = TUEVDataset(root="path/to/tuev") + >>> sample_dataset = dataset.set_task(EEGEventsTUEV()) + >>> train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal( + ... sample_dataset, ratios=[0.6, 0.1, 0.15, 0.15], seed=42 + ... ) + >>> model = ContraWR(dataset=sample_dataset) + >>> cal_embeddings = extract_embeddings(model, cal_ds, batch_size=32) + >>> ncp = NeighborhoodLabel(model=model, alpha=0.1, k_neighbors=50) + >>> ncp.calibrate(cal_dataset=cal_ds, cal_embeddings=cal_embeddings) + >>> test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False) + >>> y_true, y_prob, _, extra = Trainer(model=ncp).inference( + ... test_loader, additional_outputs=["y_predset"] + ... ) + >>> metrics = get_metrics_fn(ncp.mode)( + ... y_true, y_prob, metrics=["accuracy", "miscoverage_ps"], + ... y_predset=extra["y_predset"] + ... ) + """ + + def __init__( + self, + model: BaseModel, + alpha: float, + k_neighbors: int = 50, + lambda_L: float = 100.0, + debug: bool = False, + **kwargs, + ) -> None: + super().__init__(model, **kwargs) + + if model.mode != "multiclass": + raise NotImplementedError( + "NeighborhoodLabel only supports multiclass classification" + ) + + self.mode = self.model.mode + + for param in model.parameters(): + param.requires_grad = False + self.model.eval() + + self.device = model.device + self.debug = debug + + if not (0.0 < alpha < 1.0): + raise ValueError(f"alpha must be in (0, 1), got {alpha!r}") + self.alpha = float(alpha) + + if not isinstance(k_neighbors, int) or k_neighbors <= 0: + raise ValueError( + f"k_neighbors must be a positive integer, got {k_neighbors!r}" + ) + self.k_neighbors = k_neighbors + self.lambda_L = float(lambda_L) + + self.cal_embeddings_ = None + self.cal_conformity_scores_ = None + self.alpha_tilde_ = None + self._nn = None + + def calibrate( + self, + cal_dataset: IterableDataset, + cal_embeddings: Optional[np.ndarray] = None, + batch_size: int = 32, + ) -> None: + """Calibrate NCP steps: + + Step 1: For each calibration point i, compute Q̃^NCP (weighted quantile) + over its k-NN in calibration using weights. + Step 2: Find ã^NCP(α) = largest ã such that empirical coverage on the + calibration set is >= 1-α; store as alpha_tilde_ for use at test time. + + Args: + cal_dataset: Calibration dataset (for labels and predictions if + cal_embeddings not provided). + cal_embeddings: Optional precomputed calibration embeddings + (n_cal, embedding_dim). If None, extracted from cal_dataset. + batch_size: Batch size for embedding extraction when cal_embeddings + is not provided. + """ + cal_dict = prepare_numpy_dataset( + self.model, + cal_dataset, + ["y_prob", "y_true"], + debug=self.debug, + ) + y_prob = cal_dict["y_prob"] + y_true = cal_dict["y_true"] + N = y_prob.shape[0] + + if cal_embeddings is None: + cal_embeddings = extract_embeddings( + self.model, cal_dataset, batch_size=batch_size, device=self.device + ) + else: + cal_embeddings = np.asarray(cal_embeddings) + + if cal_embeddings.shape[0] != N: + raise ValueError( + f"cal_embeddings length {cal_embeddings.shape[0]} must match " + f"cal_dataset size {N}" + ) + + conformity_scores = y_prob[np.arange(N), y_true] + + k = min(self.k_neighbors, N) + self._nn = NearestNeighbors(n_neighbors=k, metric="euclidean").fit( + np.atleast_2d(cal_embeddings) + ) + self.cal_embeddings_ = np.atleast_2d(cal_embeddings) + self.cal_conformity_scores_ = np.asarray(conformity_scores, dtype=np.float64) + + # this is the ncp calibration step + distances_cal, indices_cal = self._nn.kneighbors( + self.cal_embeddings_, n_neighbors=k + ) + cal_weights = np.exp(-distances_cal / self.lambda_L) + cal_weights = cal_weights / cal_weights.sum(axis=1, keepdims=True) + + def _empirical_coverage(alpha_tilde_cand: float) -> float: + t_all = np.zeros(N, dtype=np.float64) + for i in range(N): + t_all[i] = _query_weighted_quantile( + self.cal_conformity_scores_[indices_cal[i]], + alpha_tilde_cand, + cal_weights[i], + ) + return float(np.mean(self.cal_conformity_scores_ >= t_all)) + + low, high = 0.0, 1.0 + for _ in range(50): + mid = (low + high) / 2 + if _empirical_coverage(mid) >= 1.0 - self.alpha: + low = mid + else: + high = mid + self.alpha_tilde_ = float(low) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward with NCP: per-sample weighted quantile threshold.""" + if ( + self.cal_embeddings_ is None + or self.cal_conformity_scores_ is None + or self.alpha_tilde_ is None + ): + raise RuntimeError( + "NeighborhoodLabel must be calibrated before inference. " + "Call calibrate() first." + ) + + pred = self.model(**{**kwargs, "embed": True}) + if "embed" not in pred: + raise ValueError( + f"Model {type(self.model).__name__} does not return " + "embeddings. Ensure it supports embed=True in forward()." + ) + + test_emb = pred["embed"].detach().cpu().numpy() + test_emb = np.atleast_2d(test_emb) + batch_size = test_emb.shape[0] + n_cal = self.cal_conformity_scores_.shape[0] + k = min(self.k_neighbors, n_cal) + + distances, indices = self._nn.kneighbors(test_emb, n_neighbors=k) + thresholds = np.zeros(batch_size, dtype=np.float64) + for i in range(batch_size): + w = np.exp(-distances[i] / self.lambda_L) + w = w / np.sum(w) + scores_i = self.cal_conformity_scores_[indices[i]] + thresholds[i] = _query_weighted_quantile( + scores_i, self.alpha_tilde_, w + ) + + th = torch.as_tensor( + thresholds, device=self.device, dtype=pred["y_prob"].dtype + ) + if pred["y_prob"].ndim > 1: + th = th.view(-1, *([1] * (pred["y_prob"].ndim - 1))) + y_predset = pred["y_prob"] >= th + # if threshold is high, include at least argmax + empty = y_predset.sum(dim=1) == 0 + if empty.any(): + argmax_idx = pred["y_prob"].argmax(dim=1) + y_predset[empty, argmax_idx[empty]] = True + pred["y_predset"] = y_predset + pred.pop("embed", None) + return pred diff --git a/pyhealth/metrics/prediction_set.py b/pyhealth/metrics/prediction_set.py index 95d71e7e8..2b6f71705 100644 --- a/pyhealth/metrics/prediction_set.py +++ b/pyhealth/metrics/prediction_set.py @@ -33,8 +33,12 @@ def _missrate(y_pred:np.ndarray, y_true:np.ndarray, ignore_rejected=False): keep_msk = (y_pred.sum(1) == 1) if ignore_rejected else np.ones(len(y_true), dtype=bool) missed = [] for k in range(K): - missed.append(1-np.mean(y_pred[keep_msk & y_true[:, k], k])) - + msk = keep_msk & y_true[:, k] + n = msk.sum() + if n == 0: + missed.append(0.0) + else: + missed.append(1 - np.mean(y_pred[msk, k])) return np.asarray(missed) @@ -92,8 +96,8 @@ def miscoverage_overall_ps(y_pred:np.ndarray, y_true:np.ndarray): """ assert len(y_true.shape) == 1 truth_pred = y_pred[np.arange(len(y_true)), y_true] + return 1 - np.mean(truth_pred) if len(truth_pred) > 0 else 0.0 - return 1 - np.mean(truth_pred) def error_overall_ps(y_pred:np.ndarray, y_true:np.ndarray): """Overall error rate for the un-rejected samples. @@ -113,4 +117,4 @@ def error_overall_ps(y_pred:np.ndarray, y_true:np.ndarray): assert len(y_true.shape) == 1 truth_pred = y_pred[np.arange(len(y_true)), y_true] truth_pred = truth_pred[y_pred.sum(1) == 1] - return 1 - np.mean(truth_pred) + return 1 - np.mean(truth_pred) if len(truth_pred) > 0 else 0.0 diff --git a/tests/core/test_gamenet.py b/tests/core/test_gamenet.py index f456189c7..8b735b857 100644 --- a/tests/core/test_gamenet.py +++ b/tests/core/test_gamenet.py @@ -103,14 +103,15 @@ def test_model_backward(self): ) def test_loss_is_finite(self): + """Test that the loss is finite.""" + torch.manual_seed(42) # reproducibility: shuffle + dropout can rarely yield non-finite loss train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) data_batch = next(iter(train_loader)) - ret = self.model(**data_batch) + with torch.no_grad(): + ret = self.model(**data_batch) self.assertTrue(torch.isfinite(ret["loss"]).all()) - self.assertFalse(torch.isnan(ret["loss"]).any()) - self.assertFalse(torch.isinf(ret["loss"]).any()) def test_output_shapes(self): train_loader = get_dataloader(self.dataset, batch_size=3, shuffle=True) diff --git a/tests/core/test_neighborhood_label.py b/tests/core/test_neighborhood_label.py new file mode 100644 index 000000000..b33f4c4b0 --- /dev/null +++ b/tests/core/test_neighborhood_label.py @@ -0,0 +1,184 @@ +"""Tests for NeighborhoodLabel (NCP) prediction set constructor.""" + +import unittest +import numpy as np +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import MLP +from pyhealth.calib.predictionset.cluster import NeighborhoodLabel +from pyhealth.calib.utils import extract_embeddings + + +class TestNeighborhoodLabel(unittest.TestCase): + """Test cases for the NeighborhoodLabel (NCP) prediction set constructor.""" + + def setUp(self): + """Set up test data and model.""" + self.samples = [ + {"patient_id": "p0", "visit_id": "v0", "conditions": ["c1"], "procedures": [1.0], "label": 0}, + {"patient_id": "p1", "visit_id": "v1", "conditions": ["c2"], "procedures": [2.0], "label": 1}, + {"patient_id": "p2", "visit_id": "v2", "conditions": ["c3"], "procedures": [3.0], "label": 2}, + {"patient_id": "p3", "visit_id": "v3", "conditions": ["c4"], "procedures": [1.5], "label": 0}, + {"patient_id": "p4", "visit_id": "v4", "conditions": ["c5"], "procedures": [2.5], "label": 1}, + {"patient_id": "p5", "visit_id": "v5", "conditions": ["c6"], "procedures": [3.5], "label": 2}, + ] + self.input_schema = {"conditions": "sequence", "procedures": "tensor"} + self.output_schema = {"label": "multiclass"} + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + self.model = MLP( + dataset=self.dataset, + feature_keys=["conditions", "procedures"], + label_key="label", + mode="multiclass", + ) + self.model.eval() + + def _get_embeddings(self, dataset): + return extract_embeddings(self.model, dataset, batch_size=32, device="cpu") + + def test_initialization(self): + ncp = NeighborhoodLabel( + model=self.model, + alpha=0.1, + k_neighbors=5, + lambda_L=100.0, + ) + self.assertIsInstance(ncp, NeighborhoodLabel) + self.assertEqual(ncp.mode, "multiclass") + self.assertEqual(ncp.alpha, 0.1) + self.assertEqual(ncp.k_neighbors, 5) + self.assertEqual(ncp.lambda_L, 100.0) + self.assertIsNone(ncp.cal_embeddings_) + self.assertIsNone(ncp.cal_conformity_scores_) + + def test_initialization_invalid_alpha_raises(self): + with self.assertRaises(ValueError): + NeighborhoodLabel(model=self.model, alpha=0.0, k_neighbors=5) + with self.assertRaises(ValueError): + NeighborhoodLabel(model=self.model, alpha=1.0, k_neighbors=5) + with self.assertRaises(ValueError): + NeighborhoodLabel(model=self.model, alpha=-0.1, k_neighbors=5) + + def test_initialization_invalid_k_neighbors_raises(self): + with self.assertRaises(ValueError): + NeighborhoodLabel(model=self.model, alpha=0.1, k_neighbors=0) + with self.assertRaises(ValueError): + NeighborhoodLabel(model=self.model, alpha=0.1, k_neighbors=-1) + with self.assertRaises(ValueError): + NeighborhoodLabel(model=self.model, alpha=0.1, k_neighbors=2.5) + + def test_initialization_non_multiclass_raises(self): + binary_samples = [ + {"patient_id": "a", "visit_id": "a", "conditions": ["c"], "procedures": [1.0], "label": 0}, + {"patient_id": "b", "visit_id": "b", "conditions": ["d"], "procedures": [2.0], "label": 1}, + ] + binary_ds = create_sample_dataset( + samples=binary_samples, + input_schema={"conditions": "sequence", "procedures": "tensor"}, + output_schema={"label": "binary"}, + dataset_name="test", + ) + binary_model = MLP( + dataset=binary_ds, feature_keys=["conditions"], label_key="label", mode="binary" + ) + with self.assertRaises(NotImplementedError): + NeighborhoodLabel(model=binary_model, alpha=0.1, k_neighbors=2) + + def test_calibrate_and_forward_returns_predset(self): + ncp = NeighborhoodLabel(model=self.model, alpha=0.2, k_neighbors=3, lambda_L=50.0) + cal_indices = [3, 4, 5] + cal_dataset = self.dataset.subset(cal_indices) + cal_embeddings = self._get_embeddings(cal_dataset) + ncp.calibrate(cal_dataset=cal_dataset, cal_embeddings=cal_embeddings) + + self.assertIsNotNone(ncp.cal_embeddings_) + self.assertIsNotNone(ncp.cal_conformity_scores_) + self.assertEqual(ncp.cal_conformity_scores_.shape[0], 3) + + test_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(test_loader)) + with torch.no_grad(): + out = ncp(**batch) + + self.assertIn("y_predset", out) + self.assertIn("y_prob", out) + self.assertEqual(out["y_predset"].dtype, torch.bool) + self.assertEqual(out["y_predset"].shape, out["y_prob"].shape) + + def test_forward_before_calibration_raises(self): + ncp = NeighborhoodLabel(model=self.model, alpha=0.1, k_neighbors=5) + loader = get_dataloader(self.dataset, batch_size=1, shuffle=False) + batch = next(iter(loader)) + with self.assertRaises(RuntimeError): + with torch.no_grad(): + ncp(**batch) + + def test_prediction_sets_nonempty_batch(self): + ncp = NeighborhoodLabel(model=self.model, alpha=0.3, k_neighbors=2, lambda_L=100.0) + cal_dataset = self.dataset.subset([2, 3, 4, 5]) + cal_emb = self._get_embeddings(cal_dataset) + ncp.calibrate(cal_dataset=cal_dataset, cal_embeddings=cal_emb) + + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + with torch.no_grad(): + for batch in loader: + out = ncp(**batch) + set_sizes = out["y_predset"].sum(dim=1) + self.assertTrue(torch.all(set_sizes > 0), "Prediction sets should be non-empty") + + def test_calibrate_without_embeddings_extracts(self): + ncp = NeighborhoodLabel(model=self.model, alpha=0.1, k_neighbors=2) + cal_dataset = self.dataset.subset([3, 4, 5]) + ncp.calibrate(cal_dataset=cal_dataset, batch_size=2) + self.assertIsNotNone(ncp.cal_embeddings_) + self.assertIsNotNone(ncp.cal_conformity_scores_) + + def test_calibration_empirical_coverage_at_least_1_minus_alpha(self): + """After calibrate(), empirical coverage on calibration set >= 1-alpha.""" + from pyhealth.calib.predictionset.base_conformal import _query_weighted_quantile + + ncp = NeighborhoodLabel(model=self.model, alpha=0.2, k_neighbors=3, lambda_L=50.0) + cal_indices = [0, 1, 2, 3, 4, 5] + cal_dataset = self.dataset.subset(cal_indices) + cal_emb = self._get_embeddings(cal_dataset) + ncp.calibrate(cal_dataset=cal_dataset, cal_embeddings=cal_emb) + + self.assertIsNotNone(ncp.alpha_tilde_) + self.assertGreaterEqual(ncp.alpha_tilde_, 0.0) + self.assertLessEqual(ncp.alpha_tilde_, 1.0) + + # Recompute per-sample thresholds using alpha_tilde (Q^NCP definition: alpha_tilde-quantile of conformity) + N = ncp.cal_conformity_scores_.shape[0] + k = min(ncp.k_neighbors, N) + distances_cal, indices_cal = ncp._nn.kneighbors( + ncp.cal_embeddings_, n_neighbors=k + ) + cal_weights = np.exp(-distances_cal / ncp.lambda_L) + cal_weights = cal_weights / cal_weights.sum(axis=1, keepdims=True) + + covered = 0 + for i in range(N): + t_i = _query_weighted_quantile( + ncp.cal_conformity_scores_[indices_cal[i]], + ncp.alpha_tilde_, + cal_weights[i], + ) + # Covered = true label in set = conformity_i >= threshold_i (paper: V_i <= t in non-conf space) + if ncp.cal_conformity_scores_[i] >= t_i: + covered += 1 + empirical_coverage = covered / N + self.assertGreaterEqual( + empirical_coverage, + 1.0 - ncp.alpha - 1e-6, + msg=f"Calibration empirical coverage {empirical_coverage:.4f} should be >= 1-alpha={1 - ncp.alpha}", + ) + + +if __name__ == "__main__": + unittest.main()