Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ leaderboard/rtd_token.txt

# locally pre-trained models
pyhealth/medcode/pretrained_embeddings/kg_emb/examples/pretrained_model

data/physionet.org/

# VSCode settings
.vscode/
.vscode/
28 changes: 26 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,31 @@ Module 4: <pyhealth.trainer>
monitor="pr_auc_samples",
)

Module 5: <pyhealth.metrics>
Module 5: <pyhealth.models.generators>
""""""""""""""""""""""""""""""""""""""""""""

``pyhealth.models.generators`` provides **synthetic data generation models** for creating artificial EHR data while preserving statistical properties and medical code correlations.

**CorGAN (Correlation-capturing Generative Adversarial Network)**: Generates synthetic patient records that maintain realistic correlations between medical codes.

.. code-block:: python

from pyhealth.models.generators.corgan import CorGAN
from pyhealth.datasets import MIMIC3Dataset

# Load real EHR data
dataset = MIMIC3Dataset(root="./data", tables=["DIAGNOSES_ICD"])

# Train CorGAN model
corgan = CorGAN(dataset)
corgan.fit(autoencoder_epochs=10, gan_epochs=50)

# Generate synthetic patients
synthetic_data = corgan.generate(n_samples=1000)

**Example Script**: See ``examples/synthetic_data_generation_mimic3_corgan.py`` for a complete end-to-end example using native ICD-9 codes from MIMIC-III data.

Module 6: <pyhealth.metrics>
""""""""""""""""""""""""""""""""""""

``pyhealth.metrics`` provides several **common evaluation metrics** (refer to `Doc <https://pyhealth.readthedocs.io/en/latest/api/metrics.html>`_ and see what are available).
Expand Down Expand Up @@ -308,7 +332,7 @@ Module 5: <pyhealth.metrics>
codemap.map("50090539100")
# ['A10AC04', 'A10AD04', 'A10AB04']

5. Medical Code Tokenizer :speech_balloon:
6. Medical Code Tokenizer :speech_balloon:
---------------------------------------------

``pyhealth.tokenizer`` is used for transformations between string-based tokens and integer-based indices, based on the overall token space. We provide flexible functions to tokenize 1D, 2D and 3D lists. **This module can be used independently.**
Expand Down
143 changes: 143 additions & 0 deletions examples/generate_synthetic_mimic3_halo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#!/usr/bin/env python3
"""
Generate synthetic MIMIC-III patients using trained HALO checkpoint.
Outputs sequential visit data (temporal format) as pickle file.
"""

import os
import sys
sys.path.insert(0, '/u/jalenj4/PyHealth')
import argparse
import torch
import pickle
import pandas as pd
from pyhealth.datasets.halo_mimic3 import HALO_MIMIC3Dataset
from pyhealth.models.generators.halo import HALO
from pyhealth.models.generators.halo_resources.halo_config import HALOConfig

def main():
parser = argparse.ArgumentParser(description="Generate synthetic patients using trained HALO")
parser.add_argument("--checkpoint", required=True, help="Path to trained HALO checkpoint directory")
parser.add_argument("--output", required=True, help="Path to output pickle file")
parser.add_argument("--csv_output", help="Optional: Path to output CSV file (converts from pickle)")
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load vocabulary and configuration from checkpoint directory
pkl_data_dir = args.checkpoint + "pkl_data/"
print(f"\nLoading vocabulary from {pkl_data_dir}")

code_to_index = pickle.load(open(f"{pkl_data_dir}codeToIndex.pkl", "rb"))
index_to_code = pickle.load(open(f"{pkl_data_dir}indexToCode.pkl", "rb"))
id_to_label = pickle.load(open(f"{pkl_data_dir}idToLabel.pkl", "rb"))
train_dataset = pickle.load(open(f"{pkl_data_dir}trainDataset.pkl", "rb"))

code_vocab_size = len(code_to_index)
label_vocab_size = len(id_to_label)
special_vocab_size = 3
total_vocab_size = code_vocab_size + label_vocab_size + special_vocab_size

print(f"Vocabulary sizes:")
print(f" Code vocabulary: {code_vocab_size}")
print(f" Label vocabulary: {label_vocab_size}")
print(f" Total vocabulary: {total_vocab_size}")

# Create config with same parameters as training
config = HALOConfig(
total_vocab_size=total_vocab_size,
code_vocab_size=code_vocab_size,
label_vocab_size=label_vocab_size,
special_vocab_size=special_vocab_size,
n_positions=56,
n_ctx=48,
n_embd=768,
n_layer=12,
n_head=12,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
batch_size=48,
sample_batch_size=256, # Generation batch size
epoch=50,
pos_loss_weight=None,
lr=1e-4
)

# Create a minimal dataset object (just for interface compatibility)
class MinimalDataset:
def __init__(self, pkl_data_dir):
self.pkl_data_dir = pkl_data_dir

dataset = MinimalDataset(pkl_data_dir)

# Load trained model
print(f"\nLoading checkpoint from {args.checkpoint}halo_model")
from pyhealth.models.generators.halo_resources.halo_model import HALOModel

model = HALOModel(config).to(device)
checkpoint = torch.load(f'{args.checkpoint}halo_model', map_location=device)
model.load_state_dict(checkpoint['model'])
model.eval()

print("Model loaded successfully")

# Generate synthetic patients
n_samples = 10000 # Generate 10k synthetic patients
print(f"\nGenerating {n_samples} synthetic patients...")
print("This will take 1-2 hours...")

# Create HALO instance for generation
halo = HALO(dataset=dataset, config=config, save_dir=args.checkpoint, train_on_init=False)
halo.model = model
halo.train_ehr_dataset = train_dataset[:n_samples] # Limit to 10k
halo.index_to_code = index_to_code

# Generate synthetic data
output_dir = os.path.dirname(args.output)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)

halo.synthesize_dataset(pkl_save_dir=output_dir + "/")

# Move the generated file to the requested output path
generated_file = os.path.join(output_dir, "haloDataset.pkl")
if generated_file != args.output:
os.rename(generated_file, args.output)

print(f"\nGeneration complete!")
print(f"Output saved to: {args.output}")

# Load and print statistics
synthetic_data = pickle.load(open(args.output, "rb"))
print(f"\nSynthetic data statistics:")
print(f" Total patients: {len(synthetic_data)}")
print(f" Avg visits per patient: {sum(len(p['visits']) for p in synthetic_data) / len(synthetic_data):.2f}")
print(f" Total visits: {sum(len(p['visits']) for p in synthetic_data)}")
print(f" Avg codes per visit: {sum(len(c) for p in synthetic_data for v in p['visits'] for c in v) / sum(len(p['visits']) for p in synthetic_data):.2f}")

# Optionally convert to CSV format
if args.csv_output:
print(f"\nConverting to CSV format: {args.csv_output}")
convert_to_csv(synthetic_data, index_to_code, args.csv_output)

def convert_to_csv(synthetic_data, index_to_code, csv_path):
"""Convert pickle format to CSV with temporal information."""
records = []
for patient_idx, patient in enumerate(synthetic_data):
patient_id = f"SYNTHETIC_{patient_idx+1:06d}"
for visit_num, visit in enumerate(patient['visits'], 1):
for code_idx in visit:
icd9_code = index_to_code[code_idx]
records.append({
'SUBJECT_ID': patient_id,
'VISIT_NUM': visit_num,
'ICD9_CODE': icd9_code
})

df = pd.DataFrame(records)
df.to_csv(csv_path, index=False)
print(f"CSV saved with {len(df)} records")

if __name__ == '__main__':
main()
150 changes: 150 additions & 0 deletions examples/halo_mimic3_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#!/usr/bin/env python3
"""
Train HALO on MIMIC-III dataset.
Example script demonstrating HALO training with configurable parameters.
"""

import os
import argparse
import torch
import pickle
import shutil
from pyhealth.datasets.halo_mimic3 import HALO_MIMIC3Dataset
from pyhealth.models.generators.halo import HALO
from pyhealth.models.generators.halo_resources.halo_config import HALOConfig


def main():
parser = argparse.ArgumentParser(description="Train HALO on MIMIC-III dataset")
parser.add_argument("--mimic3_dir", required=True, help="Path to MIMIC-III data directory")
parser.add_argument("--output_dir", required=True, help="Directory for saving checkpoints and results")
parser.add_argument("--epochs", type=int, default=80, help="Number of training epochs (default: 80)")
parser.add_argument("--batch_size", type=int, default=48, help="Training batch size (default: 48)")
parser.add_argument("--learning_rate", type=float, default=0.0001, help="Learning rate (default: 0.0001)")
parser.add_argument("--save_best", action="store_true", help="Save best checkpoint (lowest validation loss)")
parser.add_argument("--save_final", action="store_true", help="Save final checkpoint after training")
args = parser.parse_args()

# Setup directories
pkl_data_dir = os.path.join(args.output_dir, "pkl_data/")
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(pkl_data_dir, exist_ok=True)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}", flush=True)
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}", flush=True)
print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB", flush=True)

# Load and preprocess dataset
print(f"\n{'='*60}", flush=True)
print("Loading and preprocessing MIMIC-III dataset...", flush=True)
print(f"{'='*60}", flush=True)
print(f"Data directory: {args.mimic3_dir}", flush=True)
print(f"Output directory: {args.output_dir}", flush=True)

dataset = HALO_MIMIC3Dataset(
mimic3_dir=args.mimic3_dir,
pkl_data_dir=pkl_data_dir,
gzip=False
)

print(f"\n{'='*60}", flush=True)
print("Dataset preprocessing complete!", flush=True)
print(f"{'='*60}", flush=True)

# Load vocabulary sizes
code_to_index = pickle.load(open(f"{pkl_data_dir}codeToIndex.pkl", "rb"))
id_to_label = pickle.load(open(f"{pkl_data_dir}idToLabel.pkl", "rb"))

code_vocab_size = len(code_to_index)
label_vocab_size = len(id_to_label)
special_vocab_size = 3
total_vocab_size = code_vocab_size + label_vocab_size + special_vocab_size

print(f"Vocabulary sizes:", flush=True)
print(f" Code vocabulary: {code_vocab_size}", flush=True)
print(f" Label vocabulary: {label_vocab_size}", flush=True)
print(f" Special tokens: {special_vocab_size}", flush=True)
print(f" Total vocabulary: {total_vocab_size}", flush=True)

# HALO configuration
print(f"\n{'='*60}", flush=True)
print("Initializing HALO configuration", flush=True)
print(f"{'='*60}", flush=True)

config = HALOConfig(
total_vocab_size=total_vocab_size,
code_vocab_size=code_vocab_size,
label_vocab_size=label_vocab_size,
special_vocab_size=special_vocab_size,
n_positions=56,
n_ctx=48,
n_embd=768,
n_layer=12,
n_head=12,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
batch_size=args.batch_size,
sample_batch_size=256,
epoch=args.epochs,
pos_loss_weight=None,
lr=args.learning_rate
)

print("Configuration:", flush=True)
print(f" Embedding dim: {config.n_embd}", flush=True)
print(f" Layers: {config.n_layer}", flush=True)
print(f" Attention heads: {config.n_head}", flush=True)
print(f" Batch size: {config.batch_size}", flush=True)
print(f" Epochs: {config.epoch}", flush=True)
print(f" Learning rate: {config.lr}", flush=True)

# Train HALO model
print(f"\n{'='*60}", flush=True)
print("Training HALO model...", flush=True)
print(f"{'='*60}", flush=True)
print(f"Training for {args.epochs} epochs", flush=True)
print(f"Progress updates every 1,000 iterations", flush=True)
print(f"Checkpoints saved when validation loss improves", flush=True)
print(f"{'='*60}\n", flush=True)

model = HALO(
dataset=dataset,
config=config,
save_dir=args.output_dir,
train_on_init=True
)

print(f"\n{'='*60}", flush=True)
print("TRAINING COMPLETE!", flush=True)
print(f"{'='*60}", flush=True)

# Save final checkpoint if requested
if args.save_final:
final_state = {
'model': model.model.state_dict(),
'optimizer': model.optimizer.state_dict(),
'iteration': 'final',
'epoch': config.epoch
}
torch.save(final_state, os.path.join(args.output_dir, 'halo_model_final'))
print(f"Final checkpoint saved to: {args.output_dir}/halo_model_final", flush=True)

# Copy best checkpoint if requested
if args.save_best:
best_path = os.path.join(args.output_dir, 'halo_model')
if os.path.exists(best_path):
shutil.copy(best_path, os.path.join(args.output_dir, 'halo_model_best'))
print(f"Best checkpoint copied to: {args.output_dir}/halo_model_best", flush=True)

print(f"Vocabulary files saved to: {pkl_data_dir}", flush=True)
print(f"\nTraining artifacts:", flush=True)
print(f" - Checkpoints: {args.output_dir}", flush=True)
print(f" - Vocabulary: {pkl_data_dir}", flush=True)
print(f"\nNext step: Generate synthetic data using trained checkpoint", flush=True)


if __name__ == "__main__":
main()
36 changes: 36 additions & 0 deletions examples/slurm/generate_halo_mimic3.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/bin/bash
#SBATCH --job-name=halo_generate
#SBATCH --partition=gpu
#SBATCH --gres=gpu:1
#SBATCH --cpus-per-task=4
#SBATCH --mem=16G
#SBATCH --time=2:00:00
#SBATCH --output=/scratch/%u/logs/halo_generate_%j.out

# Canonical SLURM script for generating synthetic data with HALO
# Adjust paths, partition names, and resource allocations for your cluster

# Navigate to working directory
cd "${SLURM_SUBMIT_DIR}" || exit 1

echo "SLURM_JOB_ID: ${SLURM_JOB_ID}"
echo "Starting HALO generation at: $(date)"
echo "========================================"

# Activate your Python environment
# Example: conda activate pyhealth
# Example: source venv/bin/activate

# Set Python path if needed
# export PYTHONPATH=/path/to/PyHealth:${PYTHONPATH}

# Generation script
python examples/generate_synthetic_mimic3_halo.py \
--checkpoint_dir /scratch/jalenj4/halo_results/ \
--checkpoint_file halo_model_best \
--output_pkl /scratch/jalenj4/halo_results/synthetic/halo_synthetic_10k.pkl \
--output_csv /scratch/jalenj4/halo_results/synthetic/halo_synthetic_10k.csv \
--n_samples 10000

echo "========================================"
echo "Generation completed at: $(date)"
Loading