[Conformal EEG] Neighborhood Conformal Prediction#824
[Conformal EEG] Neighborhood Conformal Prediction#824lehendo wants to merge 15 commits intosunlabuiuc:masterfrom
Conversation
jhnwu3
left a comment
There was a problem hiding this comment.
Just quick change, can we make sure to update the docs/ to include this new method? So when we do merge, it'll update the site to have it as an option?
| n_cal = self.cal_conformity_scores_.shape[0] | ||
| k = min(self.k_neighbors, n_cal) | ||
|
|
||
| distances, indices = self._nn.kneighbors(test_emb, n_neighbors=k) |
There was a problem hiding this comment.
This implementation for NCP is most likely wrong.
check equations 2 in the paper: https://arxiv.org/pdf/2303.10694
The calibration is broken into two steps: 1) a q_ncp function, this calibrates local neighborhood of each calibration point. 2) write an alpha^tilde_NCP function for computing the ncp quantile. 3) check the distance functions defined in equations 3-5, we might want to check exponential transformation of the distance.
It appears this is what the current code does - identify K nearest neighbors of each test point; set importance weights based on distances from test point to these calibration points. This does not yield marginal coverage guarentees.
Arjun Chatterjee
This an implementation of the neighborhood conformal prediction from the paper that Sid shared. Added a test file and changed some of the init files.
Main implementation:
Calibrate: Stores calibration embeddings and conformity scores (prob of true class). Fits a k-NN index on calibration embeddings. Optional cal_embeddings or batch_size for extraction.
Forward: One model call with embed=True. For each test embedding: k-NN in calibration space, weights = exp(-dist/λ_L) over k-NN and normalized, threshold = weighted α-quantile of calibration conformity scores, y_predset = (y_prob >= threshold) with correct batch handling.
Args: model, alpha, k_neighbors=50, lambda_L=100.0, debug=False.
Validation: alpha in (0,1), k_neighbors positive int, multiclass-only.
Fix #789