From 5f19ba867a97aac8c5e29403a58a426cf4f2b5cb Mon Sep 17 00:00:00 2001 From: lehendo Date: Fri, 6 Feb 2026 13:07:13 -0600 Subject: [PATCH 01/16] NCP implementation --- .../cluster/neighborhood_label.py | 206 ++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 pyhealth/calib/predictionset/cluster/neighborhood_label.py diff --git a/pyhealth/calib/predictionset/cluster/neighborhood_label.py b/pyhealth/calib/predictionset/cluster/neighborhood_label.py new file mode 100644 index 000000000..db583b32e --- /dev/null +++ b/pyhealth/calib/predictionset/cluster/neighborhood_label.py @@ -0,0 +1,206 @@ +""" +Neighborhood Conformal Prediction (NCP). + +This module implements Neighborhood Conformal Prediction. + +NCP uses the learned representation (embeddings) to assign importance weights +to calibration examples: for each test input it finds k nearest calibration +points in embedding space and computes a weighted quantile of their conformity +scores to form an adaptive, per-sample threshold. This can yield smaller +prediction sets than standard conformal prediction when the representation +has good separation between classes. +""" + +from typing import Dict, Optional, Union + +import numpy as np +import torch +from sklearn.neighbors import NearestNeighbors +from torch.utils.data import IterableDataset + +from pyhealth.calib.base_classes import SetPredictor +from pyhealth.calib.predictionset.base_conformal import _query_weighted_quantile +from pyhealth.calib.utils import extract_embeddings, prepare_numpy_dataset +from pyhealth.models import BaseModel + +__all__ = ["NeighborhoodLabel"] + + +class NeighborhoodLabel(SetPredictor): + """Neighborhood Conformal Prediction (NCP) for multiclass classification. + + For each test input, NCP finds its k nearest calibration examples in + embedding space and assigns them importance weights proportional to + distance (exponential kernel). The prediction set threshold is the + weighted quantile of calibration conformity scores, so the threshold + is adaptive per test sample. + + Reference: + Ghosh, S., Belkhouja, T., Yan, Y., & Doppa, J. R. (2023). + Improving Uncertainty Quantification of Deep Classifiers via + Neighborhood Conformal Prediction. + + Args: + model: A trained base model that supports embedding extraction + (must support `embed=True` in forward pass). + alpha: Target miscoverage rate; marginal coverage P(Y not in C(X)) <= alpha. + k_neighbors: Number of nearest calibration neighbors. Default 50. + lambda_L: Temperature for exponential weights; smaller => more localization. + Default 100.0. + debug: If True, process fewer samples for faster iteration. + + Examples: + >>> from pyhealth.datasets import TUEVDataset, split_by_sample_conformal + >>> from pyhealth.datasets import get_dataloader + >>> from pyhealth.models import ContraWR + >>> from pyhealth.tasks import EEGEventsTUEV + >>> from pyhealth.calib.predictionset.cluster import NeighborhoodLabel + >>> from pyhealth.calib.utils import extract_embeddings + >>> from pyhealth.trainer import Trainer, get_metrics_fn + >>> + >>> dataset = TUEVDataset(root="path/to/tuev") + >>> sample_dataset = dataset.set_task(EEGEventsTUEV()) + >>> train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal( + ... sample_dataset, ratios=[0.6, 0.1, 0.15, 0.15], seed=42 + ... ) + >>> model = ContraWR(dataset=sample_dataset) + >>> cal_embeddings = extract_embeddings(model, cal_ds, batch_size=32) + >>> ncp = NeighborhoodLabel(model=model, alpha=0.1, k_neighbors=50) + >>> ncp.calibrate(cal_dataset=cal_ds, cal_embeddings=cal_embeddings) + >>> test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False) + >>> y_true, y_prob, _, extra = Trainer(model=ncp).inference( + ... test_loader, additional_outputs=["y_predset"] + ... ) + >>> metrics = get_metrics_fn(ncp.mode)( + ... y_true, y_prob, metrics=["accuracy", "miscoverage_ps"], + ... y_predset=extra["y_predset"] + ... ) + """ + + def __init__( + self, + model: BaseModel, + alpha: float, + k_neighbors: int = 50, + lambda_L: float = 100.0, + debug: bool = False, + **kwargs, + ) -> None: + super().__init__(model, **kwargs) + + if model.mode != "multiclass": + raise NotImplementedError( + "NeighborhoodLabel only supports multiclass classification" + ) + + self.mode = self.model.mode + + for param in model.parameters(): + param.requires_grad = False + self.model.eval() + + self.device = model.device + self.debug = debug + + if not (0.0 < alpha < 1.0): + raise ValueError(f"alpha must be in (0, 1), got {alpha!r}") + self.alpha = float(alpha) + + if not isinstance(k_neighbors, int) or k_neighbors <= 0: + raise ValueError( + f"k_neighbors must be a positive integer, got {k_neighbors!r}" + ) + self.k_neighbors = k_neighbors + self.lambda_L = float(lambda_L) + + self.cal_embeddings_ = None + self.cal_conformity_scores_ = None + self._nn = None + + def calibrate( + self, + cal_dataset: IterableDataset, + cal_embeddings: Optional[np.ndarray] = None, + batch_size: int = 32, + ) -> None: + """Calibrate NCP: store calibration embeddings and conformity scores. + + Args: + cal_dataset: Calibration dataset (for labels and predictions if + cal_embeddings not provided). + cal_embeddings: Optional precomputed calibration embeddings + (n_cal, embedding_dim). If None, extracted from cal_dataset. + batch_size: Batch size for embedding extraction when cal_embeddings + is not provided. + """ + cal_dict = prepare_numpy_dataset( + self.model, + cal_dataset, + ["y_prob", "y_true"], + debug=self.debug, + ) + y_prob = cal_dict["y_prob"] + y_true = cal_dict["y_true"] + N = y_prob.shape[0] + + if cal_embeddings is None: + cal_embeddings = extract_embeddings( + self.model, cal_dataset, batch_size=batch_size, device=self.device + ) + else: + cal_embeddings = np.asarray(cal_embeddings) + + if cal_embeddings.shape[0] != N: + raise ValueError( + f"cal_embeddings length {cal_embeddings.shape[0]} must match " + f"cal_dataset size {N}" + ) + + conformity_scores = y_prob[np.arange(N), y_true] + + k = min(self.k_neighbors, N) + self._nn = NearestNeighbors(n_neighbors=k, metric="euclidean").fit( + np.atleast_2d(cal_embeddings) + ) + self.cal_embeddings_ = np.atleast_2d(cal_embeddings) + self.cal_conformity_scores_ = np.asarray(conformity_scores, dtype=np.float64) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward with NCP: per-sample weighted quantile threshold.""" + if self.cal_embeddings_ is None or self.cal_conformity_scores_ is None: + raise RuntimeError( + "NeighborhoodLabel must be calibrated before inference. " + "Call calibrate() first." + ) + + pred = self.model(**{**kwargs, "embed": True}) + if "embed" not in pred: + raise ValueError( + f"Model {type(self.model).__name__} does not return " + "embeddings. Ensure it supports embed=True in forward()." + ) + + test_emb = pred["embed"].detach().cpu().numpy() + test_emb = np.atleast_2d(test_emb) + batch_size = test_emb.shape[0] + n_cal = self.cal_conformity_scores_.shape[0] + k = min(self.k_neighbors, n_cal) + + distances, indices = self._nn.kneighbors(test_emb, n_neighbors=k) + thresholds = np.zeros(batch_size, dtype=np.float64) + for i in range(batch_size): + w = np.exp(-distances[i] / self.lambda_L) + w = w / np.sum(w) + scores_i = self.cal_conformity_scores_[indices[i]] + thresholds[i] = _query_weighted_quantile( + scores_i, self.alpha, w + ) + + th = torch.as_tensor( + thresholds, device=self.device, dtype=pred["y_prob"].dtype + ) + if pred["y_prob"].ndim > 1: + th = th.view(-1, *([1] * (pred["y_prob"].ndim - 1))) + pred["y_predset"] = pred["y_prob"] >= th + pred.pop("embed", None) + return pred From 5369bf27f850e8471565d6e54ab777c233bf8ece Mon Sep 17 00:00:00 2001 From: lehendo Date: Fri, 6 Feb 2026 13:10:27 -0600 Subject: [PATCH 02/16] test script --- tests/core/test_neighborhood_label.py | 144 ++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 tests/core/test_neighborhood_label.py diff --git a/tests/core/test_neighborhood_label.py b/tests/core/test_neighborhood_label.py new file mode 100644 index 000000000..8c4bf17a1 --- /dev/null +++ b/tests/core/test_neighborhood_label.py @@ -0,0 +1,144 @@ +"""Tests for NeighborhoodLabel (NCP) prediction set constructor.""" + +import unittest +import numpy as np +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import MLP +from pyhealth.calib.predictionset.cluster import NeighborhoodLabel +from pyhealth.calib.utils import extract_embeddings + + +class TestNeighborhoodLabel(unittest.TestCase): + """Test cases for the NeighborhoodLabel (NCP) prediction set constructor.""" + + def setUp(self): + """Set up test data and model.""" + self.samples = [ + {"patient_id": "p0", "visit_id": "v0", "conditions": ["c1"], "procedures": [1.0], "label": 0}, + {"patient_id": "p1", "visit_id": "v1", "conditions": ["c2"], "procedures": [2.0], "label": 1}, + {"patient_id": "p2", "visit_id": "v2", "conditions": ["c3"], "procedures": [3.0], "label": 2}, + {"patient_id": "p3", "visit_id": "v3", "conditions": ["c4"], "procedures": [1.5], "label": 0}, + {"patient_id": "p4", "visit_id": "v4", "conditions": ["c5"], "procedures": [2.5], "label": 1}, + {"patient_id": "p5", "visit_id": "v5", "conditions": ["c6"], "procedures": [3.5], "label": 2}, + ] + self.input_schema = {"conditions": "sequence", "procedures": "tensor"} + self.output_schema = {"label": "multiclass"} + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + self.model = MLP( + dataset=self.dataset, + feature_keys=["conditions", "procedures"], + label_key="label", + mode="multiclass", + ) + self.model.eval() + + def _get_embeddings(self, dataset): + return extract_embeddings(self.model, dataset, batch_size=32, device="cpu") + + def test_initialization(self): + ncp = NeighborhoodLabel( + model=self.model, + alpha=0.1, + k_neighbors=5, + lambda_L=100.0, + ) + self.assertIsInstance(ncp, NeighborhoodLabel) + self.assertEqual(ncp.mode, "multiclass") + self.assertEqual(ncp.alpha, 0.1) + self.assertEqual(ncp.k_neighbors, 5) + self.assertEqual(ncp.lambda_L, 100.0) + self.assertIsNone(ncp.cal_embeddings_) + self.assertIsNone(ncp.cal_conformity_scores_) + + def test_initialization_invalid_alpha_raises(self): + with self.assertRaises(ValueError): + NeighborhoodLabel(model=self.model, alpha=0.0, k_neighbors=5) + with self.assertRaises(ValueError): + NeighborhoodLabel(model=self.model, alpha=1.0, k_neighbors=5) + with self.assertRaises(ValueError): + NeighborhoodLabel(model=self.model, alpha=-0.1, k_neighbors=5) + + def test_initialization_invalid_k_neighbors_raises(self): + with self.assertRaises(ValueError): + NeighborhoodLabel(model=self.model, alpha=0.1, k_neighbors=0) + with self.assertRaises(ValueError): + NeighborhoodLabel(model=self.model, alpha=0.1, k_neighbors=-1) + with self.assertRaises(ValueError): + NeighborhoodLabel(model=self.model, alpha=0.1, k_neighbors=2.5) + + def test_initialization_non_multiclass_raises(self): + binary_samples = [ + {"patient_id": "a", "visit_id": "a", "conditions": ["c"], "procedures": [1.0], "label": 0}, + {"patient_id": "b", "visit_id": "b", "conditions": ["d"], "procedures": [2.0], "label": 1}, + ] + binary_ds = create_sample_dataset( + samples=binary_samples, + input_schema={"conditions": "sequence", "procedures": "tensor"}, + output_schema={"label": "binary"}, + dataset_name="test", + ) + binary_model = MLP( + dataset=binary_ds, feature_keys=["conditions"], label_key="label", mode="binary" + ) + with self.assertRaises(NotImplementedError): + NeighborhoodLabel(model=binary_model, alpha=0.1, k_neighbors=2) + + def test_calibrate_and_forward_returns_predset(self): + ncp = NeighborhoodLabel(model=self.model, alpha=0.2, k_neighbors=3, lambda_L=50.0) + cal_indices = [3, 4, 5] + cal_dataset = self.dataset.subset(cal_indices) + cal_embeddings = self._get_embeddings(cal_dataset) + ncp.calibrate(cal_dataset=cal_dataset, cal_embeddings=cal_embeddings) + + self.assertIsNotNone(ncp.cal_embeddings_) + self.assertIsNotNone(ncp.cal_conformity_scores_) + self.assertEqual(ncp.cal_conformity_scores_.shape[0], 3) + + test_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(test_loader)) + with torch.no_grad(): + out = ncp(**batch) + + self.assertIn("y_predset", out) + self.assertIn("y_prob", out) + self.assertEqual(out["y_predset"].dtype, torch.bool) + self.assertEqual(out["y_predset"].shape, out["y_prob"].shape) + + def test_forward_before_calibration_raises(self): + ncp = NeighborhoodLabel(model=self.model, alpha=0.1, k_neighbors=5) + loader = get_dataloader(self.dataset, batch_size=1, shuffle=False) + batch = next(iter(loader)) + with self.assertRaises(RuntimeError): + with torch.no_grad(): + ncp(**batch) + + def test_prediction_sets_nonempty_batch(self): + ncp = NeighborhoodLabel(model=self.model, alpha=0.3, k_neighbors=2, lambda_L=100.0) + cal_dataset = self.dataset.subset([2, 3, 4, 5]) + cal_emb = self._get_embeddings(cal_dataset) + ncp.calibrate(cal_dataset=cal_dataset, cal_embeddings=cal_emb) + + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + with torch.no_grad(): + for batch in loader: + out = ncp(**batch) + set_sizes = out["y_predset"].sum(dim=1) + self.assertTrue(torch.all(set_sizes > 0), "Prediction sets should be non-empty") + + def test_calibrate_without_embeddings_extracts(self): + ncp = NeighborhoodLabel(model=self.model, alpha=0.1, k_neighbors=2) + cal_dataset = self.dataset.subset([3, 4, 5]) + ncp.calibrate(cal_dataset=cal_dataset, batch_size=2) + self.assertIsNotNone(ncp.cal_embeddings_) + self.assertIsNotNone(ncp.cal_conformity_scores_) + + +if __name__ == "__main__": + unittest.main() From 854344a8897ada06500ad3d26b4a4b1486e09e65 Mon Sep 17 00:00:00 2001 From: lehendo Date: Fri, 6 Feb 2026 13:11:00 -0600 Subject: [PATCH 03/16] All the init files --- pyhealth/calib/predictionset/__init__.py | 12 +++++++-- .../predictionset/base_conformal/__init__.py | 26 +++++++++++++++++++ .../calib/predictionset/cluster/__init__.py | 5 ++-- 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/pyhealth/calib/predictionset/__init__.py b/pyhealth/calib/predictionset/__init__.py index 46760945f..f8328d35c 100644 --- a/pyhealth/calib/predictionset/__init__.py +++ b/pyhealth/calib/predictionset/__init__.py @@ -1,10 +1,18 @@ """Prediction set construction methods""" from pyhealth.calib.predictionset.base_conformal import BaseConformal -from pyhealth.calib.predictionset.cluster import ClusterLabel +from pyhealth.calib.predictionset.cluster import ClusterLabel, NeighborhoodLabel from pyhealth.calib.predictionset.covariate import CovariateLabel from pyhealth.calib.predictionset.favmac import FavMac from pyhealth.calib.predictionset.label import LABEL from pyhealth.calib.predictionset.scrib import SCRIB -__all__ = ["BaseConformal", "LABEL", "SCRIB", "FavMac", "CovariateLabel", "ClusterLabel"] +__all__ = [ + "BaseConformal", + "LABEL", + "SCRIB", + "FavMac", + "CovariateLabel", + "ClusterLabel", + "NeighborhoodLabel", +] diff --git a/pyhealth/calib/predictionset/base_conformal/__init__.py b/pyhealth/calib/predictionset/base_conformal/__init__.py index b1ceefdee..8e37c0d6d 100644 --- a/pyhealth/calib/predictionset/base_conformal/__init__.py +++ b/pyhealth/calib/predictionset/base_conformal/__init__.py @@ -45,6 +45,32 @@ def _query_quantile(scores: np.ndarray, alpha: float) -> float: return -np.inf if loc == -1 else scores[loc] +def _query_weighted_quantile( + scores: np.ndarray, alpha: float, weights: np.ndarray +) -> float: + """Compute weighted quantile of scores (e.g. for NCP or covariate shift). + + Args: + scores: Array of conformity scores + alpha: Quantile level (between 0 and 1) + weights: Weights for each score (same length as scores) + + Returns: + The weighted alpha-quantile of scores + """ + sorted_indices = np.argsort(scores) + sorted_scores = scores[sorted_indices] + sorted_weights = weights[sorted_indices] + w_sum = np.sum(sorted_weights) + if w_sum <= 0: + return -np.inf + cum_weights = np.cumsum(sorted_weights) / w_sum + idx = np.searchsorted(cum_weights, alpha, side="left") + if idx >= len(sorted_scores): + idx = len(sorted_scores) - 1 + return float(sorted_scores[idx]) + + class BaseConformal(SetPredictor): """Base Conformal Prediction for multiclass classification. diff --git a/pyhealth/calib/predictionset/cluster/__init__.py b/pyhealth/calib/predictionset/cluster/__init__.py index eacd89967..44f26dc28 100644 --- a/pyhealth/calib/predictionset/cluster/__init__.py +++ b/pyhealth/calib/predictionset/cluster/__init__.py @@ -1,5 +1,6 @@ -"""Cluster-based prediction set methods""" +"""Cluster-based and neighborhood prediction set methods""" from pyhealth.calib.predictionset.cluster.cluster_label import ClusterLabel +from pyhealth.calib.predictionset.cluster.neighborhood_label import NeighborhoodLabel -__all__ = ["ClusterLabel"] +__all__ = ["ClusterLabel", "NeighborhoodLabel"] From ba3975aded0d5d7d61cb490c391534c856bf90f0 Mon Sep 17 00:00:00 2001 From: lehendo Date: Fri, 6 Feb 2026 13:23:12 -0600 Subject: [PATCH 04/16] Check build --- pyhealth/calib/predictionset/cluster/neighborhood_label.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/calib/predictionset/cluster/neighborhood_label.py b/pyhealth/calib/predictionset/cluster/neighborhood_label.py index db583b32e..c9bf47bbd 100644 --- a/pyhealth/calib/predictionset/cluster/neighborhood_label.py +++ b/pyhealth/calib/predictionset/cluster/neighborhood_label.py @@ -1,7 +1,7 @@ """ Neighborhood Conformal Prediction (NCP). -This module implements Neighborhood Conformal Prediction. +This module implements Neighborhood Conformal Prediction! NCP uses the learned representation (embeddings) to assign importance weights to calibration examples: for each test input it finds k nearest calibration From 11b9eaa4bf3ebde84cc7ac7971695fc69fd9f9be Mon Sep 17 00:00:00 2001 From: lehendo Date: Fri, 6 Feb 2026 13:40:30 -0600 Subject: [PATCH 05/16] random unrealted fix to fix flaky test --- tests/core/test_gamenet.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/core/test_gamenet.py b/tests/core/test_gamenet.py index f456189c7..8b735b857 100644 --- a/tests/core/test_gamenet.py +++ b/tests/core/test_gamenet.py @@ -103,14 +103,15 @@ def test_model_backward(self): ) def test_loss_is_finite(self): + """Test that the loss is finite.""" + torch.manual_seed(42) # reproducibility: shuffle + dropout can rarely yield non-finite loss train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) data_batch = next(iter(train_loader)) - ret = self.model(**data_batch) + with torch.no_grad(): + ret = self.model(**data_batch) self.assertTrue(torch.isfinite(ret["loss"]).all()) - self.assertFalse(torch.isnan(ret["loss"]).any()) - self.assertFalse(torch.isinf(ret["loss"]).any()) def test_output_shapes(self): train_loader = get_dataloader(self.dataset, batch_size=3, shuffle=True) From 836b6f2d32bcc14af50d4f44a35b7beb3ce9b2be Mon Sep 17 00:00:00 2001 From: lehendo Date: Fri, 6 Feb 2026 14:35:44 -0600 Subject: [PATCH 06/16] Add docs --- docs/api/calib.rst | 2 ++ .../api/calib/pyhealth.calib.predictionset.rst | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/docs/api/calib.rst b/docs/api/calib.rst index daff8e8c7..7599b2a10 100644 --- a/docs/api/calib.rst +++ b/docs/api/calib.rst @@ -22,6 +22,8 @@ confidence levels: - :class:`~pyhealth.calib.predictionset.SCRIB`: Class-specific risk control - :class:`~pyhealth.calib.predictionset.FavMac`: Value-maximizing sets with cost control - :class:`~pyhealth.calib.predictionset.CovariateLabel`: Covariate shift adaptive conformal +- :class:`~pyhealth.calib.predictionset.ClusterLabel`: K-means cluster-based conformal prediction +- :class:`~pyhealth.calib.predictionset.NeighborhoodLabel`: Neighborhood Conformal Prediction (NCP) Getting Started --------------- diff --git a/docs/api/calib/pyhealth.calib.predictionset.rst b/docs/api/calib/pyhealth.calib.predictionset.rst index 63d314203..fe445ea1b 100644 --- a/docs/api/calib/pyhealth.calib.predictionset.rst +++ b/docs/api/calib/pyhealth.calib.predictionset.rst @@ -16,6 +16,8 @@ Available Methods pyhealth.calib.predictionset.SCRIB pyhealth.calib.predictionset.FavMac pyhealth.calib.predictionset.CovariateLabel + pyhealth.calib.predictionset.ClusterLabel + pyhealth.calib.predictionset.NeighborhoodLabel LABEL (Least Ambiguous Set-valued Classifier) ---------------------------------------------- @@ -49,6 +51,22 @@ CovariateLabel (Covariate Shift Adaptive) :undoc-members: :show-inheritance: +ClusterLabel (K-means Cluster-based Conformal) +---------------------------------------------- + +.. autoclass:: pyhealth.calib.predictionset.ClusterLabel + :members: + :undoc-members: + :show-inheritance: + +NeighborhoodLabel (Neighborhood Conformal Prediction) +----------------------------------------------------- + +.. autoclass:: pyhealth.calib.predictionset.NeighborhoodLabel + :members: + :undoc-members: + :show-inheritance: + Helper Functions ---------------- From aa5f8728b9ab7a4989c73c940086ff4b1fc842df Mon Sep 17 00:00:00 2001 From: lehendo Date: Fri, 6 Feb 2026 15:29:39 -0600 Subject: [PATCH 07/16] scripts for baseline testing --- .../conformal_eeg/tuev_kmeans_conformal.py | 62 ++++- examples/conformal_eeg/tuev_ncp_conformal.py | 251 ++++++++++++++++++ 2 files changed, 308 insertions(+), 5 deletions(-) create mode 100644 examples/conformal_eeg/tuev_ncp_conformal.py diff --git a/examples/conformal_eeg/tuev_kmeans_conformal.py b/examples/conformal_eeg/tuev_kmeans_conformal.py index 82aa27e3f..18ab5acfe 100644 --- a/examples/conformal_eeg/tuev_kmeans_conformal.py +++ b/examples/conformal_eeg/tuev_kmeans_conformal.py @@ -10,6 +10,7 @@ Example (from repo root): python examples/conformal_eeg/tuev_kmeans_conformal.py --root downloads/tuev/v2.0.1/edf --n-clusters 5 + python examples/conformal_eeg/tuev_kmeans_conformal.py --quick-test --log-file quicktest_kmeans.log Notes: - ClusterLabel uses K-means clustering on embeddings to compute cluster-specific thresholds. @@ -20,11 +21,30 @@ import argparse import random +import sys from pathlib import Path import numpy as np import torch + +class _Tee: + """Writes to both a stream and a file.""" + + def __init__(self, stream, file): + self._stream = stream + self._file = file + + def write(self, data): + self._stream.write(data) + self._file.write(data) + self._file.flush() + + def flush(self): + self._stream.flush() + self._file.flush() + + from pyhealth.calib.predictionset.cluster import ClusterLabel from pyhealth.calib.utils import extract_embeddings from pyhealth.datasets import TUEVDataset, get_dataloader, split_by_sample_conformal @@ -43,10 +63,10 @@ def parse_args() -> argparse.Namespace: default="downloads/tuev/v2.0.1/edf", help="Path to TUEV edf/ folder.", ) - parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) + parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--epochs", type=int, default=3) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--alpha", type=float, default=0.1, help="Miscoverage rate (e.g., 0.1 => 90% target coverage).") parser.add_argument( "--ratios", @@ -69,6 +89,17 @@ def parse_args() -> argparse.Namespace: default=None, help="Device string, e.g. 'cuda:0' or 'cpu'. Defaults to auto-detect.", ) + parser.add_argument( + "--log-file", + type=str, + default=None, + help="Path to log file. Stdout and stderr are teed to this file.", + ) + parser.add_argument( + "--quick-test", + action="store_true", + help="Dev mode: dev=True, 2 epochs, ~5 min smoke test.", + ) return parser.parse_args() @@ -84,6 +115,23 @@ def main() -> None: args = parse_args() set_seed(args.seed) + orig_stdout, orig_stderr = sys.stdout, sys.stderr + log_file = None + if args.log_file: + log_file = open(args.log_file, "w", encoding="utf-8") + sys.stdout = _Tee(orig_stdout, log_file) + sys.stderr = _Tee(orig_stderr, log_file) + + try: + _run(args) + finally: + if log_file is not None: + sys.stdout = orig_stdout + sys.stderr = orig_stderr + log_file.close() + + +def _run(args: argparse.Namespace) -> None: device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu") root = Path(args.root) if not root.exists(): @@ -92,10 +140,14 @@ def main() -> None: "Pass --root to point to your downloaded TUEV edf/ directory." ) + epochs = 2 if args.quick_test else args.epochs + if args.quick_test: + print("*** QUICK TEST MODE (dev=True, 2 epochs) ***") + print("=" * 80) print("STEP 1: Load TUEV + build task dataset") print("=" * 80) - dataset = TUEVDataset(root=str(root), subset=args.subset) + dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") print(f"Task samples: {len(sample_dataset)}") @@ -129,7 +181,7 @@ def main() -> None: trainer.train( train_dataloader=train_loader, val_dataloader=val_loader, - epochs=args.epochs, + epochs=epochs, monitor="accuracy" if val_loader is not None else None, ) diff --git a/examples/conformal_eeg/tuev_ncp_conformal.py b/examples/conformal_eeg/tuev_ncp_conformal.py new file mode 100644 index 000000000..53d379ae6 --- /dev/null +++ b/examples/conformal_eeg/tuev_ncp_conformal.py @@ -0,0 +1,251 @@ +"""Neighborhood Conformal Prediction (NCP) on TUEV EEG Events using ContraWR. + +This script: +1) Loads the TUEV dataset and applies the EEGEventsTUEV task. +2) Splits into train/val/cal/test using split conformal protocol. +3) Trains a ContraWR model. +4) Extracts calibration embeddings and calibrates a NeighborhoodLabel (NCP) predictor. +5) Evaluates prediction-set coverage/miscoverage and efficiency on the test split. + +Example (from repo root): + python examples/conformal_eeg/tuev_ncp_conformal.py --root downloads/tuev/v2.0.1/edf + python examples/conformal_eeg/tuev_ncp_conformal.py --quick-test --log-file quicktest_ncp.log +""" + +from __future__ import annotations + +import argparse +import random +import sys +from pathlib import Path + +import numpy as np +import torch + + +class _Tee: + """Writes to both a stream and a file.""" + + def __init__(self, stream, file): + self._stream = stream + self._file = file + + def write(self, data): + self._stream.write(data) + self._file.write(data) + self._file.flush() + + def flush(self): + self._stream.flush() + self._file.flush() + + +from pyhealth.calib.predictionset.cluster import NeighborhoodLabel +from pyhealth.calib.utils import extract_embeddings +from pyhealth.datasets import TUEVDataset, get_dataloader, split_by_sample_conformal +from pyhealth.models import ContraWR +from pyhealth.tasks import EEGEventsTUEV +from pyhealth.trainer import Trainer, get_metrics_fn + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Neighborhood conformal prediction (NCP) on TUEV EEG events using ContraWR." + ) + parser.add_argument( + "--root", + type=str, + default="downloads/tuev/v2.0.1/edf", + help="Path to TUEV edf/ folder.", + ) + parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--epochs", type=int, default=20) + parser.add_argument("--alpha", type=float, default=0.1, help="Miscoverage rate (e.g., 0.1 => 90% target coverage).") + parser.add_argument( + "--ratios", + type=float, + nargs=4, + default=(0.6, 0.1, 0.15, 0.15), + metavar=("TRAIN", "VAL", "CAL", "TEST"), + help="Split ratios for train/val/cal/test. Must sum to 1.0.", + ) + parser.add_argument( + "--k-neighbors", + type=int, + default=50, + help="Number of nearest calibration neighbors for NCP.", + ) + parser.add_argument( + "--lambda-L", + type=float, + default=100.0, + help="Temperature for NCP exponential weights; smaller => more localization.", + ) + parser.add_argument("--n-fft", type=int, default=128, help="STFT FFT size used by ContraWR.") + parser.add_argument( + "--device", + type=str, + default=None, + help="Device string, e.g. 'cuda:0' or 'cpu'. Defaults to auto-detect.", + ) + parser.add_argument( + "--log-file", + type=str, + default=None, + help="Path to log file. Stdout and stderr are teed to this file.", + ) + parser.add_argument( + "--quick-test", + action="store_true", + help="Dev mode: dev=True, 2 epochs, ~5 min smoke test.", + ) + return parser.parse_args() + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def main() -> None: + args = parse_args() + set_seed(args.seed) + + orig_stdout, orig_stderr = sys.stdout, sys.stderr + log_file = None + if args.log_file: + log_file = open(args.log_file, "w", encoding="utf-8") + sys.stdout = _Tee(orig_stdout, log_file) + sys.stderr = _Tee(orig_stderr, log_file) + + try: + _run(args) + finally: + if log_file is not None: + sys.stdout = orig_stdout + sys.stderr = orig_stderr + log_file.close() + + +def _run(args: argparse.Namespace) -> None: + device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu") + root = Path(args.root) + if not root.exists(): + raise FileNotFoundError( + f"TUEV root not found: {root}. " + "Pass --root to point to your downloaded TUEV edf/ directory." + ) + + epochs = 2 if args.quick_test else args.epochs + if args.quick_test: + print("*** QUICK TEST MODE (dev=True, 2 epochs) ***") + + print("=" * 80) + print("STEP 1: Load TUEV + build task dataset") + print("=" * 80) + dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) + sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") + + print(f"Task samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + if len(sample_dataset) == 0: + raise RuntimeError("No samples produced. Verify TUEV root/subset/task.") + + print("\n" + "=" * 80) + print("STEP 2: Split train/val/cal/test") + print("=" * 80) + train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal( + dataset=sample_dataset, ratios=list(args.ratios), seed=args.seed + ) + print(f"Train: {len(train_ds)}") + print(f"Val: {len(val_ds)}") + print(f"Cal: {len(cal_ds)}") + print(f"Test: {len(test_ds)}") + + train_loader = get_dataloader(train_ds, batch_size=args.batch_size, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=args.batch_size, shuffle=False) if len(val_ds) else None + test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False) + + print("\n" + "=" * 80) + print("STEP 3: Train ContraWR") + print("=" * 80) + model = ContraWR(dataset=sample_dataset, n_fft=args.n_fft).to(device) + trainer = Trainer(model=model, device=device, enable_logging=False) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=epochs, + monitor="accuracy" if val_loader is not None else None, + ) + + print("\nBase model performance on test set:") + y_true_base, y_prob_base, _loss_base = trainer.inference(test_loader) + base_metrics = get_metrics_fn("multiclass")(y_true_base, y_prob_base, metrics=["accuracy", "f1_weighted"]) + for metric, value in base_metrics.items(): + print(f" {metric}: {value:.4f}") + + print("\n" + "=" * 80) + print("STEP 4: Neighborhood Conformal Prediction (NCP / NeighborhoodLabel)") + print("=" * 80) + print(f"Target miscoverage alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})") + print(f"k_neighbors: {args.k_neighbors}, lambda_L: {args.lambda_L}") + + print("Extracting embeddings for calibration split...") + cal_embeddings = extract_embeddings(model, cal_ds, batch_size=args.batch_size, device=device) + print(f" cal_embeddings shape: {cal_embeddings.shape}") + + ncp_predictor = NeighborhoodLabel( + model=model, + alpha=float(args.alpha), + k_neighbors=args.k_neighbors, + lambda_L=args.lambda_L, + ) + print("Calibrating NCP predictor (store cal embeddings and conformity scores)...") + ncp_predictor.calibrate( + cal_dataset=cal_ds, + cal_embeddings=cal_embeddings, + ) + + print("Evaluating NCP predictor on test set...") + y_true, y_prob, _loss, extra = Trainer(model=ncp_predictor).inference( + test_loader, additional_outputs=["y_predset"] + ) + + ncp_metrics = get_metrics_fn("multiclass")( + y_true, + y_prob, + metrics=["accuracy", "miscoverage_ps"], + y_predset=extra["y_predset"], + ) + + predset = extra["y_predset"] + if isinstance(predset, np.ndarray): + predset_t = torch.tensor(predset) + else: + predset_t = predset + avg_set_size = predset_t.float().sum(dim=1).mean().item() + + miscoverage = ncp_metrics["miscoverage_ps"] + if isinstance(miscoverage, np.ndarray): + miscoverage = float(miscoverage.item() if miscoverage.size == 1 else miscoverage.mean()) + else: + miscoverage = float(miscoverage) + + print("\nNCP (NeighborhoodLabel) Results:") + print(f" Accuracy: {ncp_metrics['accuracy']:.4f}") + print(f" Empirical miscoverage: {miscoverage:.4f}") + print(f" Empirical coverage: {1 - miscoverage:.4f}") + print(f" Average set size: {avg_set_size:.2f}") + print(f" k_neighbors: {args.k_neighbors}") + + +if __name__ == "__main__": + main() From 99aad036fef96003ab4aa8cdd929e63cd832b1a3 Mon Sep 17 00:00:00 2001 From: lehendo Date: Fri, 6 Feb 2026 15:45:29 -0600 Subject: [PATCH 08/16] update data path --- examples/conformal_eeg/tuev_kmeans_conformal.py | 4 ++-- examples/conformal_eeg/tuev_ncp_conformal.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/conformal_eeg/tuev_kmeans_conformal.py b/examples/conformal_eeg/tuev_kmeans_conformal.py index 18ab5acfe..62ab2273e 100644 --- a/examples/conformal_eeg/tuev_kmeans_conformal.py +++ b/examples/conformal_eeg/tuev_kmeans_conformal.py @@ -9,7 +9,7 @@ 6) Evaluates prediction-set coverage/miscoverage and efficiency on the test split. Example (from repo root): - python examples/conformal_eeg/tuev_kmeans_conformal.py --root downloads/tuev/v2.0.1/edf --n-clusters 5 + python examples/conformal_eeg/tuev_kmeans_conformal.py --root /srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf --n-clusters 5 python examples/conformal_eeg/tuev_kmeans_conformal.py --quick-test --log-file quicktest_kmeans.log Notes: @@ -60,7 +60,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--root", type=str, - default="downloads/tuev/v2.0.1/edf", + default="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf", help="Path to TUEV edf/ folder.", ) parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) diff --git a/examples/conformal_eeg/tuev_ncp_conformal.py b/examples/conformal_eeg/tuev_ncp_conformal.py index 53d379ae6..3cdb432c8 100644 --- a/examples/conformal_eeg/tuev_ncp_conformal.py +++ b/examples/conformal_eeg/tuev_ncp_conformal.py @@ -8,7 +8,7 @@ 5) Evaluates prediction-set coverage/miscoverage and efficiency on the test split. Example (from repo root): - python examples/conformal_eeg/tuev_ncp_conformal.py --root downloads/tuev/v2.0.1/edf + python examples/conformal_eeg/tuev_ncp_conformal.py --root /srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf python examples/conformal_eeg/tuev_ncp_conformal.py --quick-test --log-file quicktest_ncp.log """ @@ -55,7 +55,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--root", type=str, - default="downloads/tuev/v2.0.1/edf", + default="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf", help="Path to TUEV edf/ folder.", ) parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) From 928931927c0c97354c63637d897addebda49d46d Mon Sep 17 00:00:00 2001 From: lehendo Date: Fri, 6 Feb 2026 16:31:32 -0600 Subject: [PATCH 09/16] quick test script changes --- examples/conformal_eeg/tuev_kmeans_conformal.py | 8 ++++++-- examples/conformal_eeg/tuev_ncp_conformal.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/conformal_eeg/tuev_kmeans_conformal.py b/examples/conformal_eeg/tuev_kmeans_conformal.py index 62ab2273e..b5a043dad 100644 --- a/examples/conformal_eeg/tuev_kmeans_conformal.py +++ b/examples/conformal_eeg/tuev_kmeans_conformal.py @@ -98,7 +98,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--quick-test", action="store_true", - help="Dev mode: dev=True, 2 epochs, ~5 min smoke test.", + help="Smoke test: dev=True, max 2000 samples, 2 epochs, ~5-10 min.", ) return parser.parse_args() @@ -141,14 +141,18 @@ def _run(args: argparse.Namespace) -> None: ) epochs = 2 if args.quick_test else args.epochs + quick_test_max_samples = 2000 # cap samples so quick-test finishes in ~5-10 min if args.quick_test: - print("*** QUICK TEST MODE (dev=True, 2 epochs) ***") + print("*** QUICK TEST MODE (dev=True, 2 epochs, max 2000 samples) ***") print("=" * 80) print("STEP 1: Load TUEV + build task dataset") print("=" * 80) dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") + if args.quick_test and len(sample_dataset) > quick_test_max_samples: + sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) + print(f"Capped to {quick_test_max_samples} samples for quick-test.") print(f"Task samples: {len(sample_dataset)}") print(f"Input schema: {sample_dataset.input_schema}") diff --git a/examples/conformal_eeg/tuev_ncp_conformal.py b/examples/conformal_eeg/tuev_ncp_conformal.py index 3cdb432c8..7441cbae9 100644 --- a/examples/conformal_eeg/tuev_ncp_conformal.py +++ b/examples/conformal_eeg/tuev_ncp_conformal.py @@ -99,7 +99,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--quick-test", action="store_true", - help="Dev mode: dev=True, 2 epochs, ~5 min smoke test.", + help="Smoke test: dev=True, max 2000 samples, 2 epochs, ~5-10 min.", ) return parser.parse_args() @@ -142,14 +142,18 @@ def _run(args: argparse.Namespace) -> None: ) epochs = 2 if args.quick_test else args.epochs + quick_test_max_samples = 2000 # cap samples so quick-test finishes in ~5-10 min if args.quick_test: - print("*** QUICK TEST MODE (dev=True, 2 epochs) ***") + print("*** QUICK TEST MODE (dev=True, 2 epochs, max 2000 samples) ***") print("=" * 80) print("STEP 1: Load TUEV + build task dataset") print("=" * 80) dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") + if args.quick_test and len(sample_dataset) > quick_test_max_samples: + sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) + print(f"Capped to {quick_test_max_samples} samples for quick-test.") print(f"Task samples: {len(sample_dataset)}") print(f"Input schema: {sample_dataset.input_schema}") From d99c67736e6ccb86b9835fd486394c69d66e231e Mon Sep 17 00:00:00 2001 From: lehendo Date: Fri, 6 Feb 2026 16:44:10 -0600 Subject: [PATCH 10/16] fix nan error --- pyhealth/metrics/prediction_set.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pyhealth/metrics/prediction_set.py b/pyhealth/metrics/prediction_set.py index 95d71e7e8..2b6f71705 100644 --- a/pyhealth/metrics/prediction_set.py +++ b/pyhealth/metrics/prediction_set.py @@ -33,8 +33,12 @@ def _missrate(y_pred:np.ndarray, y_true:np.ndarray, ignore_rejected=False): keep_msk = (y_pred.sum(1) == 1) if ignore_rejected else np.ones(len(y_true), dtype=bool) missed = [] for k in range(K): - missed.append(1-np.mean(y_pred[keep_msk & y_true[:, k], k])) - + msk = keep_msk & y_true[:, k] + n = msk.sum() + if n == 0: + missed.append(0.0) + else: + missed.append(1 - np.mean(y_pred[msk, k])) return np.asarray(missed) @@ -92,8 +96,8 @@ def miscoverage_overall_ps(y_pred:np.ndarray, y_true:np.ndarray): """ assert len(y_true.shape) == 1 truth_pred = y_pred[np.arange(len(y_true)), y_true] + return 1 - np.mean(truth_pred) if len(truth_pred) > 0 else 0.0 - return 1 - np.mean(truth_pred) def error_overall_ps(y_pred:np.ndarray, y_true:np.ndarray): """Overall error rate for the un-rejected samples. @@ -113,4 +117,4 @@ def error_overall_ps(y_pred:np.ndarray, y_true:np.ndarray): assert len(y_true.shape) == 1 truth_pred = y_pred[np.arange(len(y_true)), y_true] truth_pred = truth_pred[y_pred.sum(1) == 1] - return 1 - np.mean(truth_pred) + return 1 - np.mean(truth_pred) if len(truth_pred) > 0 else 0.0 From 16c7ac0c8741a670613421c7d7adb30e131e2762 Mon Sep 17 00:00:00 2001 From: lehendo Date: Sat, 7 Feb 2026 15:15:20 -0600 Subject: [PATCH 11/16] Fixed ncp implementation --- .../cluster/neighborhood_label.py | 58 +++++++++++++------ tests/core/test_neighborhood_label.py | 38 ++++++++++++ 2 files changed, 78 insertions(+), 18 deletions(-) diff --git a/pyhealth/calib/predictionset/cluster/neighborhood_label.py b/pyhealth/calib/predictionset/cluster/neighborhood_label.py index c9bf47bbd..18698d938 100644 --- a/pyhealth/calib/predictionset/cluster/neighborhood_label.py +++ b/pyhealth/calib/predictionset/cluster/neighborhood_label.py @@ -1,14 +1,6 @@ """ Neighborhood Conformal Prediction (NCP). -This module implements Neighborhood Conformal Prediction! - -NCP uses the learned representation (embeddings) to assign importance weights -to calibration examples: for each test input it finds k nearest calibration -points in embedding space and computes a weighted quantile of their conformity -scores to form an adaptive, per-sample threshold. This can yield smaller -prediction sets than standard conformal prediction when the representation -has good separation between classes. """ from typing import Dict, Optional, Union @@ -29,12 +21,6 @@ class NeighborhoodLabel(SetPredictor): """Neighborhood Conformal Prediction (NCP) for multiclass classification. - For each test input, NCP finds its k nearest calibration examples in - embedding space and assigns them importance weights proportional to - distance (exponential kernel). The prediction set threshold is the - weighted quantile of calibration conformity scores, so the threshold - is adaptive per test sample. - Reference: Ghosh, S., Belkhouja, T., Yan, Y., & Doppa, J. R. (2023). Improving Uncertainty Quantification of Deep Classifiers via @@ -115,6 +101,7 @@ def __init__( self.cal_embeddings_ = None self.cal_conformity_scores_ = None + self.alpha_tilde_ = None self._nn = None def calibrate( @@ -123,7 +110,12 @@ def calibrate( cal_embeddings: Optional[np.ndarray] = None, batch_size: int = 32, ) -> None: - """Calibrate NCP: store calibration embeddings and conformity scores. + """Calibrate NCP steps: + + Step 1: For each calibration point i, compute Q̃^NCP (weighted quantile) + over its k-NN in calibration using weights. + Step 2: Find ã^NCP(α) = largest ã such that empirical coverage on the + calibration set is >= 1-α; store as alpha_tilde_ for use at test time. Args: cal_dataset: Calibration dataset (for labels and predictions if @@ -165,9 +157,39 @@ def calibrate( self.cal_embeddings_ = np.atleast_2d(cal_embeddings) self.cal_conformity_scores_ = np.asarray(conformity_scores, dtype=np.float64) + # this is the ncp calibration step + distances_cal, indices_cal = self._nn.kneighbors( + self.cal_embeddings_, n_neighbors=k + ) + cal_weights = np.exp(-distances_cal / self.lambda_L) + cal_weights = cal_weights / cal_weights.sum(axis=1, keepdims=True) + + def _empirical_coverage(alpha_tilde_cand: float) -> float: + t_all = np.zeros(N, dtype=np.float64) + for i in range(N): + t_all[i] = _query_weighted_quantile( + self.cal_conformity_scores_[indices_cal[i]], + 1.0 - alpha_tilde_cand, + cal_weights[i], + ) + return float(np.mean(self.cal_conformity_scores_ <= t_all)) + + low, high = 0.0, 1.0 + for _ in range(50): + mid = (low + high) / 2 + if _empirical_coverage(mid) >= 1.0 - self.alpha: + low = mid + else: + high = mid + self.alpha_tilde_ = float(low) + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: - """Forward with NCP: per-sample weighted quantile threshold.""" - if self.cal_embeddings_ is None or self.cal_conformity_scores_ is None: + """Forward with NCP: per-sample weighted quantile threshold (Eq 2, Q̃^NCP).""" + if ( + self.cal_embeddings_ is None + or self.cal_conformity_scores_ is None + or self.alpha_tilde_ is None + ): raise RuntimeError( "NeighborhoodLabel must be calibrated before inference. " "Call calibrate() first." @@ -193,7 +215,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: w = w / np.sum(w) scores_i = self.cal_conformity_scores_[indices[i]] thresholds[i] = _query_weighted_quantile( - scores_i, self.alpha, w + scores_i, 1.0 - self.alpha_tilde_, w ) th = torch.as_tensor( diff --git a/tests/core/test_neighborhood_label.py b/tests/core/test_neighborhood_label.py index 8c4bf17a1..a6b4d1219 100644 --- a/tests/core/test_neighborhood_label.py +++ b/tests/core/test_neighborhood_label.py @@ -139,6 +139,44 @@ def test_calibrate_without_embeddings_extracts(self): self.assertIsNotNone(ncp.cal_embeddings_) self.assertIsNotNone(ncp.cal_conformity_scores_) + def test_calibration_empirical_coverage_at_least_1_minus_alpha(self): + """After calibrate(), empirical coverage on calibration set >= 1-alpha (Eq 2).""" + from pyhealth.calib.predictionset.base_conformal import _query_weighted_quantile + + ncp = NeighborhoodLabel(model=self.model, alpha=0.2, k_neighbors=3, lambda_L=50.0) + cal_indices = [0, 1, 2, 3, 4, 5] + cal_dataset = self.dataset.subset(cal_indices) + cal_emb = self._get_embeddings(cal_dataset) + ncp.calibrate(cal_dataset=cal_dataset, cal_embeddings=cal_emb) + + self.assertIsNotNone(ncp.alpha_tilde_) + self.assertGreaterEqual(ncp.alpha_tilde_, 0.0) + self.assertLessEqual(ncp.alpha_tilde_, 1.0) + + N = ncp.cal_conformity_scores_.shape[0] + k = min(ncp.k_neighbors, N) + distances_cal, indices_cal = ncp._nn.kneighbors( + ncp.cal_embeddings_, n_neighbors=k + ) + cal_weights = np.exp(-distances_cal / ncp.lambda_L) + cal_weights = cal_weights / cal_weights.sum(axis=1, keepdims=True) + + covered = 0 + for i in range(N): + t_i = _query_weighted_quantile( + ncp.cal_conformity_scores_[indices_cal[i]], + 1.0 - ncp.alpha_tilde_, + cal_weights[i], + ) + if ncp.cal_conformity_scores_[i] <= t_i: + covered += 1 + empirical_coverage = covered / N + self.assertGreaterEqual( + empirical_coverage, + 1.0 - ncp.alpha - 1e-6, + msg=f"Calibration empirical coverage {empirical_coverage:.4f} should be >= 1-alpha={1 - ncp.alpha}", + ) + if __name__ == "__main__": unittest.main() From 0c8e9cf093366ad76c136ef9460e38f930be20c0 Mon Sep 17 00:00:00 2001 From: lehendo Date: Sat, 7 Feb 2026 15:17:06 -0600 Subject: [PATCH 12/16] minor fixes --- pyhealth/calib/predictionset/cluster/neighborhood_label.py | 2 +- tests/core/test_neighborhood_label.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyhealth/calib/predictionset/cluster/neighborhood_label.py b/pyhealth/calib/predictionset/cluster/neighborhood_label.py index 18698d938..62c243877 100644 --- a/pyhealth/calib/predictionset/cluster/neighborhood_label.py +++ b/pyhealth/calib/predictionset/cluster/neighborhood_label.py @@ -184,7 +184,7 @@ def _empirical_coverage(alpha_tilde_cand: float) -> float: self.alpha_tilde_ = float(low) def forward(self, **kwargs) -> Dict[str, torch.Tensor]: - """Forward with NCP: per-sample weighted quantile threshold (Eq 2, Q̃^NCP).""" + """Forward with NCP: per-sample weighted quantile threshold.""" if ( self.cal_embeddings_ is None or self.cal_conformity_scores_ is None diff --git a/tests/core/test_neighborhood_label.py b/tests/core/test_neighborhood_label.py index a6b4d1219..266baa1dd 100644 --- a/tests/core/test_neighborhood_label.py +++ b/tests/core/test_neighborhood_label.py @@ -140,7 +140,7 @@ def test_calibrate_without_embeddings_extracts(self): self.assertIsNotNone(ncp.cal_conformity_scores_) def test_calibration_empirical_coverage_at_least_1_minus_alpha(self): - """After calibrate(), empirical coverage on calibration set >= 1-alpha (Eq 2).""" + """After calibrate(), empirical coverage on calibration set >= 1-alpha.""" from pyhealth.calib.predictionset.base_conformal import _query_weighted_quantile ncp = NeighborhoodLabel(model=self.model, alpha=0.2, k_neighbors=3, lambda_L=50.0) From 0b2b466b5e23a4ff9357ef4814a6fcc8983be9de Mon Sep 17 00:00:00 2001 From: lehendo Date: Sat, 7 Feb 2026 15:27:04 -0600 Subject: [PATCH 13/16] empty set logic --- .../calib/predictionset/cluster/neighborhood_label.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pyhealth/calib/predictionset/cluster/neighborhood_label.py b/pyhealth/calib/predictionset/cluster/neighborhood_label.py index 62c243877..e28405ef7 100644 --- a/pyhealth/calib/predictionset/cluster/neighborhood_label.py +++ b/pyhealth/calib/predictionset/cluster/neighborhood_label.py @@ -223,6 +223,12 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: ) if pred["y_prob"].ndim > 1: th = th.view(-1, *([1] * (pred["y_prob"].ndim - 1))) - pred["y_predset"] = pred["y_prob"] >= th + y_predset = pred["y_prob"] >= th + # if threshold is high, include at least argmax + empty = y_predset.sum(dim=1) == 0 + if empty.any(): + argmax_idx = pred["y_prob"].argmax(dim=1) + y_predset[empty, argmax_idx[empty]] = True + pred["y_predset"] = y_predset pred.pop("embed", None) return pred From 75df772714163c7e6f027fb6261e991aad59c680 Mon Sep 17 00:00:00 2001 From: lehendo Date: Sat, 7 Feb 2026 20:54:38 -0600 Subject: [PATCH 14/16] Math fix --- pyhealth/calib/predictionset/cluster/neighborhood_label.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyhealth/calib/predictionset/cluster/neighborhood_label.py b/pyhealth/calib/predictionset/cluster/neighborhood_label.py index e28405ef7..2d9f2dc6d 100644 --- a/pyhealth/calib/predictionset/cluster/neighborhood_label.py +++ b/pyhealth/calib/predictionset/cluster/neighborhood_label.py @@ -169,10 +169,10 @@ def _empirical_coverage(alpha_tilde_cand: float) -> float: for i in range(N): t_all[i] = _query_weighted_quantile( self.cal_conformity_scores_[indices_cal[i]], - 1.0 - alpha_tilde_cand, + alpha_tilde_cand, cal_weights[i], ) - return float(np.mean(self.cal_conformity_scores_ <= t_all)) + return float(np.mean(self.cal_conformity_scores_ >= t_all)) low, high = 0.0, 1.0 for _ in range(50): @@ -215,7 +215,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: w = w / np.sum(w) scores_i = self.cal_conformity_scores_[indices[i]] thresholds[i] = _query_weighted_quantile( - scores_i, 1.0 - self.alpha_tilde_, w + scores_i, self.alpha_tilde_, w ) th = torch.as_tensor( From 4712945ef15a253185a18594b62686e0572304de Mon Sep 17 00:00:00 2001 From: lehendo Date: Sat, 7 Feb 2026 21:08:19 -0600 Subject: [PATCH 15/16] Modify tests to reflect fix --- tests/core/test_neighborhood_label.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/core/test_neighborhood_label.py b/tests/core/test_neighborhood_label.py index 266baa1dd..b33f4c4b0 100644 --- a/tests/core/test_neighborhood_label.py +++ b/tests/core/test_neighborhood_label.py @@ -153,6 +153,7 @@ def test_calibration_empirical_coverage_at_least_1_minus_alpha(self): self.assertGreaterEqual(ncp.alpha_tilde_, 0.0) self.assertLessEqual(ncp.alpha_tilde_, 1.0) + # Recompute per-sample thresholds using alpha_tilde (Q^NCP definition: alpha_tilde-quantile of conformity) N = ncp.cal_conformity_scores_.shape[0] k = min(ncp.k_neighbors, N) distances_cal, indices_cal = ncp._nn.kneighbors( @@ -165,10 +166,11 @@ def test_calibration_empirical_coverage_at_least_1_minus_alpha(self): for i in range(N): t_i = _query_weighted_quantile( ncp.cal_conformity_scores_[indices_cal[i]], - 1.0 - ncp.alpha_tilde_, + ncp.alpha_tilde_, cal_weights[i], ) - if ncp.cal_conformity_scores_[i] <= t_i: + # Covered = true label in set = conformity_i >= threshold_i (paper: V_i <= t in non-conf space) + if ncp.cal_conformity_scores_[i] >= t_i: covered += 1 empirical_coverage = covered / N self.assertGreaterEqual( From 60c7796760fb212109b048875330c91fb46e668c Mon Sep 17 00:00:00 2001 From: lehendo Date: Sun, 8 Feb 2026 13:51:08 -0600 Subject: [PATCH 16/16] Updated tests --- examples/conformal_eeg/tuev_ncp_conformal.py | 333 ++++++++++++++----- 1 file changed, 254 insertions(+), 79 deletions(-) diff --git a/examples/conformal_eeg/tuev_ncp_conformal.py b/examples/conformal_eeg/tuev_ncp_conformal.py index 7441cbae9..9b51a7756 100644 --- a/examples/conformal_eeg/tuev_ncp_conformal.py +++ b/examples/conformal_eeg/tuev_ncp_conformal.py @@ -7,9 +7,14 @@ 4) Extracts calibration embeddings and calibrates a NeighborhoodLabel (NCP) predictor. 5) Evaluates prediction-set coverage/miscoverage and efficiency on the test split. +With --n-seeds > 1: fixes the test set (--split-seed), runs multiple training runs +with different seeds (different train/val/cal splits and model init), reports +coverage / set size / accuracy as mean ± std (error bars). + Example (from repo root): python examples/conformal_eeg/tuev_ncp_conformal.py --root /srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf python examples/conformal_eeg/tuev_ncp_conformal.py --quick-test --log-file quicktest_ncp.log + python examples/conformal_eeg/tuev_ncp_conformal.py --alpha 0.1 --n-seeds 5 --split-seed 0 --log-file ncp_seeds5.log """ from __future__ import annotations @@ -59,7 +64,25 @@ def parse_args() -> argparse.Namespace: help="Path to TUEV edf/ folder.", ) parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) - parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--seed", type=int, default=42, help="Run seed (or first of run seeds when n-seeds > 1).") + parser.add_argument( + "--n-seeds", + type=int, + default=1, + help="Number of runs for mean±std. Test set fixed; train/val/cal vary by seed.", + ) + parser.add_argument( + "--split-seed", + type=int, + default=0, + help="Fixed seed for initial split (fixes test set when n-seeds > 1).", + ) + parser.add_argument( + "--seeds", + type=str, + default=None, + help="Comma-separated run seeds, e.g. 42,43,44,45,46. Overrides --seed and --n-seeds.", + ) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--alpha", type=float, default=0.1, help="Miscoverage rate (e.g., 0.1 => 90% target coverage).") @@ -112,9 +135,126 @@ def set_seed(seed: int) -> None: torch.cuda.manual_seed_all(seed) +def _split_remainder_into_train_val_cal(sample_dataset, remainder_indices, ratios, run_seed): + """Split remainder indices into train/val/cal by renormalized ratios. Uses run_seed for shuffle.""" + r0, r1, r2, r3 = ratios + remainder_frac = 1.0 - r3 + if remainder_frac <= 0: + raise ValueError("Test ratio must be < 1 so remainder (train+val+cal) is non-empty.") + # Renormalize so train/val/cal ratios sum to 1 on the remainder + r_train = r0 / remainder_frac + r_val = r1 / remainder_frac + remainder = np.asarray(remainder_indices, dtype=np.int64) + np.random.seed(run_seed) + shuffled = np.random.permutation(remainder) + M = len(shuffled) + train_end = int(M * r_train) + val_end = int(M * (r_train + r_val)) + train_index = shuffled[:train_end] + val_index = shuffled[train_end:val_end] + cal_index = shuffled[val_end:] + train_ds = sample_dataset.subset(train_index.tolist()) + val_ds = sample_dataset.subset(val_index.tolist()) + cal_ds = sample_dataset.subset(cal_index.tolist()) + return train_ds, val_ds, cal_ds + + +def _run_one_ncp( + sample_dataset, + train_ds, + val_ds, + cal_ds, + test_loader, + args, + device, + epochs, + return_metrics=False, +): + """Train ContraWR, calibrate NCP, evaluate on test. Optionally return metrics dict for aggregation.""" + train_loader = get_dataloader(train_ds, batch_size=args.batch_size, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=args.batch_size, shuffle=False) if len(val_ds) else None + + print("\n" + "=" * 80) + print("STEP 3: Train ContraWR") + print("=" * 80) + model = ContraWR(dataset=sample_dataset, n_fft=args.n_fft).to(device) + trainer = Trainer(model=model, device=device, enable_logging=False) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=epochs, + monitor="accuracy" if val_loader is not None else None, + ) + + if not return_metrics: + print("\nBase model performance on test set:") + y_true_base, y_prob_base, _loss_base = trainer.inference(test_loader) + base_metrics = get_metrics_fn("multiclass")( + y_true_base, y_prob_base, metrics=["accuracy", "f1_weighted"] + ) + for metric, value in base_metrics.items(): + print(f" {metric}: {value:.4f}") + + print("\n" + "=" * 80) + print("STEP 4: Neighborhood Conformal Prediction (NCP / NeighborhoodLabel)") + print("=" * 80) + print(f"Target miscoverage alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})") + print(f"k_neighbors: {args.k_neighbors}, lambda_L: {args.lambda_L}") + + cal_embeddings = extract_embeddings(model, cal_ds, batch_size=args.batch_size, device=device) + if not return_metrics: + print(f" cal_embeddings shape: {cal_embeddings.shape}") + + ncp_predictor = NeighborhoodLabel( + model=model, + alpha=float(args.alpha), + k_neighbors=args.k_neighbors, + lambda_L=args.lambda_L, + ) + ncp_predictor.calibrate(cal_dataset=cal_ds, cal_embeddings=cal_embeddings) + + y_true, y_prob, _loss, extra = Trainer(model=ncp_predictor).inference( + test_loader, additional_outputs=["y_predset"] + ) + ncp_metrics = get_metrics_fn("multiclass")( + y_true, y_prob, metrics=["accuracy", "miscoverage_ps"], y_predset=extra["y_predset"] + ) + predset = extra["y_predset"] + if isinstance(predset, np.ndarray): + predset_t = torch.tensor(predset) + else: + predset_t = predset + avg_set_size = predset_t.float().sum(dim=1).mean().item() + miscoverage = ncp_metrics["miscoverage_ps"] + if isinstance(miscoverage, np.ndarray): + miscoverage = float(miscoverage.item() if miscoverage.size == 1 else miscoverage.mean()) + else: + miscoverage = float(miscoverage) + coverage = 1.0 - miscoverage + + if return_metrics: + return { + "accuracy": float(ncp_metrics["accuracy"]), + "coverage": coverage, + "miscoverage": miscoverage, + "avg_set_size": avg_set_size, + } + + print("\nNCP (NeighborhoodLabel) Results:") + print(f" Accuracy: {ncp_metrics['accuracy']:.4f}") + print(f" Empirical miscoverage: {miscoverage:.4f}") + print(f" Empirical coverage: {coverage:.4f}") + print(f" Average set size: {avg_set_size:.2f}") + print(f" k_neighbors: {args.k_neighbors}") + print("\n--- Single-run summary (for reporting) ---") + print(f" alpha={args.alpha}, target_coverage={1 - args.alpha:.2f}, empirical_coverage={coverage:.4f}, miscoverage={miscoverage:.4f}, accuracy={ncp_metrics['accuracy']:.4f}, avg_set_size={avg_set_size:.2f}") + + def main() -> None: args = parse_args() - set_seed(args.seed) + # Seed set per run in multi-seed mode; for single run set once here + if args.n_seeds <= 1 and args.seeds is None: + set_seed(args.seed) orig_stdout, orig_stderr = sys.stdout, sys.stderr log_file = None @@ -159,96 +299,131 @@ def _run(args: argparse.Namespace) -> None: print(f"Input schema: {sample_dataset.input_schema}") print(f"Output schema: {sample_dataset.output_schema}") + # Experiment configuration (for PI / reporting) + print("\n--- Experiment configuration ---") + print(f" dataset_root: {root}") + print(f" subset: {args.subset}, ratios: train/val/cal/test = {args.ratios[0]:.2f}/{args.ratios[1]:.2f}/{args.ratios[2]:.2f}/{args.ratios[3]:.2f}") + print(f" alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})") + print(f" k_neighbors: {args.k_neighbors}, lambda_L: {args.lambda_L}") + print(f" epochs: {epochs}, batch_size: {args.batch_size}, device: {device}, seed: {args.seed}") + if len(sample_dataset) == 0: raise RuntimeError("No samples produced. Verify TUEV root/subset/task.") + ratios = list(args.ratios) + use_multi_seed = args.n_seeds > 1 or args.seeds is not None + if use_multi_seed: + run_seeds = ( + [int(s.strip()) for s in args.seeds.split(",")] + if args.seeds + else [args.seed + i for i in range(args.n_seeds)] + ) + n_runs = len(run_seeds) + print(f" multi_seed: n_runs={n_runs}, run_seeds={run_seeds}, split_seed={args.split_seed} (fixed test set)") + print(f"Multi-seed mode: {n_runs} runs (fixed test set), run seeds: {run_seeds}") + + if not use_multi_seed: + # Single run: original behavior + print("\n" + "=" * 80) + print("STEP 2: Split train/val/cal/test") + print("=" * 80) + train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal( + dataset=sample_dataset, ratios=ratios, seed=args.seed + ) + print(f"Train: {len(train_ds)}") + print(f"Val: {len(val_ds)}") + print(f"Cal: {len(cal_ds)}") + print(f"Test: {len(test_ds)}") + + test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False) + _run_one_ncp( + sample_dataset=sample_dataset, + train_ds=train_ds, + val_ds=val_ds, + cal_ds=cal_ds, + test_loader=test_loader, + args=args, + device=device, + epochs=epochs, + ) + print("\n--- Split sizes and seed (for reporting) ---") + print(f" train={len(train_ds)}, val={len(val_ds)}, cal={len(cal_ds)}, test={len(test_ds)}, seed={args.seed}") + return + + # Multi-seed: fix test set, vary train/val/cal per run print("\n" + "=" * 80) - print("STEP 2: Split train/val/cal/test") + print("STEP 2: Fix test set (split-seed), then run multiple train/cal splits") print("=" * 80) - train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal( - dataset=sample_dataset, ratios=list(args.ratios), seed=args.seed + train_idx, val_idx, cal_idx, test_idx = split_by_sample_conformal( + dataset=sample_dataset, ratios=ratios, seed=args.split_seed, get_index=True ) - print(f"Train: {len(train_ds)}") - print(f"Val: {len(val_ds)}") - print(f"Cal: {len(cal_ds)}") - print(f"Test: {len(test_ds)}") - - train_loader = get_dataloader(train_ds, batch_size=args.batch_size, shuffle=True) - val_loader = get_dataloader(val_ds, batch_size=args.batch_size, shuffle=False) if len(val_ds) else None + # Convert to numpy for indexing + train_index = train_idx.numpy() if hasattr(train_idx, "numpy") else np.array(train_idx) + val_index = val_idx.numpy() if hasattr(val_idx, "numpy") else np.array(val_idx) + cal_index = cal_idx.numpy() if hasattr(cal_idx, "numpy") else np.array(cal_idx) + test_index = test_idx.numpy() if hasattr(test_idx, "numpy") else np.array(test_idx) + remainder_indices = np.concatenate([train_index, val_index, cal_index]) + test_ds = sample_dataset.subset(test_index.tolist()) test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False) + n_test = len(test_ds) + print(f"Fixed test set size: {n_test}") + + accs, coverages, miscoverages, set_sizes = [], [], [], [] + for run_i, run_seed in enumerate(run_seeds): + print("\n" + "=" * 80) + print(f"Run {run_i + 1} / {n_runs} (seed={run_seed})") + print("=" * 80) + set_seed(run_seed) + train_ds, val_ds, cal_ds = _split_remainder_into_train_val_cal( + sample_dataset, remainder_indices, ratios, run_seed + ) + print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Cal: {len(cal_ds)}") + + metrics = _run_one_ncp( + sample_dataset=sample_dataset, + train_ds=train_ds, + val_ds=val_ds, + cal_ds=cal_ds, + test_loader=test_loader, + args=args, + device=device, + epochs=epochs, + return_metrics=True, + ) + accs.append(metrics["accuracy"]) + coverages.append(metrics["coverage"]) + miscoverages.append(metrics["miscoverage"]) + set_sizes.append(metrics["avg_set_size"]) + + accs = np.array(accs) + coverages = np.array(coverages) + miscoverages_arr = np.array(miscoverages) + set_sizes = np.array(set_sizes) + # Per-run table (for PI / reporting) print("\n" + "=" * 80) - print("STEP 3: Train ContraWR") + print("Per-run NCP results (fixed test set)") print("=" * 80) - model = ContraWR(dataset=sample_dataset, n_fft=args.n_fft).to(device) - trainer = Trainer(model=model, device=device, enable_logging=False) - - trainer.train( - train_dataloader=train_loader, - val_dataloader=val_loader, - epochs=epochs, - monitor="accuracy" if val_loader is not None else None, - ) - - print("\nBase model performance on test set:") - y_true_base, y_prob_base, _loss_base = trainer.inference(test_loader) - base_metrics = get_metrics_fn("multiclass")(y_true_base, y_prob_base, metrics=["accuracy", "f1_weighted"]) - for metric, value in base_metrics.items(): - print(f" {metric}: {value:.4f}") + print(f" {'Run':<4} {'Seed':<6} {'Accuracy':<10} {'Coverage':<10} {'Miscoverage':<12} {'Avg set size':<12}") + print(" " + "-" * 54) + for i in range(n_runs): + print(f" {i+1:<4} {run_seeds[i]:<6} {accs[i]:<10.4f} {coverages[i]:<10.4f} {miscoverages_arr[i]:<12.4f} {set_sizes[i]:<12.2f}") print("\n" + "=" * 80) - print("STEP 4: Neighborhood Conformal Prediction (NCP / NeighborhoodLabel)") + print("NCP summary (mean ± std over {} runs, fixed test set)".format(n_runs)) print("=" * 80) - print(f"Target miscoverage alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})") - print(f"k_neighbors: {args.k_neighbors}, lambda_L: {args.lambda_L}") - - print("Extracting embeddings for calibration split...") - cal_embeddings = extract_embeddings(model, cal_ds, batch_size=args.batch_size, device=device) - print(f" cal_embeddings shape: {cal_embeddings.shape}") - - ncp_predictor = NeighborhoodLabel( - model=model, - alpha=float(args.alpha), - k_neighbors=args.k_neighbors, - lambda_L=args.lambda_L, - ) - print("Calibrating NCP predictor (store cal embeddings and conformity scores)...") - ncp_predictor.calibrate( - cal_dataset=cal_ds, - cal_embeddings=cal_embeddings, - ) - - print("Evaluating NCP predictor on test set...") - y_true, y_prob, _loss, extra = Trainer(model=ncp_predictor).inference( - test_loader, additional_outputs=["y_predset"] - ) - - ncp_metrics = get_metrics_fn("multiclass")( - y_true, - y_prob, - metrics=["accuracy", "miscoverage_ps"], - y_predset=extra["y_predset"], - ) - - predset = extra["y_predset"] - if isinstance(predset, np.ndarray): - predset_t = torch.tensor(predset) - else: - predset_t = predset - avg_set_size = predset_t.float().sum(dim=1).mean().item() - - miscoverage = ncp_metrics["miscoverage_ps"] - if isinstance(miscoverage, np.ndarray): - miscoverage = float(miscoverage.item() if miscoverage.size == 1 else miscoverage.mean()) - else: - miscoverage = float(miscoverage) - - print("\nNCP (NeighborhoodLabel) Results:") - print(f" Accuracy: {ncp_metrics['accuracy']:.4f}") - print(f" Empirical miscoverage: {miscoverage:.4f}") - print(f" Empirical coverage: {1 - miscoverage:.4f}") - print(f" Average set size: {avg_set_size:.2f}") - print(f" k_neighbors: {args.k_neighbors}") + print(f" Accuracy: {accs.mean():.4f} ± {accs.std():.4f}") + print(f" Empirical coverage: {coverages.mean():.4f} ± {coverages.std():.4f}") + print(f" Empirical miscoverage: {miscoverages_arr.mean():.4f} ± {miscoverages_arr.std():.4f}") + print(f" Average set size: {set_sizes.mean():.2f} ± {set_sizes.std():.2f}") + print(f" Target coverage: {1 - args.alpha:.0%} (alpha={args.alpha})") + print(f" k_neighbors: {args.k_neighbors}, lambda_L: {args.lambda_L}") + print(f" Test set size: {n_test} (fixed across runs)") + print(f" Run seeds: {run_seeds}") + print("\n--- Min / Max (across runs) ---") + print(f" Coverage: [{coverages.min():.4f}, {coverages.max():.4f}]") + print(f" Set size: [{set_sizes.min():.2f}, {set_sizes.max():.2f}]") + print(f" Accuracy: [{accs.min():.4f}, {accs.max():.4f}]") if __name__ == "__main__":