Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions pyhealth/datasets/ehrshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ class EHRShotDataset(BaseDataset):
tables (List[str]): A list of tables to be included in the dataset.
dataset_name (Optional[str]): The name of the dataset.
config_path (Optional[str]): The path to the configuration file.
Examples:
>>> from pyhealth.datasets import EHRShotDataset
>>> # Load EHRShot dataset with benchmark tables
>>> dataset = EHRShotDataset(
... root="/path/to/ehrshot/data",
... tables=["ehrshot", "chexpert", "guo_icu", "lab_anemia"],
... )
>>> dataset.stats()
"""

def __init__(
Expand All @@ -28,7 +37,7 @@ def __init__(
tables: List[str],
dataset_name: Optional[str] = None,
config_path: Optional[str] = None,
**kwargs
**kwargs,
) -> None:
if config_path is None:
logger.info("No config path provided, using default config")
Expand All @@ -38,6 +47,6 @@ def __init__(
tables=tables,
dataset_name=dataset_name or "ehrshot",
config_path=config_path,
**kwargs
**kwargs,
)
return
15 changes: 12 additions & 3 deletions pyhealth/datasets/mimic3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ class MIMIC3Dataset(BaseDataset):
tables (List[str]): A list of tables to be included in the dataset.
dataset_name (Optional[str]): The name of the dataset.
config_path (Optional[str]): The path to the configuration file.
Examples:
>>> from pyhealth.datasets import MIMIC3Dataset
>>> # Load MIMIC-III dataset with clinical tables
>>> dataset = MIMIC3Dataset(
... root="/path/to/mimic-iii/1.4",
... tables=["diagnoses_icd", "procedures_icd", "labevents"],
... )
>>> dataset.stats()
"""

def __init__(
Expand All @@ -30,7 +39,7 @@ def __init__(
tables: List[str],
dataset_name: Optional[str] = None,
config_path: Optional[str] = None,
**kwargs
**kwargs,
) -> None:
"""
Initializes the MIMIC4Dataset with the specified parameters.
Expand All @@ -57,14 +66,14 @@ def __init__(
tables=tables,
dataset_name=dataset_name or "mimic3",
config_path=config_path,
**kwargs
**kwargs,
)
return

def preprocess_noteevents(self, df: pl.LazyFrame) -> pl.LazyFrame:
"""
Table-specific preprocess function which will be called by BaseDataset.load_table().
Preprocesses the noteevents table by ensuring that the charttime column
is populated. If charttime is null, it uses chartdate with a default
time of 00:00:00.
Expand Down
49 changes: 49 additions & 0 deletions pyhealth/datasets/mimic4.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ class MIMIC4EHRDataset(BaseDataset):
tables (List[str]): A list of tables to be included in the dataset.
dataset_name (Optional[str]): The name of the dataset.
config_path (Optional[str]): The path to the configuration file.

Examples:
>>> from pyhealth.datasets import MIMIC4EHRDataset
>>> # Load MIMIC-IV EHR dataset with clinical tables
>>> dataset = MIMIC4EHRDataset(
... root="/path/to/mimic-iv/2.2",
... tables=["diagnoses_icd", "procedures_icd", "labevents"],
... )
>>> dataset.stats()
"""

def __init__(
Expand Down Expand Up @@ -83,6 +92,15 @@ class MIMIC4NoteDataset(BaseDataset):
tables (List[str]): A list of tables to be included in the dataset.
dataset_name (Optional[str]): The name of the dataset.
config_path (Optional[str]): The path to the configuration file.

Examples:
>>> from pyhealth.datasets import MIMIC4NoteDataset
>>> # Load MIMIC-IV clinical notes dataset
>>> dataset = MIMIC4NoteDataset(
... root="/path/to/mimic-iv-note/2.2",
... tables=["discharge", "radiology"],
... )
>>> dataset.stats()
"""

def __init__(
Expand Down Expand Up @@ -135,6 +153,15 @@ class MIMIC4CXRDataset(BaseDataset):
tables (List[str]): A list of tables to be included in the dataset.
dataset_name (Optional[str]): The name of the dataset.
config_path (Optional[str]): The path to the configuration file.

Examples:
>>> from pyhealth.datasets import MIMIC4CXRDataset
>>> # Load MIMIC-CXR dataset with chest X-ray images and labels
>>> dataset = MIMIC4CXRDataset(
... root="/path/to/mimic-cxr/2.0.0",
... tables=["metadata", "chexpert"],
... )
>>> dataset.stats()
"""

def __init__(
Expand Down Expand Up @@ -217,6 +244,28 @@ class MIMIC4Dataset(BaseDataset):
cxr_config_path: Path to the CXR config file
dataset_name: Name of the dataset
dev: Whether to enable dev mode (limit to 1000 patients)

Examples:
>>> from pyhealth.datasets import MIMIC4Dataset
>>> # Load unified MIMIC-IV dataset with EHR, notes, and CXR data
>>> dataset = MIMIC4Dataset(
... ehr_root="/path/to/mimic-iv/2.2",
... note_root="/path/to/mimic-iv-note/2.2",
... cxr_root="/path/to/mimic-cxr/2.0.0",
... ehr_tables=["diagnoses_icd", "procedures_icd", "labevents"],
... note_tables=["discharge", "radiology"],
... cxr_tables=["metadata", "chexpert"],
... )
>>> dataset.stats()
>>>
>>> # Load with only EHR and notes (without CXR)
>>> dataset = MIMIC4Dataset(
... ehr_root="/path/to/mimic-iv/2.2",
... note_root="/path/to/mimic-iv-note/2.2",
... ehr_tables=["diagnoses_icd", "labevents"],
... note_tables=["discharge"],
... )
>>> dataset.stats()
"""

def __init__(
Expand Down
39 changes: 20 additions & 19 deletions pyhealth/tasks/benchmark_ehrshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,37 @@


class BenchmarkEHRShot(BaseTask):
"""Benchmark predictive tasks using EHRShot."""
"""Benchmark predictive tasks using EHRShot.

Examples:
>>> from pyhealth.datasets import EHRShotDataset
>>> from pyhealth.tasks import BenchmarkEHRShot
>>> dataset = EHRShotDataset(
... root="/path/to/ehrshot/data",
... tables=["ehrshot", "splits", "guo_icu"],
... )
>>> task = BenchmarkEHRShot(task="guo_icu")
>>> samples = dataset.set_task(task)
"""

tasks = {
"operational_outcomes": [
"guo_los",
"guo_readmission",
"guo_icu"
],
"operational_outcomes": ["guo_los", "guo_readmission", "guo_icu"],
"lab_values": [
"lab_thrombocytopenia",
"lab_hyperkalemia",
"lab_hypoglycemia",
"lab_hyponatremia",
"lab_anemia"
"lab_anemia",
],
"new_diagnoses": [
"new_hypertension",
"new_hyperlipidemia",
"new_pancan",
"new_celiac",
"new_lupus",
"new_acutemi"
"new_acutemi",
],
"chexpert": [
"chexpert"
]
"chexpert": ["chexpert"],
}

def __init__(self, task: str, omop_tables: Optional[List[str]] = None) -> None:
Expand All @@ -53,13 +58,13 @@ def __init__(self, task: str, omop_tables: Optional[List[str]] = None) -> None:
self.output_schema = {"label": "binary"}
elif task in self.tasks["chexpert"]:
self.output_schema = {"label": "multilabel"}

def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:
if self.omop_tables is None:
return df
filtered_df = df.filter(
(pl.col("event_type") != "ehrshot") |
(pl.col("ehrshot/omop_table").is_in(self.omop_tables))
(pl.col("event_type") != "ehrshot")
| (pl.col("ehrshot/omop_table").is_in(self.omop_tables))
)
return filtered_df

Expand All @@ -81,9 +86,5 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
label_value = int(label_value)
label_value = [i for i in range(14) if (label_value >> i) & 1]
label_value = [13 - i for i in label_value[::-1]]
samples.append({
"feature": codes,
"label": label_value,
"split": split
})
samples.append({"feature": codes, "label": label_value, "split": split})
return samples
7 changes: 7 additions & 0 deletions pyhealth/tasks/bmd_hs_disease_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ class BMDHSDiseaseClassification(BaseTask):
output_schema (Dict[str, str]): The output schema specifying the output
format. Contains:
- "diagnosis": "multilabel"
Examples:
>>> from pyhealth.datasets import BMDHSDataset
>>> from pyhealth.tasks import BMDHSDiseaseClassification
>>> dataset = BMDHSDataset(root="/path/to/bmd_hs")
>>> task = BMDHSDiseaseClassification()
>>> samples = dataset.set_task(task)
"""

task_name: str = "BMDHSDiseaseClassification"
Expand Down
13 changes: 12 additions & 1 deletion pyhealth/tasks/chestxray14_binary_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Author:
Eric Schrock (ejs9@illinois.edu)
"""

import logging
from typing import Dict, List

Expand All @@ -24,6 +25,7 @@

logger = logging.getLogger(__name__)


class ChestXray14BinaryClassification(BaseTask):
"""
A PyHealth task class for binary classification of a specific disease
Expand All @@ -34,7 +36,15 @@ class ChestXray14BinaryClassification(BaseTask):
input_schema (Dict[str, str]): The schema for the task input.
output_schema (Dict[str, str]): The schema for the task output.
disease (str): The disease label to classify.

Examples:
>>> from pyhealth.datasets import ChestXray14Dataset
>>> from pyhealth.tasks import ChestXray14BinaryClassification
>>> dataset = ChestXray14Dataset(root="/path/to/chestxray14")
>>> task = ChestXray14BinaryClassification(disease="pneumonia")
>>> samples = dataset.set_task(task)
"""

task_name: str = "ChestXray14BinaryClassification"
input_schema: Dict[str, str] = {"image": "image"}
output_schema: Dict[str, str] = {"label": "binary"}
Expand All @@ -50,7 +60,8 @@ def __init__(self, disease: str) -> None:
Raises:
ValueError: If the specified disease is not a valid class in the dataset.
"""
from pyhealth.datasets import ChestXray14Dataset # Avoid circular import
from pyhealth.datasets import ChestXray14Dataset # Avoid circular import

if disease not in ChestXray14Dataset.classes:
msg = f"Invalid disease: '{disease}'! Must be one of {ChestXray14Dataset.classes}."
logger.error(msg)
Expand Down
24 changes: 22 additions & 2 deletions pyhealth/tasks/chestxray14_multilabel_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Author:
Eric Schrock (ejs9@illinois.edu)
"""

import logging
from typing import Dict, List

Expand All @@ -24,6 +25,7 @@

logger = logging.getLogger(__name__)


class ChestXray14MultilabelClassification(BaseTask):
"""
A PyHealth task class for multilabel classification of all fourteen diseases
Expand All @@ -33,7 +35,15 @@ class ChestXray14MultilabelClassification(BaseTask):
task_name (str): The name of the task.
input_schema (Dict[str, str]): The schema for the task input.
output_schema (Dict[str, str]): The schema for the task output.
Examples:
>>> from pyhealth.datasets import ChestXray14Dataset
>>> from pyhealth.tasks import ChestXray14MultilabelClassification
>>> dataset = ChestXray14Dataset(root="/path/to/chestxray14")
>>> task = ChestXray14MultilabelClassification()
>>> samples = dataset.set_task(task)
"""

task_name: str = "ChestXray14MultilabelClassification"
input_schema: Dict[str, str] = {"image": "image"}
output_schema: Dict[str, str] = {"labels": "multilabel"}
Expand All @@ -54,8 +64,18 @@ def __call__(self, patient: Patient) -> List[Dict]:
events: List[Event] = patient.get_events(event_type="chestxray14")

samples = []
from pyhealth.datasets import ChestXray14Dataset # Avoid circular import
from pyhealth.datasets import ChestXray14Dataset # Avoid circular import

for event in events:
samples.append({"image": event["path"], "labels": [disease for disease in ChestXray14Dataset.classes if int(event[disease])]})
samples.append(
{
"image": event["path"],
"labels": [
disease
for disease in ChestXray14Dataset.classes
if int(event[disease])
],
}
)

return samples
7 changes: 7 additions & 0 deletions pyhealth/tasks/covid19_cxr_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ class COVID19CXRClassification(BaseTask):
input format. Contains a single key "image" with value "image".
output_schema (Dict[str, str]): The output schema specifying the output
format. Contains a single key "disease" with value "multiclass".
Examples:
>>> from pyhealth.datasets import COVID19CXRDataset
>>> from pyhealth.tasks import COVID19CXRClassification
>>> dataset = COVID19CXRDataset(root="/path/to/covid19_cxr")
>>> task = COVID19CXRClassification()
>>> samples = dataset.set_task(task)
"""

task_name: str = "COVID19CXRClassification"
Expand Down
10 changes: 10 additions & 0 deletions pyhealth/tasks/drug_recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,16 @@ class DrugRecommendationMIMIC4(BaseTask):
- drugs_hist: Nested list of drug codes from history (current visit excluded)
output_schema (Dict[str, str]): The schema for output data:
- drugs: List of drugs to predict for current visit
Examples:
>>> from pyhealth.datasets import MIMIC4EHRDataset
>>> from pyhealth.tasks import DrugRecommendationMIMIC4
>>> dataset = MIMIC4EHRDataset(
... root="/path/to/mimic-iv/2.2",
... tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
... )
>>> task = DrugRecommendationMIMIC4()
>>> samples = dataset.set_task(task)
"""

task_name: str = "DrugRecommendationMIMIC4"
Expand Down
10 changes: 10 additions & 0 deletions pyhealth/tasks/in_hospital_mortality_mimic4.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ class InHospitalMortalityMIMIC4(BaseTask):
- labs: A timeseries of lab results.
output_schema (Dict[str, str]): The schema for output data, which includes:
- mortality: A binary indicator of mortality.

Examples:
>>> from pyhealth.datasets import MIMIC4EHRDataset
>>> from pyhealth.tasks import InHospitalMortalityMIMIC4
>>> dataset = MIMIC4EHRDataset(
... root="/path/to/mimic-iv/2.2",
... tables=["labevents"],
... )
>>> task = InHospitalMortalityMIMIC4()
>>> samples = dataset.set_task(task)
"""

task_name: str = "InHospitalMortalityMIMIC4"
Expand Down
Loading