From 9cabac56e92ffa7fc87860e8040ed1055098b013 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 5 Feb 2026 17:42:22 -0600 Subject: [PATCH 1/4] define interface --- pyhealth/models/base_model.py | 61 +++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/pyhealth/models/base_model.py b/pyhealth/models/base_model.py index ea8046355..8d0bebaa1 100644 --- a/pyhealth/models/base_model.py +++ b/pyhealth/models/base_model.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Callable, Any +from typing import Callable, Any, Optional import inspect import torch @@ -16,9 +16,20 @@ class BaseModel(ABC, nn.Module): Args: dataset (SampleDataset): The dataset to train the model. It is used to query certain information such as the set of all tokens. + + Interpretability + -------- + To use a model with interpretability methods, the model must implement a method + `forward_from_embedding` that takes in embeddings as input instead of raw features; + for the models that already take in dense features as input, this method can simply + call the existing `forward` method. + + For certain gradient-based interpretability methods (e.g., DeepLIFT), the model must also + ensure all non-linearity (e.g. ReLU, Sigmoid, Softmax) are using nn.Module versions instead of + functional versions (e.g., F.relu, F.sigmoid, F.softmax) so that hooks can be registered properly. """ - def __init__(self, dataset: SampleDataset = None): + def __init__(self, dataset: Optional[SampleDataset] = None): """ Initializes the BaseModel. @@ -44,6 +55,49 @@ def __init__(self, dataset: SampleDataset = None): self._dummy_param = nn.Parameter(torch.empty(0)) self.mode = getattr(self, "mode", None) # legacy API + + def forward(self, *kwargs: dict[str, torch.Tensor | str]) -> dict[str, torch.Tensor]: + """Forward pass of the model. + + Args: + *kwargs: A variable number of keyword arguments representing input features. + Each keyword argument is a tensor of shape (batch_size, ...). + + Returns: + A dictionary with the following keys: + distance: list of tensors of stage variation. + y_prob: a tensor of predicted probabilities. + loss [optional]: a scalar tensor representing the final loss, if self.label_keys in kwargs. + y_true [optional]: a tensor representing the true labels, if self.label_keys in kwargs. + """ + raise NotImplementedError + + def forward_from_embedding(self, *kwargs: dict[str, torch.Tensor | str]) -> dict[str, torch.Tensor]: + """Forward pass of the model from embeddings. + + This method should be implemented for interpretability methods that require + access to the model's forward pass from embeddings. + + Args: + *kwargs: A variable number of keyword arguments representing input features + as embeddings. Each keyword argument is a tensor of shape (batch_size, ...). + + Returns: + A dictionary with the following keys: + distance: list of tensors of stage variation. + y_prob: a tensor of predicted probabilities. + loss [optional]: a scalar tensor representing the final loss, if self.label_keys in kwargs. + y_true [optional]: a tensor representing the true labels, if self.label_keys in kwargs. + """ + raise NotImplementedError + + def get_embedding_model(self) -> nn.Module | None: + """Get the embedding model if applicable. This is used in pair with `forward_from_embedding`. + + Returns: + nn.Module | None: The embedding model or None if not applicable. + """ + raise NotImplementedError # ------------------------------------------------------------------ # Internal helpers @@ -94,6 +148,7 @@ def get_output_size(self) -> int: assert ( len(self.label_keys) == 1 ), "Only one label key is supported if get_output_size is called" + assert self.dataset is not None, "Dataset must be provided to get output size" output_size = self.dataset.output_processors[self.label_keys[0]].size() return output_size @@ -113,6 +168,7 @@ def get_loss_function(self) -> Callable: assert ( len(self.label_keys) == 1 ), "Only one label key is supported if get_loss_function is called" + assert self.dataset is not None, "Dataset must be provided to get loss function" label_key = self.label_keys[0] mode = self._resolve_mode(self.dataset.output_schema[label_key]) if mode == "binary": @@ -150,6 +206,7 @@ def prepare_y_prob(self, logits: torch.Tensor) -> torch.Tensor: assert ( len(self.label_keys) == 1 ), "Only one label key is supported if get_loss_function is called" + assert self.dataset is not None, "Dataset must be provided to prepare y_prob" label_key = self.label_keys[0] mode = self._resolve_mode(self.dataset.output_schema[label_key]) if mode in ["binary"]: From 06658002245d2b2ffc56d1260fb2d7380176917e Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 5 Feb 2026 17:46:45 -0600 Subject: [PATCH 2/4] Fix type-hint --- pyhealth/models/base_model.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/pyhealth/models/base_model.py b/pyhealth/models/base_model.py index 8d0bebaa1..376cb00a1 100644 --- a/pyhealth/models/base_model.py +++ b/pyhealth/models/base_model.py @@ -29,7 +29,7 @@ class BaseModel(ABC, nn.Module): functional versions (e.g., F.relu, F.sigmoid, F.softmax) so that hooks can be registered properly. """ - def __init__(self, dataset: Optional[SampleDataset] = None): + def __init__(self, dataset: SampleDataset): """ Initializes the BaseModel. @@ -56,7 +56,7 @@ def __init__(self, dataset: Optional[SampleDataset] = None): self.mode = getattr(self, "mode", None) # legacy API - def forward(self, *kwargs: dict[str, torch.Tensor | str]) -> dict[str, torch.Tensor]: + def forward(self, **kwargs: dict[str, torch.Tensor | str]) -> dict[str, torch.Tensor]: """Forward pass of the model. Args: @@ -72,7 +72,7 @@ def forward(self, *kwargs: dict[str, torch.Tensor | str]) -> dict[str, torch.Ten """ raise NotImplementedError - def forward_from_embedding(self, *kwargs: dict[str, torch.Tensor | str]) -> dict[str, torch.Tensor]: + def forward_from_embedding(self, **kwargs: dict[str, torch.Tensor | str]) -> dict[str, torch.Tensor]: """Forward pass of the model from embeddings. This method should be implemented for interpretability methods that require @@ -148,7 +148,6 @@ def get_output_size(self) -> int: assert ( len(self.label_keys) == 1 ), "Only one label key is supported if get_output_size is called" - assert self.dataset is not None, "Dataset must be provided to get output size" output_size = self.dataset.output_processors[self.label_keys[0]].size() return output_size @@ -168,7 +167,6 @@ def get_loss_function(self) -> Callable: assert ( len(self.label_keys) == 1 ), "Only one label key is supported if get_loss_function is called" - assert self.dataset is not None, "Dataset must be provided to get loss function" label_key = self.label_keys[0] mode = self._resolve_mode(self.dataset.output_schema[label_key]) if mode == "binary": @@ -206,7 +204,6 @@ def prepare_y_prob(self, logits: torch.Tensor) -> torch.Tensor: assert ( len(self.label_keys) == 1 ), "Only one label key is supported if get_loss_function is called" - assert self.dataset is not None, "Dataset must be provided to prepare y_prob" label_key = self.label_keys[0] mode = self._resolve_mode(self.dataset.output_schema[label_key]) if mode in ["binary"]: From f7d6b110fcb8b99bbdd6337488b98280072e78c8 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 5 Feb 2026 18:11:56 -0600 Subject: [PATCH 3/4] new API changes --- pyhealth/models/base_model.py | 22 +-- pyhealth/models/stagenet.py | 243 +++++++++++++--------------------- 2 files changed, 103 insertions(+), 162 deletions(-) diff --git a/pyhealth/models/base_model.py b/pyhealth/models/base_model.py index 376cb00a1..772b36e26 100644 --- a/pyhealth/models/base_model.py +++ b/pyhealth/models/base_model.py @@ -56,35 +56,41 @@ def __init__(self, dataset: SampleDataset): self.mode = getattr(self, "mode", None) # legacy API - def forward(self, **kwargs: dict[str, torch.Tensor | str]) -> dict[str, torch.Tensor]: + def forward(self, + **kwargs: dict[str, torch.Tensor | tuple[torch.Tensor, ...]] + ) -> dict[str, torch.Tensor]: """Forward pass of the model. Args: - *kwargs: A variable number of keyword arguments representing input features. - Each keyword argument is a tensor of shape (batch_size, ...). + **kwargs: A variable number of keyword arguments representing input features. + Each keyword argument is a tensor or a tuple of tensors of shape (batch_size, ...). Returns: A dictionary with the following keys: - distance: list of tensors of stage variation. + logit: a tensor of predicted logits. y_prob: a tensor of predicted probabilities. loss [optional]: a scalar tensor representing the final loss, if self.label_keys in kwargs. y_true [optional]: a tensor representing the true labels, if self.label_keys in kwargs. """ raise NotImplementedError - def forward_from_embedding(self, **kwargs: dict[str, torch.Tensor | str]) -> dict[str, torch.Tensor]: + def forward_from_embedding( + self, + **kwargs: dict[str, torch.Tensor | tuple[torch.Tensor, ...]] + ) -> dict[str, torch.Tensor]: """Forward pass of the model from embeddings. This method should be implemented for interpretability methods that require access to the model's forward pass from embeddings. Args: - *kwargs: A variable number of keyword arguments representing input features - as embeddings. Each keyword argument is a tensor of shape (batch_size, ...). + **kwargs: A variable number of keyword arguments representing input features + as embeddings. Each keyword argument is a tensor or a tuple of tensors of + shape (batch_size, ...). Returns: A dictionary with the following keys: - distance: list of tensors of stage variation. + logit: a tensor of predicted logits. y_prob: a tensor of predicted probabilities. loss [optional]: a scalar tensor representing the final loss, if self.label_keys in kwargs. y_true [optional]: a tensor representing the true labels, if self.label_keys in kwargs. diff --git a/pyhealth/models/stagenet.py b/pyhealth/models/stagenet.py index 973934904..d091a1564 100644 --- a/pyhealth/models/stagenet.py +++ b/pyhealth/models/stagenet.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, cast import torch import torch.nn as nn @@ -373,27 +373,9 @@ def __init__( len(self.feature_keys) * self.chunk_size * self.levels, output_size ) - self._deeplift_hooks = None - - # ------------------------------------------------------------------ - # Interpretability support (e.g., DeepLIFT) - # ------------------------------------------------------------------ - def set_deeplift_hooks(self, hooks) -> None: - """Backward-compatibility stub; activation swapping occurs in interpreters.""" - - self._deeplift_hooks = hooks - - def clear_deeplift_hooks(self) -> None: - """Backward-compatibility stub; activation swapping occurs in interpreters.""" - - self._deeplift_hooks = None - def forward_from_embedding( self, - feature_embeddings: Dict[str, torch.Tensor], - time_info: Optional[Dict[str, torch.Tensor]] = None, - mask_info: Optional[Dict[str, torch.Tensor]] = None, - **kwargs, + **kwargs: Dict[str, tuple[torch.Tensor, ...]], ) -> Dict[str, torch.Tensor]: """Forward pass starting from feature embeddings. @@ -403,17 +385,11 @@ def forward_from_embedding( interpolate in embedding space. Args: - feature_embeddings: Dictionary mapping feature keys to their - embedded representations. Each tensor should have shape - [batch_size, seq_len, embedding_dim]. - time_info: Optional dictionary mapping feature keys to their - time information tensors of shape [batch_size, seq_len]. - If None, uniform time intervals are assumed. - mask_info: Optional dictionary mapping feature keys to masks - of shape [batch_size, seq_len]. When provided, these masks - override the automatic mask derived from the embeddings. - **kwargs: Additional keyword arguments, must include the label - key for loss computation. + **kwargs: keyword arguments for the model. The keys must contain + all the feature keys and the label key. Feature keys should + contain tuples of tensors (time, values, mask) from temporal processors. + But the featurs keys can also contain just the values without time and mask for backward compatibility. + The label key should contain the true labels for loss computation. Returns: A dictionary with the following keys: @@ -427,31 +403,50 @@ def forward_from_embedding( distance = [] for feature_key in self.feature_keys: - # Get embedded feature - x = feature_embeddings[feature_key].to(self.device) - # x: [batch, seq_len, embedding_dim] or 4D nested - - # Handle nested sequences (4D) by pooling over inner dim - # This matches forward() processing for consistency - if x.dim() == 4: # [batch, seq_len, inner_len, embedding_dim] - # Sum pool over inner dimension - x = x.sum(dim=2) # [batch, seq_len, embedding_dim] - - # Get time information if available - time = None - if time_info is not None and feature_key in time_info: - if time_info[feature_key] is not None: - time = time_info[feature_key].to(self.device) - # Ensure time is 2D [batch, seq_len] - if time.dim() == 1: - time = time.unsqueeze(0) - - # Create mask from embedded values unless an explicit one is provided - if mask_info is not None and feature_key in mask_info: - mask = mask_info[feature_key].to(self.device) + feature = kwargs[feature_key] + + # Unpack feature tuple + if isinstance(feature, tuple): + time, x, mask = cast(tuple[torch.Tensor, ...], feature) + # else: - mask = (x.sum(dim=-1) != 0).int() # [batch, seq_len] - + x = cast(torch.Tensor, feature) + time = None + mask = None + + x = x.to(self.device) + + if time is None: + import warnings + warnings.warn( + f"Feature '{feature_key}' does not have time " + f"intervals. StageNet's temporal modeling " + f"capabilities will be limited. Consider using " + f"StageNet format with time intervals for " + f"better performance.", + UserWarning, + ) + else: + time = time.to(self.device) + # Ensure time is 2D [batch, seq_len] + if time.dim() == 1: + time = time.unsqueeze(0) + + if mask is None: + import warnings + warnings.warn( + f"Feature '{feature_key}' does not have mask " + f"information. Default mask will be created from " + f"embedded values. But it may not be accurate.", + ) + mask = (x.abs().sum(dim=-1) != 0).int() + else: + mask = mask.to(self.device) + + if x.dim() == 4: + # Nested sequences: [batch, seq_len, inner_len, embedding_dim] + x = x.sum(dim=2) # Sum pool over inner dimension + # Pass through StageNet layer with embedded features last_output, _, cur_dis = self.stagenet[feature_key]( x, time=time, mask=mask @@ -460,44 +455,44 @@ def forward_from_embedding( patient_emb.append(last_output) distance.append(cur_dis) - # Concatenate all feature embeddings patient_emb = torch.cat(patient_emb, dim=1) - - # Register hook if needed for gradient tracking - if patient_emb.requires_grad: - patient_emb.register_hook(lambda grad: grad) - - # Pass through final classification layer + # (patient, label_size) logits = self.fc(patient_emb) - - # Obtain y_true, loss, y_prob - y_true = kwargs[self.label_key].to(self.device) - loss = self.get_loss_function()(logits, y_true) - y_prob = self.prepare_y_prob(logits) + results = { - "loss": loss, + "logit": logits, "y_prob": y_prob, - "y_true": y_true, - "logit": logits, } - + + # obtain y_true, loss, y_prob + if self.label_key in kwargs: + y_true = cast(torch.Tensor, kwargs[self.label_key]).to(self.device) + loss = self.get_loss_function()(logits, y_true) + results["loss"] = loss + results["y_true"] = y_true + # Optionally return embeddings if kwargs.get("embed", False): results["embed"] = patient_emb - return results - def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + def forward( + self, + **kwargs: Dict[str, torch.Tensor | tuple[torch.Tensor, ...]] + ) -> Dict[str, torch.Tensor]: """Forward propagation. - The label `kwargs[self.label_key]` is a list of labels for each - patient. - Args: - **kwargs: keyword arguments for the model. The keys must contain - all the feature keys and the label key. Feature keys should - contain tuples of (time, values) from temporal processors. + **kwargs: keyword arguments for the model. + + The keys must contain all the feature keys and the label key. + + Feature keys should contain tuples of tensors (time, values) from temporal processors. + But the featurs keys can also contain just the values without time + at the cost of degraded performance. + + The label key should contain the true labels for loss computation. Returns: A dictionary with the following keys: @@ -556,6 +551,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: x = x.sum(dim=2) # [batch, seq_len, embedding_dim] # Create mask from embedded values + # TODO: mask should be created in embedding model, fix later mask = (x.sum(dim=-1) != 0).int() # [batch, seq_len] # Move time to correct device if present @@ -577,86 +573,25 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: patient_emb = torch.cat(patient_emb, dim=1) # (patient, label_size) logits = self.fc(patient_emb) - - # obtain y_true, loss, y_prob - y_true = kwargs[self.label_key].to(self.device) - loss = self.get_loss_function()(logits, y_true) - y_prob = self.prepare_y_prob(logits) + results = { - "loss": loss, + "logit": logits, "y_prob": y_prob, - "y_true": y_true, - "logit": logits, } + + # obtain y_true, loss, y_prob + if self.label_key in kwargs: + y_true = cast(torch.Tensor, kwargs[self.label_key]).to(self.device) + loss = self.get_loss_function()(logits, y_true) + results["loss"] = loss + results["y_true"] = y_true + + # Optionally return embeddings if kwargs.get("embed", False): results["embed"] = patient_emb return results + + def get_embedding_model(self) -> nn.Module | None: + return self.embedding_model - -if __name__ == "__main__": - from pyhealth.datasets import create_sample_dataset - - samples = [ - { - "patient_id": "patient-0", - "visit_id": "visit-0", - "codes": ( - [0.0, 2.0, 1.3], - ["505800458", "50580045810", "50580045811"], - ), - "procedures": ( - [0.0, 1.5], - [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], - ), - "label": 1, - }, - { - "patient_id": "patient-0", - "visit_id": "visit-1", - "codes": ( - [0.0, 2.0, 1.3, 1.0, 2.0], - [ - "55154191800", - "551541928", - "55154192800", - "705182798", - "70518279800", - ], - ), - "procedures": ( - [0.0], - [["A04A", "B035", "C129"]], - ), - "label": 0, - }, - ] - - # dataset - dataset = create_sample_dataset( - samples=samples, - input_schema={ - "codes": "stagenet", - "procedures": "stagenet", - }, - output_schema={"label": "binary"}, - dataset_name="test", - ) - - # data loader - from pyhealth.datasets import get_dataloader - - train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) - - # model - model = StageNet(dataset=dataset) - - # data batch - data_batch = next(iter(train_loader)) - - # try the model - ret = model(**data_batch) - print(ret) - - # try loss backward - ret["loss"].backward() From 1144bbcdf0e208e917fa7c6327e010152a1d25af Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 5 Feb 2026 18:16:11 -0600 Subject: [PATCH 4/4] Fix comment --- pyhealth/models/stagenet.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pyhealth/models/stagenet.py b/pyhealth/models/stagenet.py index d091a1564..7237221f4 100644 --- a/pyhealth/models/stagenet.py +++ b/pyhealth/models/stagenet.py @@ -386,9 +386,13 @@ def forward_from_embedding( Args: **kwargs: keyword arguments for the model. The keys must contain - all the feature keys and the label key. Feature keys should - contain tuples of tensors (time, values, mask) from temporal processors. - But the featurs keys can also contain just the values without time and mask for backward compatibility. + all the feature keys and the label key. + + Feature keys should contain tuples of tensors (time, embedding, mask) + from temporal processors. + But the featurs keys can also contain just the embedding values + without time and mask for backward compatibility. + The label key should contain the true labels for loss computation. Returns: