Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions pyhealth/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import Callable, Any
from typing import Callable, Any, Optional
import inspect

import torch
Expand All @@ -16,9 +16,20 @@ class BaseModel(ABC, nn.Module):
Args:
dataset (SampleDataset): The dataset to train the model. It is used to query certain
information such as the set of all tokens.

Interpretability
--------
To use a model with interpretability methods, the model must implement a method
`forward_from_embedding` that takes in embeddings as input instead of raw features;
for the models that already take in dense features as input, this method can simply
call the existing `forward` method.

For certain gradient-based interpretability methods (e.g., DeepLIFT), the model must also
ensure all non-linearity (e.g. ReLU, Sigmoid, Softmax) are using nn.Module versions instead of
functional versions (e.g., F.relu, F.sigmoid, F.softmax) so that hooks can be registered properly.
"""

def __init__(self, dataset: SampleDataset = None):
def __init__(self, dataset: SampleDataset):
"""
Initializes the BaseModel.

Expand All @@ -44,6 +55,55 @@ def __init__(self, dataset: SampleDataset = None):
self._dummy_param = nn.Parameter(torch.empty(0))

self.mode = getattr(self, "mode", None) # legacy API

def forward(self,
**kwargs: dict[str, torch.Tensor | tuple[torch.Tensor, ...]]
) -> dict[str, torch.Tensor]:
"""Forward pass of the model.

Args:
**kwargs: A variable number of keyword arguments representing input features.
Each keyword argument is a tensor or a tuple of tensors of shape (batch_size, ...).

Returns:
A dictionary with the following keys:
logit: a tensor of predicted logits.
y_prob: a tensor of predicted probabilities.
loss [optional]: a scalar tensor representing the final loss, if self.label_keys in kwargs.
y_true [optional]: a tensor representing the true labels, if self.label_keys in kwargs.
"""
raise NotImplementedError

def forward_from_embedding(
self,
**kwargs: dict[str, torch.Tensor | tuple[torch.Tensor, ...]]
) -> dict[str, torch.Tensor]:
"""Forward pass of the model from embeddings.

This method should be implemented for interpretability methods that require
access to the model's forward pass from embeddings.

Args:
**kwargs: A variable number of keyword arguments representing input features
as embeddings. Each keyword argument is a tensor or a tuple of tensors of
shape (batch_size, ...).

Returns:
A dictionary with the following keys:
logit: a tensor of predicted logits.
y_prob: a tensor of predicted probabilities.
loss [optional]: a scalar tensor representing the final loss, if self.label_keys in kwargs.
y_true [optional]: a tensor representing the true labels, if self.label_keys in kwargs.
"""
raise NotImplementedError

def get_embedding_model(self) -> nn.Module | None:
"""Get the embedding model if applicable. This is used in pair with `forward_from_embedding`.

Returns:
nn.Module | None: The embedding model or None if not applicable.
"""
raise NotImplementedError

# ------------------------------------------------------------------
# Internal helpers
Expand Down
247 changes: 93 additions & 154 deletions pyhealth/models/stagenet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, cast

import torch
import torch.nn as nn
Expand Down Expand Up @@ -373,27 +373,9 @@ def __init__(
len(self.feature_keys) * self.chunk_size * self.levels, output_size
)

self._deeplift_hooks = None

# ------------------------------------------------------------------
# Interpretability support (e.g., DeepLIFT)
# ------------------------------------------------------------------
def set_deeplift_hooks(self, hooks) -> None:
"""Backward-compatibility stub; activation swapping occurs in interpreters."""

self._deeplift_hooks = hooks

def clear_deeplift_hooks(self) -> None:
"""Backward-compatibility stub; activation swapping occurs in interpreters."""

self._deeplift_hooks = None

def forward_from_embedding(
self,
feature_embeddings: Dict[str, torch.Tensor],
time_info: Optional[Dict[str, torch.Tensor]] = None,
mask_info: Optional[Dict[str, torch.Tensor]] = None,
**kwargs,
**kwargs: Dict[str, tuple[torch.Tensor, ...]],
) -> Dict[str, torch.Tensor]:
"""Forward pass starting from feature embeddings.

Expand All @@ -403,17 +385,15 @@ def forward_from_embedding(
interpolate in embedding space.

Args:
feature_embeddings: Dictionary mapping feature keys to their
embedded representations. Each tensor should have shape
[batch_size, seq_len, embedding_dim].
time_info: Optional dictionary mapping feature keys to their
time information tensors of shape [batch_size, seq_len].
If None, uniform time intervals are assumed.
mask_info: Optional dictionary mapping feature keys to masks
of shape [batch_size, seq_len]. When provided, these masks
override the automatic mask derived from the embeddings.
**kwargs: Additional keyword arguments, must include the label
key for loss computation.
**kwargs: keyword arguments for the model. The keys must contain
all the feature keys and the label key.

Feature keys should contain tuples of tensors (time, embedding, mask)
from temporal processors.
But the featurs keys can also contain just the embedding values
without time and mask for backward compatibility.

The label key should contain the true labels for loss computation.

Returns:
A dictionary with the following keys:
Expand All @@ -427,31 +407,50 @@ def forward_from_embedding(
distance = []

for feature_key in self.feature_keys:
# Get embedded feature
x = feature_embeddings[feature_key].to(self.device)
# x: [batch, seq_len, embedding_dim] or 4D nested

# Handle nested sequences (4D) by pooling over inner dim
# This matches forward() processing for consistency
if x.dim() == 4: # [batch, seq_len, inner_len, embedding_dim]
# Sum pool over inner dimension
x = x.sum(dim=2) # [batch, seq_len, embedding_dim]

# Get time information if available
time = None
if time_info is not None and feature_key in time_info:
if time_info[feature_key] is not None:
time = time_info[feature_key].to(self.device)
# Ensure time is 2D [batch, seq_len]
if time.dim() == 1:
time = time.unsqueeze(0)

# Create mask from embedded values unless an explicit one is provided
if mask_info is not None and feature_key in mask_info:
mask = mask_info[feature_key].to(self.device)
feature = kwargs[feature_key]

# Unpack feature tuple
if isinstance(feature, tuple):
time, x, mask = cast(tuple[torch.Tensor, ...], feature)
#
else:
mask = (x.sum(dim=-1) != 0).int() # [batch, seq_len]

x = cast(torch.Tensor, feature)
time = None
mask = None

x = x.to(self.device)

if time is None:
import warnings
warnings.warn(
f"Feature '{feature_key}' does not have time "
f"intervals. StageNet's temporal modeling "
f"capabilities will be limited. Consider using "
f"StageNet format with time intervals for "
f"better performance.",
UserWarning,
)
else:
time = time.to(self.device)
# Ensure time is 2D [batch, seq_len]
if time.dim() == 1:
time = time.unsqueeze(0)

if mask is None:
import warnings
warnings.warn(
f"Feature '{feature_key}' does not have mask "
f"information. Default mask will be created from "
f"embedded values. But it may not be accurate.",
)
mask = (x.abs().sum(dim=-1) != 0).int()
else:
mask = mask.to(self.device)

if x.dim() == 4:
# Nested sequences: [batch, seq_len, inner_len, embedding_dim]
x = x.sum(dim=2) # Sum pool over inner dimension

# Pass through StageNet layer with embedded features
last_output, _, cur_dis = self.stagenet[feature_key](
x, time=time, mask=mask
Expand All @@ -460,44 +459,44 @@ def forward_from_embedding(
patient_emb.append(last_output)
distance.append(cur_dis)

# Concatenate all feature embeddings
patient_emb = torch.cat(patient_emb, dim=1)

# Register hook if needed for gradient tracking
if patient_emb.requires_grad:
patient_emb.register_hook(lambda grad: grad)

# Pass through final classification layer
# (patient, label_size)
logits = self.fc(patient_emb)

# Obtain y_true, loss, y_prob
y_true = kwargs[self.label_key].to(self.device)
loss = self.get_loss_function()(logits, y_true)

y_prob = self.prepare_y_prob(logits)

results = {
"loss": loss,
"logit": logits,
"y_prob": y_prob,
"y_true": y_true,
"logit": logits,
}


# obtain y_true, loss, y_prob
if self.label_key in kwargs:
y_true = cast(torch.Tensor, kwargs[self.label_key]).to(self.device)
loss = self.get_loss_function()(logits, y_true)
results["loss"] = loss
results["y_true"] = y_true

# Optionally return embeddings
if kwargs.get("embed", False):
results["embed"] = patient_emb

return results

def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
def forward(
self,
**kwargs: Dict[str, torch.Tensor | tuple[torch.Tensor, ...]]
) -> Dict[str, torch.Tensor]:
"""Forward propagation.

The label `kwargs[self.label_key]` is a list of labels for each
patient.

Args:
**kwargs: keyword arguments for the model. The keys must contain
all the feature keys and the label key. Feature keys should
contain tuples of (time, values) from temporal processors.
**kwargs: keyword arguments for the model.

The keys must contain all the feature keys and the label key.

Feature keys should contain tuples of tensors (time, values) from temporal processors.
But the featurs keys can also contain just the values without time
at the cost of degraded performance.

The label key should contain the true labels for loss computation.

Returns:
A dictionary with the following keys:
Expand Down Expand Up @@ -556,6 +555,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
x = x.sum(dim=2) # [batch, seq_len, embedding_dim]

# Create mask from embedded values
# TODO: mask should be created in embedding model, fix later
mask = (x.sum(dim=-1) != 0).int() # [batch, seq_len]

# Move time to correct device if present
Expand All @@ -577,86 +577,25 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
patient_emb = torch.cat(patient_emb, dim=1)
# (patient, label_size)
logits = self.fc(patient_emb)

# obtain y_true, loss, y_prob
y_true = kwargs[self.label_key].to(self.device)
loss = self.get_loss_function()(logits, y_true)

y_prob = self.prepare_y_prob(logits)

results = {
"loss": loss,
"logit": logits,
"y_prob": y_prob,
"y_true": y_true,
"logit": logits,
}

# obtain y_true, loss, y_prob
if self.label_key in kwargs:
y_true = cast(torch.Tensor, kwargs[self.label_key]).to(self.device)
loss = self.get_loss_function()(logits, y_true)
results["loss"] = loss
results["y_true"] = y_true

# Optionally return embeddings
if kwargs.get("embed", False):
results["embed"] = patient_emb
return results

def get_embedding_model(self) -> nn.Module | None:
return self.embedding_model


if __name__ == "__main__":
from pyhealth.datasets import create_sample_dataset

samples = [
{
"patient_id": "patient-0",
"visit_id": "visit-0",
"codes": (
[0.0, 2.0, 1.3],
["505800458", "50580045810", "50580045811"],
),
"procedures": (
[0.0, 1.5],
[["A05B", "A05C", "A06A"], ["A11D", "A11E"]],
),
"label": 1,
},
{
"patient_id": "patient-0",
"visit_id": "visit-1",
"codes": (
[0.0, 2.0, 1.3, 1.0, 2.0],
[
"55154191800",
"551541928",
"55154192800",
"705182798",
"70518279800",
],
),
"procedures": (
[0.0],
[["A04A", "B035", "C129"]],
),
"label": 0,
},
]

# dataset
dataset = create_sample_dataset(
samples=samples,
input_schema={
"codes": "stagenet",
"procedures": "stagenet",
},
output_schema={"label": "binary"},
dataset_name="test",
)

# data loader
from pyhealth.datasets import get_dataloader

train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)

# model
model = StageNet(dataset=dataset)

# data batch
data_batch = next(iter(train_loader))

# try the model
ret = model(**data_batch)
print(ret)

# try loss backward
ret["loss"].backward()
Loading