diff --git a/pyhealth/models/base_model.py b/pyhealth/models/base_model.py index ea8046355..772b36e26 100644 --- a/pyhealth/models/base_model.py +++ b/pyhealth/models/base_model.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Callable, Any +from typing import Callable, Any, Optional import inspect import torch @@ -16,9 +16,20 @@ class BaseModel(ABC, nn.Module): Args: dataset (SampleDataset): The dataset to train the model. It is used to query certain information such as the set of all tokens. + + Interpretability + -------- + To use a model with interpretability methods, the model must implement a method + `forward_from_embedding` that takes in embeddings as input instead of raw features; + for the models that already take in dense features as input, this method can simply + call the existing `forward` method. + + For certain gradient-based interpretability methods (e.g., DeepLIFT), the model must also + ensure all non-linearity (e.g. ReLU, Sigmoid, Softmax) are using nn.Module versions instead of + functional versions (e.g., F.relu, F.sigmoid, F.softmax) so that hooks can be registered properly. """ - def __init__(self, dataset: SampleDataset = None): + def __init__(self, dataset: SampleDataset): """ Initializes the BaseModel. @@ -44,6 +55,55 @@ def __init__(self, dataset: SampleDataset = None): self._dummy_param = nn.Parameter(torch.empty(0)) self.mode = getattr(self, "mode", None) # legacy API + + def forward(self, + **kwargs: dict[str, torch.Tensor | tuple[torch.Tensor, ...]] + ) -> dict[str, torch.Tensor]: + """Forward pass of the model. + + Args: + **kwargs: A variable number of keyword arguments representing input features. + Each keyword argument is a tensor or a tuple of tensors of shape (batch_size, ...). + + Returns: + A dictionary with the following keys: + logit: a tensor of predicted logits. + y_prob: a tensor of predicted probabilities. + loss [optional]: a scalar tensor representing the final loss, if self.label_keys in kwargs. + y_true [optional]: a tensor representing the true labels, if self.label_keys in kwargs. + """ + raise NotImplementedError + + def forward_from_embedding( + self, + **kwargs: dict[str, torch.Tensor | tuple[torch.Tensor, ...]] + ) -> dict[str, torch.Tensor]: + """Forward pass of the model from embeddings. + + This method should be implemented for interpretability methods that require + access to the model's forward pass from embeddings. + + Args: + **kwargs: A variable number of keyword arguments representing input features + as embeddings. Each keyword argument is a tensor or a tuple of tensors of + shape (batch_size, ...). + + Returns: + A dictionary with the following keys: + logit: a tensor of predicted logits. + y_prob: a tensor of predicted probabilities. + loss [optional]: a scalar tensor representing the final loss, if self.label_keys in kwargs. + y_true [optional]: a tensor representing the true labels, if self.label_keys in kwargs. + """ + raise NotImplementedError + + def get_embedding_model(self) -> nn.Module | None: + """Get the embedding model if applicable. This is used in pair with `forward_from_embedding`. + + Returns: + nn.Module | None: The embedding model or None if not applicable. + """ + raise NotImplementedError # ------------------------------------------------------------------ # Internal helpers diff --git a/pyhealth/models/stagenet.py b/pyhealth/models/stagenet.py index 973934904..7237221f4 100644 --- a/pyhealth/models/stagenet.py +++ b/pyhealth/models/stagenet.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, cast import torch import torch.nn as nn @@ -373,27 +373,9 @@ def __init__( len(self.feature_keys) * self.chunk_size * self.levels, output_size ) - self._deeplift_hooks = None - - # ------------------------------------------------------------------ - # Interpretability support (e.g., DeepLIFT) - # ------------------------------------------------------------------ - def set_deeplift_hooks(self, hooks) -> None: - """Backward-compatibility stub; activation swapping occurs in interpreters.""" - - self._deeplift_hooks = hooks - - def clear_deeplift_hooks(self) -> None: - """Backward-compatibility stub; activation swapping occurs in interpreters.""" - - self._deeplift_hooks = None - def forward_from_embedding( self, - feature_embeddings: Dict[str, torch.Tensor], - time_info: Optional[Dict[str, torch.Tensor]] = None, - mask_info: Optional[Dict[str, torch.Tensor]] = None, - **kwargs, + **kwargs: Dict[str, tuple[torch.Tensor, ...]], ) -> Dict[str, torch.Tensor]: """Forward pass starting from feature embeddings. @@ -403,17 +385,15 @@ def forward_from_embedding( interpolate in embedding space. Args: - feature_embeddings: Dictionary mapping feature keys to their - embedded representations. Each tensor should have shape - [batch_size, seq_len, embedding_dim]. - time_info: Optional dictionary mapping feature keys to their - time information tensors of shape [batch_size, seq_len]. - If None, uniform time intervals are assumed. - mask_info: Optional dictionary mapping feature keys to masks - of shape [batch_size, seq_len]. When provided, these masks - override the automatic mask derived from the embeddings. - **kwargs: Additional keyword arguments, must include the label - key for loss computation. + **kwargs: keyword arguments for the model. The keys must contain + all the feature keys and the label key. + + Feature keys should contain tuples of tensors (time, embedding, mask) + from temporal processors. + But the featurs keys can also contain just the embedding values + without time and mask for backward compatibility. + + The label key should contain the true labels for loss computation. Returns: A dictionary with the following keys: @@ -427,31 +407,50 @@ def forward_from_embedding( distance = [] for feature_key in self.feature_keys: - # Get embedded feature - x = feature_embeddings[feature_key].to(self.device) - # x: [batch, seq_len, embedding_dim] or 4D nested - - # Handle nested sequences (4D) by pooling over inner dim - # This matches forward() processing for consistency - if x.dim() == 4: # [batch, seq_len, inner_len, embedding_dim] - # Sum pool over inner dimension - x = x.sum(dim=2) # [batch, seq_len, embedding_dim] - - # Get time information if available - time = None - if time_info is not None and feature_key in time_info: - if time_info[feature_key] is not None: - time = time_info[feature_key].to(self.device) - # Ensure time is 2D [batch, seq_len] - if time.dim() == 1: - time = time.unsqueeze(0) - - # Create mask from embedded values unless an explicit one is provided - if mask_info is not None and feature_key in mask_info: - mask = mask_info[feature_key].to(self.device) + feature = kwargs[feature_key] + + # Unpack feature tuple + if isinstance(feature, tuple): + time, x, mask = cast(tuple[torch.Tensor, ...], feature) + # else: - mask = (x.sum(dim=-1) != 0).int() # [batch, seq_len] - + x = cast(torch.Tensor, feature) + time = None + mask = None + + x = x.to(self.device) + + if time is None: + import warnings + warnings.warn( + f"Feature '{feature_key}' does not have time " + f"intervals. StageNet's temporal modeling " + f"capabilities will be limited. Consider using " + f"StageNet format with time intervals for " + f"better performance.", + UserWarning, + ) + else: + time = time.to(self.device) + # Ensure time is 2D [batch, seq_len] + if time.dim() == 1: + time = time.unsqueeze(0) + + if mask is None: + import warnings + warnings.warn( + f"Feature '{feature_key}' does not have mask " + f"information. Default mask will be created from " + f"embedded values. But it may not be accurate.", + ) + mask = (x.abs().sum(dim=-1) != 0).int() + else: + mask = mask.to(self.device) + + if x.dim() == 4: + # Nested sequences: [batch, seq_len, inner_len, embedding_dim] + x = x.sum(dim=2) # Sum pool over inner dimension + # Pass through StageNet layer with embedded features last_output, _, cur_dis = self.stagenet[feature_key]( x, time=time, mask=mask @@ -460,44 +459,44 @@ def forward_from_embedding( patient_emb.append(last_output) distance.append(cur_dis) - # Concatenate all feature embeddings patient_emb = torch.cat(patient_emb, dim=1) - - # Register hook if needed for gradient tracking - if patient_emb.requires_grad: - patient_emb.register_hook(lambda grad: grad) - - # Pass through final classification layer + # (patient, label_size) logits = self.fc(patient_emb) - - # Obtain y_true, loss, y_prob - y_true = kwargs[self.label_key].to(self.device) - loss = self.get_loss_function()(logits, y_true) - y_prob = self.prepare_y_prob(logits) + results = { - "loss": loss, + "logit": logits, "y_prob": y_prob, - "y_true": y_true, - "logit": logits, } - + + # obtain y_true, loss, y_prob + if self.label_key in kwargs: + y_true = cast(torch.Tensor, kwargs[self.label_key]).to(self.device) + loss = self.get_loss_function()(logits, y_true) + results["loss"] = loss + results["y_true"] = y_true + # Optionally return embeddings if kwargs.get("embed", False): results["embed"] = patient_emb - return results - def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + def forward( + self, + **kwargs: Dict[str, torch.Tensor | tuple[torch.Tensor, ...]] + ) -> Dict[str, torch.Tensor]: """Forward propagation. - The label `kwargs[self.label_key]` is a list of labels for each - patient. - Args: - **kwargs: keyword arguments for the model. The keys must contain - all the feature keys and the label key. Feature keys should - contain tuples of (time, values) from temporal processors. + **kwargs: keyword arguments for the model. + + The keys must contain all the feature keys and the label key. + + Feature keys should contain tuples of tensors (time, values) from temporal processors. + But the featurs keys can also contain just the values without time + at the cost of degraded performance. + + The label key should contain the true labels for loss computation. Returns: A dictionary with the following keys: @@ -556,6 +555,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: x = x.sum(dim=2) # [batch, seq_len, embedding_dim] # Create mask from embedded values + # TODO: mask should be created in embedding model, fix later mask = (x.sum(dim=-1) != 0).int() # [batch, seq_len] # Move time to correct device if present @@ -577,86 +577,25 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: patient_emb = torch.cat(patient_emb, dim=1) # (patient, label_size) logits = self.fc(patient_emb) - - # obtain y_true, loss, y_prob - y_true = kwargs[self.label_key].to(self.device) - loss = self.get_loss_function()(logits, y_true) - y_prob = self.prepare_y_prob(logits) + results = { - "loss": loss, + "logit": logits, "y_prob": y_prob, - "y_true": y_true, - "logit": logits, } + + # obtain y_true, loss, y_prob + if self.label_key in kwargs: + y_true = cast(torch.Tensor, kwargs[self.label_key]).to(self.device) + loss = self.get_loss_function()(logits, y_true) + results["loss"] = loss + results["y_true"] = y_true + + # Optionally return embeddings if kwargs.get("embed", False): results["embed"] = patient_emb return results + + def get_embedding_model(self) -> nn.Module | None: + return self.embedding_model - -if __name__ == "__main__": - from pyhealth.datasets import create_sample_dataset - - samples = [ - { - "patient_id": "patient-0", - "visit_id": "visit-0", - "codes": ( - [0.0, 2.0, 1.3], - ["505800458", "50580045810", "50580045811"], - ), - "procedures": ( - [0.0, 1.5], - [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], - ), - "label": 1, - }, - { - "patient_id": "patient-0", - "visit_id": "visit-1", - "codes": ( - [0.0, 2.0, 1.3, 1.0, 2.0], - [ - "55154191800", - "551541928", - "55154192800", - "705182798", - "70518279800", - ], - ), - "procedures": ( - [0.0], - [["A04A", "B035", "C129"]], - ), - "label": 0, - }, - ] - - # dataset - dataset = create_sample_dataset( - samples=samples, - input_schema={ - "codes": "stagenet", - "procedures": "stagenet", - }, - output_schema={"label": "binary"}, - dataset_name="test", - ) - - # data loader - from pyhealth.datasets import get_dataloader - - train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) - - # model - model = StageNet(dataset=dataset) - - # data batch - data_batch = next(iter(train_loader)) - - # try the model - ret = model(**data_batch) - print(ret) - - # try loss backward - ret["loss"].backward()