diff --git a/pyhealth/datasets/ehrshot.py b/pyhealth/datasets/ehrshot.py index 6b822dc74..878295bea 100644 --- a/pyhealth/datasets/ehrshot.py +++ b/pyhealth/datasets/ehrshot.py @@ -20,6 +20,15 @@ class EHRShotDataset(BaseDataset): tables (List[str]): A list of tables to be included in the dataset. dataset_name (Optional[str]): The name of the dataset. config_path (Optional[str]): The path to the configuration file. + + Examples: + >>> from pyhealth.datasets import EHRShotDataset + >>> # Load EHRShot dataset with benchmark tables + >>> dataset = EHRShotDataset( + ... root="/path/to/ehrshot/data", + ... tables=["ehrshot", "chexpert", "guo_icu", "lab_anemia"], + ... ) + >>> dataset.stats() """ def __init__( @@ -28,7 +37,7 @@ def __init__( tables: List[str], dataset_name: Optional[str] = None, config_path: Optional[str] = None, - **kwargs + **kwargs, ) -> None: if config_path is None: logger.info("No config path provided, using default config") @@ -38,6 +47,6 @@ def __init__( tables=tables, dataset_name=dataset_name or "ehrshot", config_path=config_path, - **kwargs + **kwargs, ) return diff --git a/pyhealth/datasets/mimic3.py b/pyhealth/datasets/mimic3.py index 22ca79d5c..7e569d2f3 100644 --- a/pyhealth/datasets/mimic3.py +++ b/pyhealth/datasets/mimic3.py @@ -22,6 +22,15 @@ class MIMIC3Dataset(BaseDataset): tables (List[str]): A list of tables to be included in the dataset. dataset_name (Optional[str]): The name of the dataset. config_path (Optional[str]): The path to the configuration file. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> # Load MIMIC-III dataset with clinical tables + >>> dataset = MIMIC3Dataset( + ... root="/path/to/mimic-iii/1.4", + ... tables=["diagnoses_icd", "procedures_icd", "labevents"], + ... ) + >>> dataset.stats() """ def __init__( @@ -30,7 +39,7 @@ def __init__( tables: List[str], dataset_name: Optional[str] = None, config_path: Optional[str] = None, - **kwargs + **kwargs, ) -> None: """ Initializes the MIMIC4Dataset with the specified parameters. @@ -57,14 +66,14 @@ def __init__( tables=tables, dataset_name=dataset_name or "mimic3", config_path=config_path, - **kwargs + **kwargs, ) return def preprocess_noteevents(self, df: pl.LazyFrame) -> pl.LazyFrame: """ Table-specific preprocess function which will be called by BaseDataset.load_table(). - + Preprocesses the noteevents table by ensuring that the charttime column is populated. If charttime is null, it uses chartdate with a default time of 00:00:00. diff --git a/pyhealth/datasets/mimic4.py b/pyhealth/datasets/mimic4.py index 6a01033be..9d1aa55d8 100644 --- a/pyhealth/datasets/mimic4.py +++ b/pyhealth/datasets/mimic4.py @@ -40,6 +40,15 @@ class MIMIC4EHRDataset(BaseDataset): tables (List[str]): A list of tables to be included in the dataset. dataset_name (Optional[str]): The name of the dataset. config_path (Optional[str]): The path to the configuration file. + + Examples: + >>> from pyhealth.datasets import MIMIC4EHRDataset + >>> # Load MIMIC-IV EHR dataset with clinical tables + >>> dataset = MIMIC4EHRDataset( + ... root="/path/to/mimic-iv/2.2", + ... tables=["diagnoses_icd", "procedures_icd", "labevents"], + ... ) + >>> dataset.stats() """ def __init__( @@ -83,6 +92,15 @@ class MIMIC4NoteDataset(BaseDataset): tables (List[str]): A list of tables to be included in the dataset. dataset_name (Optional[str]): The name of the dataset. config_path (Optional[str]): The path to the configuration file. + + Examples: + >>> from pyhealth.datasets import MIMIC4NoteDataset + >>> # Load MIMIC-IV clinical notes dataset + >>> dataset = MIMIC4NoteDataset( + ... root="/path/to/mimic-iv-note/2.2", + ... tables=["discharge", "radiology"], + ... ) + >>> dataset.stats() """ def __init__( @@ -135,6 +153,15 @@ class MIMIC4CXRDataset(BaseDataset): tables (List[str]): A list of tables to be included in the dataset. dataset_name (Optional[str]): The name of the dataset. config_path (Optional[str]): The path to the configuration file. + + Examples: + >>> from pyhealth.datasets import MIMIC4CXRDataset + >>> # Load MIMIC-CXR dataset with chest X-ray images and labels + >>> dataset = MIMIC4CXRDataset( + ... root="/path/to/mimic-cxr/2.0.0", + ... tables=["metadata", "chexpert"], + ... ) + >>> dataset.stats() """ def __init__( @@ -217,6 +244,28 @@ class MIMIC4Dataset(BaseDataset): cxr_config_path: Path to the CXR config file dataset_name: Name of the dataset dev: Whether to enable dev mode (limit to 1000 patients) + + Examples: + >>> from pyhealth.datasets import MIMIC4Dataset + >>> # Load unified MIMIC-IV dataset with EHR, notes, and CXR data + >>> dataset = MIMIC4Dataset( + ... ehr_root="/path/to/mimic-iv/2.2", + ... note_root="/path/to/mimic-iv-note/2.2", + ... cxr_root="/path/to/mimic-cxr/2.0.0", + ... ehr_tables=["diagnoses_icd", "procedures_icd", "labevents"], + ... note_tables=["discharge", "radiology"], + ... cxr_tables=["metadata", "chexpert"], + ... ) + >>> dataset.stats() + >>> + >>> # Load with only EHR and notes (without CXR) + >>> dataset = MIMIC4Dataset( + ... ehr_root="/path/to/mimic-iv/2.2", + ... note_root="/path/to/mimic-iv-note/2.2", + ... ehr_tables=["diagnoses_icd", "labevents"], + ... note_tables=["discharge"], + ... ) + >>> dataset.stats() """ def __init__( diff --git a/pyhealth/tasks/benchmark_ehrshot.py b/pyhealth/tasks/benchmark_ehrshot.py index a528b0db9..1d72ec2f1 100644 --- a/pyhealth/tasks/benchmark_ehrshot.py +++ b/pyhealth/tasks/benchmark_ehrshot.py @@ -6,20 +6,27 @@ class BenchmarkEHRShot(BaseTask): - """Benchmark predictive tasks using EHRShot.""" + """Benchmark predictive tasks using EHRShot. + + Examples: + >>> from pyhealth.datasets import EHRShotDataset + >>> from pyhealth.tasks import BenchmarkEHRShot + >>> dataset = EHRShotDataset( + ... root="/path/to/ehrshot/data", + ... tables=["ehrshot", "splits", "guo_icu"], + ... ) + >>> task = BenchmarkEHRShot(task="guo_icu") + >>> samples = dataset.set_task(task) + """ tasks = { - "operational_outcomes": [ - "guo_los", - "guo_readmission", - "guo_icu" - ], + "operational_outcomes": ["guo_los", "guo_readmission", "guo_icu"], "lab_values": [ "lab_thrombocytopenia", "lab_hyperkalemia", "lab_hypoglycemia", "lab_hyponatremia", - "lab_anemia" + "lab_anemia", ], "new_diagnoses": [ "new_hypertension", @@ -27,11 +34,9 @@ class BenchmarkEHRShot(BaseTask): "new_pancan", "new_celiac", "new_lupus", - "new_acutemi" + "new_acutemi", ], - "chexpert": [ - "chexpert" - ] + "chexpert": ["chexpert"], } def __init__(self, task: str, omop_tables: Optional[List[str]] = None) -> None: @@ -53,13 +58,13 @@ def __init__(self, task: str, omop_tables: Optional[List[str]] = None) -> None: self.output_schema = {"label": "binary"} elif task in self.tasks["chexpert"]: self.output_schema = {"label": "multilabel"} - + def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: if self.omop_tables is None: return df filtered_df = df.filter( - (pl.col("event_type") != "ehrshot") | - (pl.col("ehrshot/omop_table").is_in(self.omop_tables)) + (pl.col("event_type") != "ehrshot") + | (pl.col("ehrshot/omop_table").is_in(self.omop_tables)) ) return filtered_df @@ -81,9 +86,5 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: label_value = int(label_value) label_value = [i for i in range(14) if (label_value >> i) & 1] label_value = [13 - i for i in label_value[::-1]] - samples.append({ - "feature": codes, - "label": label_value, - "split": split - }) + samples.append({"feature": codes, "label": label_value, "split": split}) return samples diff --git a/pyhealth/tasks/bmd_hs_disease_classification.py b/pyhealth/tasks/bmd_hs_disease_classification.py index c7c68586f..79d4a06f9 100644 --- a/pyhealth/tasks/bmd_hs_disease_classification.py +++ b/pyhealth/tasks/bmd_hs_disease_classification.py @@ -33,6 +33,13 @@ class BMDHSDiseaseClassification(BaseTask): output_schema (Dict[str, str]): The output schema specifying the output format. Contains: - "diagnosis": "multilabel" + + Examples: + >>> from pyhealth.datasets import BMDHSDataset + >>> from pyhealth.tasks import BMDHSDiseaseClassification + >>> dataset = BMDHSDataset(root="/path/to/bmd_hs") + >>> task = BMDHSDiseaseClassification() + >>> samples = dataset.set_task(task) """ task_name: str = "BMDHSDiseaseClassification" diff --git a/pyhealth/tasks/chestxray14_binary_classification.py b/pyhealth/tasks/chestxray14_binary_classification.py index 3b28b54b2..020299ff7 100644 --- a/pyhealth/tasks/chestxray14_binary_classification.py +++ b/pyhealth/tasks/chestxray14_binary_classification.py @@ -16,6 +16,7 @@ Author: Eric Schrock (ejs9@illinois.edu) """ + import logging from typing import Dict, List @@ -24,6 +25,7 @@ logger = logging.getLogger(__name__) + class ChestXray14BinaryClassification(BaseTask): """ A PyHealth task class for binary classification of a specific disease @@ -34,7 +36,15 @@ class ChestXray14BinaryClassification(BaseTask): input_schema (Dict[str, str]): The schema for the task input. output_schema (Dict[str, str]): The schema for the task output. disease (str): The disease label to classify. + + Examples: + >>> from pyhealth.datasets import ChestXray14Dataset + >>> from pyhealth.tasks import ChestXray14BinaryClassification + >>> dataset = ChestXray14Dataset(root="/path/to/chestxray14") + >>> task = ChestXray14BinaryClassification(disease="pneumonia") + >>> samples = dataset.set_task(task) """ + task_name: str = "ChestXray14BinaryClassification" input_schema: Dict[str, str] = {"image": "image"} output_schema: Dict[str, str] = {"label": "binary"} @@ -50,7 +60,8 @@ def __init__(self, disease: str) -> None: Raises: ValueError: If the specified disease is not a valid class in the dataset. """ - from pyhealth.datasets import ChestXray14Dataset # Avoid circular import + from pyhealth.datasets import ChestXray14Dataset # Avoid circular import + if disease not in ChestXray14Dataset.classes: msg = f"Invalid disease: '{disease}'! Must be one of {ChestXray14Dataset.classes}." logger.error(msg) diff --git a/pyhealth/tasks/chestxray14_multilabel_classification.py b/pyhealth/tasks/chestxray14_multilabel_classification.py index 77b2b9f15..43ad3569d 100644 --- a/pyhealth/tasks/chestxray14_multilabel_classification.py +++ b/pyhealth/tasks/chestxray14_multilabel_classification.py @@ -16,6 +16,7 @@ Author: Eric Schrock (ejs9@illinois.edu) """ + import logging from typing import Dict, List @@ -24,6 +25,7 @@ logger = logging.getLogger(__name__) + class ChestXray14MultilabelClassification(BaseTask): """ A PyHealth task class for multilabel classification of all fourteen diseases @@ -33,7 +35,15 @@ class ChestXray14MultilabelClassification(BaseTask): task_name (str): The name of the task. input_schema (Dict[str, str]): The schema for the task input. output_schema (Dict[str, str]): The schema for the task output. + + Examples: + >>> from pyhealth.datasets import ChestXray14Dataset + >>> from pyhealth.tasks import ChestXray14MultilabelClassification + >>> dataset = ChestXray14Dataset(root="/path/to/chestxray14") + >>> task = ChestXray14MultilabelClassification() + >>> samples = dataset.set_task(task) """ + task_name: str = "ChestXray14MultilabelClassification" input_schema: Dict[str, str] = {"image": "image"} output_schema: Dict[str, str] = {"labels": "multilabel"} @@ -54,8 +64,18 @@ def __call__(self, patient: Patient) -> List[Dict]: events: List[Event] = patient.get_events(event_type="chestxray14") samples = [] - from pyhealth.datasets import ChestXray14Dataset # Avoid circular import + from pyhealth.datasets import ChestXray14Dataset # Avoid circular import + for event in events: - samples.append({"image": event["path"], "labels": [disease for disease in ChestXray14Dataset.classes if int(event[disease])]}) + samples.append( + { + "image": event["path"], + "labels": [ + disease + for disease in ChestXray14Dataset.classes + if int(event[disease]) + ], + } + ) return samples diff --git a/pyhealth/tasks/covid19_cxr_classification.py b/pyhealth/tasks/covid19_cxr_classification.py index 44c9c6776..ae6814928 100644 --- a/pyhealth/tasks/covid19_cxr_classification.py +++ b/pyhealth/tasks/covid19_cxr_classification.py @@ -17,6 +17,13 @@ class COVID19CXRClassification(BaseTask): input format. Contains a single key "image" with value "image". output_schema (Dict[str, str]): The output schema specifying the output format. Contains a single key "disease" with value "multiclass". + + Examples: + >>> from pyhealth.datasets import COVID19CXRDataset + >>> from pyhealth.tasks import COVID19CXRClassification + >>> dataset = COVID19CXRDataset(root="/path/to/covid19_cxr") + >>> task = COVID19CXRClassification() + >>> samples = dataset.set_task(task) """ task_name: str = "COVID19CXRClassification" diff --git a/pyhealth/tasks/drug_recommendation.py b/pyhealth/tasks/drug_recommendation.py index ae0432b64..e42358299 100644 --- a/pyhealth/tasks/drug_recommendation.py +++ b/pyhealth/tasks/drug_recommendation.py @@ -153,6 +153,16 @@ class DrugRecommendationMIMIC4(BaseTask): - drugs_hist: Nested list of drug codes from history (current visit excluded) output_schema (Dict[str, str]): The schema for output data: - drugs: List of drugs to predict for current visit + + Examples: + >>> from pyhealth.datasets import MIMIC4EHRDataset + >>> from pyhealth.tasks import DrugRecommendationMIMIC4 + >>> dataset = MIMIC4EHRDataset( + ... root="/path/to/mimic-iv/2.2", + ... tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + ... ) + >>> task = DrugRecommendationMIMIC4() + >>> samples = dataset.set_task(task) """ task_name: str = "DrugRecommendationMIMIC4" diff --git a/pyhealth/tasks/in_hospital_mortality_mimic4.py b/pyhealth/tasks/in_hospital_mortality_mimic4.py index 6835648b5..334239714 100644 --- a/pyhealth/tasks/in_hospital_mortality_mimic4.py +++ b/pyhealth/tasks/in_hospital_mortality_mimic4.py @@ -18,6 +18,16 @@ class InHospitalMortalityMIMIC4(BaseTask): - labs: A timeseries of lab results. output_schema (Dict[str, str]): The schema for output data, which includes: - mortality: A binary indicator of mortality. + + Examples: + >>> from pyhealth.datasets import MIMIC4EHRDataset + >>> from pyhealth.tasks import InHospitalMortalityMIMIC4 + >>> dataset = MIMIC4EHRDataset( + ... root="/path/to/mimic-iv/2.2", + ... tables=["labevents"], + ... ) + >>> task = InHospitalMortalityMIMIC4() + >>> samples = dataset.set_task(task) """ task_name: str = "InHospitalMortalityMIMIC4" diff --git a/pyhealth/tasks/length_of_stay_stagenet_mimic4.py b/pyhealth/tasks/length_of_stay_stagenet_mimic4.py index 8a509f068..be05a22b6 100644 --- a/pyhealth/tasks/length_of_stay_stagenet_mimic4.py +++ b/pyhealth/tasks/length_of_stay_stagenet_mimic4.py @@ -29,6 +29,16 @@ class LengthOfStayStageNetMIMIC4(BaseTask): Args: padding: Optional padding forwarded to the StageNet processor for nested sequences. Default is 0. + + Examples: + >>> from pyhealth.datasets import MIMIC4EHRDataset + >>> from pyhealth.tasks import LengthOfStayStageNetMIMIC4 + >>> dataset = MIMIC4EHRDataset( + ... root="/path/to/mimic-iv/2.2", + ... tables=["diagnoses_icd", "procedures_icd", "labevents"], + ... ) + >>> task = LengthOfStayStageNetMIMIC4() + >>> samples = dataset.set_task(task) """ task_name: str = "LengthOfStayStageNetMIMIC4" diff --git a/pyhealth/tasks/medical_coding.py b/pyhealth/tasks/medical_coding.py index 23b3cb2e4..739c674d1 100644 --- a/pyhealth/tasks/medical_coding.py +++ b/pyhealth/tasks/medical_coding.py @@ -25,6 +25,16 @@ class MIMIC3ICD9Coding(BaseTask): task_name: Name of the task input_schema: Definition of the input data schema output_schema: Definition of the output data schema + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import MIMIC3ICD9Coding + >>> dataset = MIMIC3Dataset( + ... root="/path/to/mimic-iii/1.4", + ... tables=["diagnoses_icd", "procedures_icd", "noteevents"], + ... ) + >>> task = MIMIC3ICD9Coding() + >>> samples = dataset.set_task(task) """ task_name: str = "mimic3_icd9_coding" diff --git a/pyhealth/tasks/medical_transcriptions_classification.py b/pyhealth/tasks/medical_transcriptions_classification.py index c4e27f89a..37a84fa21 100644 --- a/pyhealth/tasks/medical_transcriptions_classification.py +++ b/pyhealth/tasks/medical_transcriptions_classification.py @@ -16,7 +16,17 @@ class MedicalTranscriptionsClassification(BaseTask): task_name (str): Name of the task input_schema (Dict[str, str]): Schema defining input features output_schema (Dict[str, str]): Schema defining output features + + Examples: + >>> from pyhealth.datasets import MedicalTranscriptionsDataset + >>> from pyhealth.tasks import MedicalTranscriptionsClassification + >>> dataset = MedicalTranscriptionsDataset( + ... root="/path/to/medical_transcriptions", + ... ) + >>> task = MedicalTranscriptionsClassification() + >>> samples = dataset.set_task(task) """ + task_name: str = "MedicalTranscriptionsClassification" input_schema: Dict[str, str] = {"transcription": "text"} output_schema: Dict[str, str] = {"medical_specialty": "multiclass"} diff --git a/pyhealth/tasks/mortality_prediction.py b/pyhealth/tasks/mortality_prediction.py index b1bed4665..ba5d42977 100644 --- a/pyhealth/tasks/mortality_prediction.py +++ b/pyhealth/tasks/mortality_prediction.py @@ -9,6 +9,16 @@ class MortalityPredictionMIMIC3(BaseTask): This task aims to predict whether the patient will decease in the next hospital visit based on clinical information from the current visit. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import MortalityPredictionMIMIC3 + >>> dataset = MIMIC3Dataset( + ... root="/path/to/mimic-iii/1.4", + ... tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + ... ) + >>> task = MortalityPredictionMIMIC3() + >>> samples = dataset.set_task(task) """ task_name: str = "MortalityPredictionMIMIC3" @@ -78,6 +88,17 @@ class MultimodalMortalityPredictionMIMIC3(BaseTask): This task aims to predict whether the patient will decease in the next hospital visit based on clinical information from the current visit. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import MultimodalMortalityPredictionMIMIC3 + >>> dataset = MIMIC3Dataset( + ... root="/path/to/mimic-iii/1.4", + ... tables=["diagnoses_icd", "procedures_icd", "prescriptions", + ... "noteevents"], + ... ) + >>> task = MultimodalMortalityPredictionMIMIC3() + >>> samples = dataset.set_task(task) """ task_name: str = "MultimodalMortalityPredictionMIMIC3" @@ -149,7 +170,18 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: class MortalityPredictionMIMIC4(BaseTask): - """Task for predicting mortality using MIMIC-IV EHR data only.""" + """Task for predicting mortality using MIMIC-IV EHR data only. + + Examples: + >>> from pyhealth.datasets import MIMIC4EHRDataset + >>> from pyhealth.tasks import MortalityPredictionMIMIC4 + >>> dataset = MIMIC4EHRDataset( + ... root="/path/to/mimic-iv/2.2", + ... tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + ... ) + >>> task = MortalityPredictionMIMIC4() + >>> samples = dataset.set_task(task) + """ task_name: str = "MortalityPredictionMIMIC4" input_schema: Dict[str, str] = { @@ -280,6 +312,21 @@ class MultimodalMortalityPredictionMIMIC4(BaseTask): - Lab events: 10-dimensional lab value vectors (time-series) - Chest X-rays: Must have an image path available + Examples: + >>> from pyhealth.datasets import MIMIC4Dataset + >>> from pyhealth.tasks import MultimodalMortalityPredictionMIMIC4 + >>> dataset = MIMIC4Dataset( + ... ehr_root="/path/to/mimic-iv/2.2", + ... note_root="/path/to/mimic-iv-note/2.2", + ... cxr_root="/path/to/mimic-cxr/2.0.0", + ... ehr_tables=["diagnoses_icd", "procedures_icd", + ... "prescriptions", "labevents"], + ... note_tables=["discharge", "radiology"], + ... cxr_tables=["metadata", "negbio"], + ... ) + >>> task = MultimodalMortalityPredictionMIMIC4() + >>> samples = dataset.set_task(task) + Patient-Level Aggregation: - Mortality is determined iteratively by checking if the NEXT admission has the death flag @@ -373,8 +420,11 @@ def _clean_text(self, text: Optional[str]) -> Optional[str]: return text if text else None def _process_lab_events( - self, patient: Any, admission_time: datetime, admission_dischtime: datetime, - reference_time: Optional[datetime] = None + self, + patient: Any, + admission_time: datetime, + admission_dischtime: datetime, + reference_time: Optional[datetime] = None, ) -> Optional[tuple]: """Process lab events into 10-dimensional vectors with timestamps. @@ -416,9 +466,7 @@ def _process_lab_events( # Parse storetime and filter (matching stagenet implementation) labevents_df = labevents_df.with_columns( - pl.col("labevents/storetime").str.strptime( - pl.Datetime, "%Y-%m-%d %H:%M:%S" - ) + pl.col("labevents/storetime").str.strptime(pl.Datetime, "%Y-%m-%d %H:%M:%S") ) labevents_df = labevents_df.filter( pl.col("labevents/storetime") <= admission_dischtime @@ -459,9 +507,7 @@ def _process_lab_events( lab_vector.append(category_value) # Calculate time from reference time (hours) - time_from_reference = ( - lab_ts - reference_time - ).total_seconds() / 3600.0 + time_from_reference = (lab_ts - reference_time).total_seconds() / 3600.0 lab_times.append(time_from_reference) lab_values.append(lab_vector) @@ -599,25 +645,23 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: # Get clinical codes using hadm_id filtering diagnoses_icd = patient.get_events( event_type="diagnoses_icd", - filters=[("hadm_id", "==", admission.hadm_id)] + filters=[("hadm_id", "==", admission.hadm_id)], ) procedures_icd = patient.get_events( event_type="procedures_icd", - filters=[("hadm_id", "==", admission.hadm_id)] + filters=[("hadm_id", "==", admission.hadm_id)], ) prescriptions = patient.get_events( event_type="prescriptions", - filters=[("hadm_id", "==", admission.hadm_id)] + filters=[("hadm_id", "==", admission.hadm_id)], ) # Get notes using hadm_id filtering discharge_notes = patient.get_events( - event_type="discharge", - filters=[("hadm_id", "==", admission.hadm_id)] + event_type="discharge", filters=[("hadm_id", "==", admission.hadm_id)] ) radiology_notes = patient.get_events( - event_type="radiology", - filters=[("hadm_id", "==", admission.hadm_id)] + event_type="radiology", filters=[("hadm_id", "==", admission.hadm_id)] ) # Extract clinical codes per visit (nested structure) @@ -627,9 +671,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: procedures_list = self._clean_sequence( [event.icd_code for event in procedures_icd] ) - drugs = self._clean_sequence( - [event.ndc for event in prescriptions] - ) + drugs = self._clean_sequence([event.ndc for event in prescriptions]) # Append as nested lists (one list per visit) for nested_sequence all_conditions.append(conditions) @@ -656,8 +698,10 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: # Process lab events with reference to first admission time labs_data = self._process_lab_events( - patient, admission.timestamp, admission_dischtime, - reference_time=first_admission_time + patient, + admission.timestamp, + admission_dischtime, + reference_time=first_admission_time, ) if labs_data is not None: @@ -685,8 +729,14 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: has_image = bool(image_path) # Return empty list if any required modality is missing - if not (has_conditions and has_procedures and has_drugs - and has_notes and has_labs and has_image): + if not ( + has_conditions + and has_procedures + and has_drugs + and has_notes + and has_labs + and has_image + ): return [] # Sort lab events by time and create aggregated labs data @@ -735,6 +785,16 @@ class MortalityPredictionEICU(BaseTask): - using diagnosis table (ICD9CM and ICD10CM) as condition codes - using physicalExam table as procedure codes - using medication table as drugs codes + + Examples: + >>> from pyhealth.datasets import eICUDataset + >>> from pyhealth.tasks import MortalityPredictionEICU + >>> dataset = eICUDataset( + ... root="/path/to/eicu-crd/2.0", + ... tables=["diagnosis", "medication", "physicalExam"], + ... ) + >>> task = MortalityPredictionEICU() + >>> samples = dataset.set_task(task) """ task_name: str = "MortalityPredictionEICU" @@ -818,6 +878,16 @@ class MortalityPredictionEICU2(BaseTask): Similar to MortalityPredictionEICU, but with different code mapping: - using admissionDx table and diagnosisString under diagnosis table as condition codes - using treatment table as procedure codes + + Examples: + >>> from pyhealth.datasets import eICUDataset + >>> from pyhealth.tasks import MortalityPredictionEICU2 + >>> dataset = eICUDataset( + ... root="/path/to/eicu-crd/2.0", + ... tables=["diagnosis", "treatment", "admissionDx"], + ... ) + >>> task = MortalityPredictionEICU2() + >>> samples = dataset.set_task(task) """ task_name: str = "MortalityPredictionEICU2" diff --git a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py index 91e1f94cd..4c0505f2d 100644 --- a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py +++ b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py @@ -36,6 +36,16 @@ class MortalityPredictionStageNetMIMIC4(BaseTask): - labs: Lab results (stagenet_tensor, 10D vectors per timestamp) output_schema (Dict[str, str]): The schema for output data: - mortality: Binary indicator (1 if any admission had mortality) + + Examples: + >>> from pyhealth.datasets import MIMIC4EHRDataset + >>> from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 + >>> dataset = MIMIC4EHRDataset( + ... root="/path/to/mimic-iv/2.2", + ... tables=["diagnoses_icd", "procedures_icd", "labevents"], + ... ) + >>> task = MortalityPredictionStageNetMIMIC4() + >>> samples = dataset.set_task(task) """ task_name: str = "MortalityPredictionStageNetMIMIC4" diff --git a/pyhealth/tasks/readmission_prediction.py b/pyhealth/tasks/readmission_prediction.py index 478963bac..821a7fea4 100644 --- a/pyhealth/tasks/readmission_prediction.py +++ b/pyhealth/tasks/readmission_prediction.py @@ -16,6 +16,16 @@ class ReadmissionPredictionMIMIC3(BaseTask): task_name (str): The name of the task. input_schema (Dict[str, str]): The schema for the task input. output_schema (Dict[str, str]): The schema for the task output. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import ReadmissionPredictionMIMIC3 + >>> dataset = MIMIC3Dataset( + ... root="/path/to/mimic-iii/1.4", + ... tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + ... ) + >>> task = ReadmissionPredictionMIMIC3() + >>> samples = dataset.set_task(task) """ task_name: str = "ReadmissionPredictionMIMIC3" input_schema: Dict[str, str] = {"conditions": "sequence", "procedures": "sequence", "drugs": "sequence"} @@ -123,6 +133,16 @@ class ReadmissionPredictionMIMIC4(BaseTask): task_name (str): The name of the task. input_schema (Dict[str, str]): The schema for the task input. output_schema (Dict[str, str]): The schema for the task output. + + Examples: + >>> from pyhealth.datasets import MIMIC4EHRDataset + >>> from pyhealth.tasks import ReadmissionPredictionMIMIC4 + >>> dataset = MIMIC4EHRDataset( + ... root="/path/to/mimic-iv/2.2", + ... tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + ... ) + >>> task = ReadmissionPredictionMIMIC4() + >>> samples = dataset.set_task(task) """ task_name: str = "ReadmissionPredictionMIMIC4" input_schema: Dict[str, str] = {"conditions": "sequence", "procedures": "sequence", "drugs": "sequence"} @@ -364,6 +384,17 @@ class ReadmissionPredictionOMOP(BaseTask): task_name (str): The name of the task. input_schema (Dict[str, str]): The schema for the task input. output_schema (Dict[str, str]): The schema for the task output. + + Examples: + >>> from pyhealth.datasets import OMOPDataset + >>> from pyhealth.tasks import ReadmissionPredictionOMOP + >>> dataset = OMOPDataset( + ... root="/path/to/omop/data", + ... tables=["condition_occurrence", "procedure_occurrence", + ... "drug_exposure"], + ... ) + >>> task = ReadmissionPredictionOMOP() + >>> samples = dataset.set_task(task) """ task_name: str = "ReadmissionPredictionOMOP" input_schema: Dict[str, str] = {"conditions": "sequence", "procedures": "sequence", "drugs": "sequence"}