diff --git a/docs/api/processors.rst b/docs/api/processors.rst index 25de2fece..a27000f08 100644 --- a/docs/api/processors.rst +++ b/docs/api/processors.rst @@ -36,6 +36,7 @@ Available Processors - ``AudioProcessor``: For audio signal data - ``SignalProcessor``: For general signal data (e.g., EEG, ECG) - ``TimeseriesProcessor``: For time-series data +- ``TimeImageProcessor``: For time-stamped image sequences (e.g., serial X-rays) - ``TensorProcessor``: For pre-processed tensor data - ``RawProcessor``: Pass-through processor for raw data @@ -270,6 +271,7 @@ Common string keys for automatic processor selection: - ``"audio"``: For audio data - ``"signal"``: For signal data - ``"timeseries"``: For time-series data +- ``"time_image"``: For time-stamped image sequences - ``"tensor"``: For pre-processed tensors - ``"raw"``: For raw/unprocessed data @@ -459,6 +461,7 @@ API Reference processors/pyhealth.processors.AudioProcessor processors/pyhealth.processors.SignalProcessor processors/pyhealth.processors.TimeseriesProcessor + processors/pyhealth.processors.TimeImageProcessor processors/pyhealth.processors.TensorProcessor processors/pyhealth.processors.RawProcessor processors/pyhealth.processors.IgnoreProcessor diff --git a/docs/api/processors/pyhealth.processors.TimeImageProcessor.rst b/docs/api/processors/pyhealth.processors.TimeImageProcessor.rst new file mode 100644 index 000000000..6abdb26a4 --- /dev/null +++ b/docs/api/processors/pyhealth.processors.TimeImageProcessor.rst @@ -0,0 +1,9 @@ +pyhealth.processors.TimeImageProcessor +======================================= + +Processor for time-aware image data. + +.. autoclass:: pyhealth.processors.TimeImageProcessor + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/time_image_processor_tutorial.ipynb b/examples/time_image_processor_tutorial.ipynb new file mode 100644 index 000000000..54d2987dd --- /dev/null +++ b/examples/time_image_processor_tutorial.ipynb @@ -0,0 +1,505 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TimeImageProcessor Tutorial\n", + "\n", + "This notebook demonstrates how to use the `TimeImageProcessor` for multimodal PyHealth pipelines.\n", + "\n", + "**Contributors:** Josh Steier\n", + "\n", + "## Overview\n", + "\n", + "The `TimeImageProcessor` is a time-aware image processor that pairs image loading with temporal metadata. It is designed for tasks where each patient has **multiple images taken at different times** (e.g., serial chest X-rays during an ICU stay).\n", + "\n", + "**Input:** `(List[image_path], List[time_diff_from_first_admission])`\n", + "\n", + "**Output:** `(N×C×H×W image tensor, N timestamp tensor, \"image\")`\n", + "\n", + "### Steps\n", + "1. Create synthetic time-stamped chest X-ray data\n", + "2. Standalone processor usage and verification\n", + "3. Processor with normalization and truncation\n", + "4. Integration with `create_sample_dataset`\n", + "5. Verify multimodal compatibility for downstream fusion" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Environment Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import random\n", + "import shutil\n", + "import tempfile\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from PIL import Image\n", + "\n", + "from pyhealth.datasets import create_sample_dataset\n", + "from pyhealth.processors.time_image_processor import TimeImageProcessor\n", + "\n", + "SEED = 42\n", + "random.seed(SEED)\n", + "np.random.seed(SEED)\n", + "torch.manual_seed(SEED)\n", + "\n", + "print(f\"PyTorch version: {torch.__version__}\")\n", + "print(f\"CUDA available: {torch.cuda.is_available()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Create Synthetic Time-Stamped X-ray Data\n", + "\n", + "We simulate a scenario where each patient has 1–5 chest X-rays taken at different times during their hospital stay. Each image gets a timestamp representing days from the patient's first admission." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DATA_ROOT = tempfile.mkdtemp(prefix=\"time_image_example_\")\n", + "images_dir = os.path.join(DATA_ROOT, \"images\")\n", + "os.makedirs(images_dir, exist_ok=True)\n", + "\n", + "NUM_PATIENTS = 20\n", + "MAX_IMAGES_PER_PATIENT = 5\n", + "\n", + "samples = []\n", + "for pid in range(NUM_PATIENTS):\n", + " # Each patient has 1-5 X-rays taken at different times\n", + " n_images = np.random.randint(1, MAX_IMAGES_PER_PATIENT + 1)\n", + "\n", + " # Time differences from first admission in days\n", + " time_diffs = sorted(np.random.uniform(0, 30, size=n_images).tolist())\n", + " time_diffs[0] = 0.0 # First image always at t=0\n", + "\n", + " image_paths = []\n", + " for j in range(n_images):\n", + " # Synthetic grayscale X-ray with noise\n", + " img_array = np.random.normal(80, 25, (224, 224))\n", + "\n", + " # Add lung-shaped regions\n", + " y, x = np.ogrid[:224, :224]\n", + " left_mask = ((x - 72)**2 / 3000 + (y - 112)**2 / 8000) < 1\n", + " right_mask = ((x - 152)**2 / 3000 + (y - 112)**2 / 8000) < 1\n", + " img_array[left_mask] -= 20\n", + " img_array[right_mask] -= 20\n", + "\n", + " img_array = np.clip(img_array, 0, 255).astype(np.uint8)\n", + " img = Image.fromarray(img_array, mode=\"L\")\n", + "\n", + " img_path = os.path.join(images_dir, f\"p{pid:03d}_t{j:02d}.png\")\n", + " img.save(img_path)\n", + " image_paths.append(img_path)\n", + "\n", + " # Binary mortality label\n", + " label = pid % 2\n", + "\n", + " samples.append({\n", + " \"patient_id\": f\"p{pid}\",\n", + " \"visit_id\": f\"v{pid}\",\n", + " \"chest_xray\": (image_paths, time_diffs),\n", + " \"label\": label,\n", + " })\n", + "\n", + "print(f\"Created {NUM_PATIENTS} patients in {DATA_ROOT}\")\n", + "print(f\"Images per patient: 1-{MAX_IMAGES_PER_PATIENT}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Inspect a sample patient\n", + "sample = samples[0]\n", + "paths, times = sample[\"chest_xray\"]\n", + "\n", + "print(f\"Patient {sample['patient_id']}:\")\n", + "print(f\" Number of images: {len(paths)}\")\n", + "print(f\" Times (days from admission): {[round(t, 1) for t in times]}\")\n", + "print(f\" Mortality label: {sample['label']}\")\n", + "print(f\" Image paths:\")\n", + "for p, t in zip(paths, times):\n", + " print(f\" t={t:5.1f}d {os.path.basename(p)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Standalone Processor Usage\n", + "\n", + "The `TimeImageProcessor` takes a tuple `(image_paths, time_diffs)` and returns `(images_tensor, timestamps_tensor, \"image\")`.\n", + "\n", + "Key behaviors:\n", + "- **Sorts images chronologically** by timestamp\n", + "- **Truncates** to `max_images` most recent if set\n", + "- Returns the `\"image\"` tag for modality routing in the multimodal embedding model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "proc = TimeImageProcessor(\n", + " image_size=224,\n", + " mode=\"L\",\n", + ")\n", + "\n", + "images, timestamps, tag = proc.process(sample[\"chest_xray\"])\n", + "\n", + "print(f\"Input: ({len(paths)} paths, {len(times)} timestamps)\")\n", + "print(f\"\")\n", + "print(f\"Output:\")\n", + "print(f\" images shape: {images.shape} # (N, C, H, W)\")\n", + "print(f\" timestamps shape: {timestamps.shape} # (N,)\")\n", + "print(f\" modality tag: {tag!r}\")\n", + "print(f\"\")\n", + "print(f\" images dtype: {images.dtype}\")\n", + "print(f\" timestamps dtype: {timestamps.dtype}\")\n", + "print(f\" pixel range: [{images.min():.3f}, {images.max():.3f}]\")\n", + "print(f\" timestamps: {timestamps.tolist()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Chronological Sorting Verification\n", + "\n", + "Even if image paths are provided in random order, the processor always returns them sorted by timestamp." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Provide images in reverse order\n", + "reversed_paths = list(reversed(paths))\n", + "reversed_times = list(reversed(times))\n", + "\n", + "print(f\"Input order (reversed):\")\n", + "for p, t in zip(reversed_paths, reversed_times):\n", + " print(f\" t={t:5.1f}d {os.path.basename(p)}\")\n", + "\n", + "_, sorted_timestamps, _ = proc.process((reversed_paths, reversed_times))\n", + "\n", + "print(f\"\\nOutput timestamps (sorted): {sorted_timestamps.tolist()}\")\n", + "print(f\"Correctly sorted: {all(sorted_timestamps[i] <= sorted_timestamps[i+1] for i in range(len(sorted_timestamps)-1))}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Truncation with `max_images`\n", + "\n", + "When `max_images` is set, the processor keeps only the **most recent** images (by timestamp). This is useful for patients with many X-rays where you want to cap compute." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "proc_truncated = TimeImageProcessor(\n", + " image_size=224,\n", + " mode=\"L\",\n", + " max_images=2,\n", + ")\n", + "\n", + "imgs_trunc, ts_trunc, _ = proc_truncated.process(sample[\"chest_xray\"])\n", + "\n", + "print(f\"Original images: {len(paths)}\")\n", + "print(f\"max_images: 2\")\n", + "print(f\"Output images: {imgs_trunc.shape[0]}\")\n", + "print(f\"Kept timestamps: {ts_trunc.tolist()}\")\n", + "print(f\"(These are the 2 most recent observations)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Normalization\n", + "\n", + "ImageNet-style normalization can be applied for pretrained backbone compatibility." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "proc_norm = TimeImageProcessor(\n", + " image_size=128,\n", + " mode=\"L\",\n", + " normalize=True,\n", + " mean=[0.5],\n", + " std=[0.5],\n", + ")\n", + "\n", + "imgs_norm, ts_norm, _ = proc_norm.process(sample[\"chest_xray\"])\n", + "\n", + "print(f\"Without normalization:\")\n", + "print(f\" pixel range: [{images.min():.3f}, {images.max():.3f}]\")\n", + "print(f\"\")\n", + "print(f\"With normalization (mean=0.5, std=0.5):\")\n", + "print(f\" pixel range: [{imgs_norm.min():.3f}, {imgs_norm.max():.3f}]\")\n", + "print(f\" output shape: {imgs_norm.shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Integration with `create_sample_dataset`\n", + "\n", + "The processor is registered as `\"time_image\"` in PyHealth's processor registry, so it can be used in task schemas. Here we show it working with `create_sample_dataset`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = create_sample_dataset(\n", + " samples=samples,\n", + " input_schema={\n", + " \"chest_xray\": \"time_image\",\n", + " },\n", + " output_schema={\n", + " \"label\": \"binary\",\n", + " },\n", + " input_processors={\n", + " \"chest_xray\": TimeImageProcessor(\n", + " image_size=224,\n", + " mode=\"L\",\n", + " max_images=4,\n", + " ),\n", + " },\n", + " dataset_name=\"time_xray_example\",\n", + ")\n", + "\n", + "print(f\"Dataset: {dataset}\")\n", + "print(f\"Total samples: {len(dataset)}\")\n", + "print(f\"Input schema: {dataset.input_schema}\")\n", + "print(f\"Output schema: {dataset.output_schema}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Inspect a processed sample\n", + "processed = dataset[0]\n", + "print(f\"Processed sample keys: {list(processed.keys())}\")\n", + "print()\n", + "\n", + "xray_data = processed[\"chest_xray\"]\n", + "if isinstance(xray_data, tuple):\n", + " img_tensor, ts_tensor, modality_tag = xray_data\n", + " print(f\"chest_xray output:\")\n", + " print(f\" images shape: {img_tensor.shape} # (N, C, H, W)\")\n", + " print(f\" timestamps shape: {ts_tensor.shape} # (N,)\")\n", + " print(f\" modality tag: {modality_tag!r}\")\n", + "else:\n", + " print(f\" type: {type(xray_data)}\")\n", + "\n", + "print(f\"\\nlabel: {processed['label']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Per-Patient Output Shape Summary\n", + "\n", + "Since patients have different numbers of X-rays, the output tensor shapes vary per patient. This is expected — the multimodal embedding model handles variable-length inputs via masking." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "proc_demo = TimeImageProcessor(image_size=224, mode=\"L\", max_images=4)\n", + "\n", + "print(f\"{'Patient':<10} {'N imgs':<8} {'Output Shape':<25} {'Time Range (days)'}\")\n", + "print(\"-\" * 65)\n", + "\n", + "for i in range(min(10, len(samples))):\n", + " s = samples[i]\n", + " paths_i, times_i = s[\"chest_xray\"]\n", + " imgs_i, ts_i, _ = proc_demo.process((paths_i, times_i))\n", + " print(\n", + " f\"{s['patient_id']:<10} \"\n", + " f\"{len(paths_i):<8} \"\n", + " f\"{str(tuple(imgs_i.shape)):<25} \"\n", + " f\"[{ts_i[0]:.1f}, {ts_i[-1]:.1f}]\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 9. Multimodal Compatibility\n", + "\n", + "The `TimeImageProcessor` output format is designed to feed directly into the **unified multimodal embedding model**:\n", + "\n", + "```\n", + "TimeImageProcessor\n", + " ↓\n", + "(N, C, H, W) images + (N,) timestamps + \"image\" tag\n", + " ↓\n", + "VisionEncoder(images) → (B, P, E') patch embeddings\n", + "TimeEmbedding(timestamps) → temporal encoding\n", + "ModalityEmbedding(\"image\") → modality type encoding\n", + " ↓\n", + "Combined: (B, P, E') vision tokens\n", + " ↓\n", + "Concatenate with other modalities:\n", + " TextEncoder → (B, T, E') text tokens\n", + " TimeseriesProc → (B, S, E') timeseries tokens\n", + " ↓\n", + "(B, P+T+S, E') → BottleneckTransformer\n", + "```\n", + "\n", + "This matches the architecture specified in the multimodal design doc." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Simulated multimodal input shapes for one patient\n", + "E_prime = 128 # shared embedding dimension\n", + "\n", + "# Vision: TimeImageProcessor -> VisionEncoder\n", + "P = 49 # patches from CNN/ResNet backbone\n", + "vision_tokens = torch.randn(1, P, E_prime)\n", + "\n", + "# Text: TextProcessor -> TextEncoder (Medical RoBERTa)\n", + "T = 64 # 128-token chunks\n", + "text_tokens = torch.randn(1, T, E_prime)\n", + "\n", + "# Timeseries: TimeseriesProcessor -> TimeseriesEncoder\n", + "S = 48 # hourly lab values over 2 days\n", + "ts_tokens = torch.randn(1, S, E_prime)\n", + "\n", + "# Concatenate for transformer fusion\n", + "combined = torch.cat([vision_tokens, text_tokens, ts_tokens], dim=1)\n", + "\n", + "print(f\"Vision tokens: {tuple(vision_tokens.shape)} # P={P} patches\")\n", + "print(f\"Text tokens: {tuple(text_tokens.shape)} # T={T} tokens\")\n", + "print(f\"Timeseries tokens: {tuple(ts_tokens.shape)} # S={S} steps\")\n", + "print(f\"\")\n", + "print(f\"Combined sequence: {tuple(combined.shape)} # P+T+S={P+T+S} tokens\")\n", + "print(f\"Ready for BottleneckTransformer input\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 10. Cleanup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "shutil.rmtree(DATA_ROOT)\n", + "print(f\"Cleaned up: {DATA_ROOT}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "### TimeImageProcessor\n", + "\n", + "| Feature | Details |\n", + "|---|---|\n", + "| **Registry name** | `\"time_image\"` |\n", + "| **Input** | `(List[image_path], List[time_diff])` |\n", + "| **Output** | `(N×C×H×W tensor, N tensor, \"image\")` |\n", + "| **Sorting** | Chronological by timestamp |\n", + "| **Truncation** | `max_images` keeps most recent |\n", + "| **Normalization** | Optional ImageNet-style |\n", + "| **Mode** | RGB, L (grayscale), RGBA |\n", + "\n", + "### Usage in task schema\n", + "\n", + "```python\n", + "input_schema = {\n", + " \"chest_xray\": (\"time_image\", {\n", + " \"image_size\": 224,\n", + " \"mode\": \"RGB\",\n", + " \"normalize\": True,\n", + " \"mean\": [0.485, 0.456, 0.406],\n", + " \"std\": [0.229, 0.224, 0.225],\n", + " \"max_images\": 8,\n", + " }),\n", + "}\n", + "```\n", + "\n", + "### Downstream pipeline\n", + "\n", + "```\n", + "TimeImageProcessor → VisionEmbeddingModel → MultimodalEmbeddingModel → BottleneckTransformer\n", + "```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pyhealth/processors/__init__.py b/pyhealth/processors/__init__.py index 15512c2d7..6d8b9ad33 100644 --- a/pyhealth/processors/__init__.py +++ b/pyhealth/processors/__init__.py @@ -44,6 +44,7 @@ def get_processor(name: str): from .tensor_processor import TensorProcessor from .text_processor import TextProcessor from .timeseries_processor import TimeseriesProcessor +from .time_image_processor import TimeImageProcessor from .audio_processor import AudioProcessor from .ignore_processor import IgnoreProcessor @@ -63,5 +64,6 @@ def get_processor(name: str): "TensorProcessor", "TextProcessor", "TimeseriesProcessor", + "TimeImageProcessor", "AudioProcessor", ] diff --git a/pyhealth/processors/time_image_processor.py b/pyhealth/processors/time_image_processor.py new file mode 100644 index 000000000..cdd853367 --- /dev/null +++ b/pyhealth/processors/time_image_processor.py @@ -0,0 +1,328 @@ +# Author: Joshua Steier +# Description: Time-aware image processor for multimodal PyHealth +# pipelines. Pairs image loading with temporal metadata for +# unified multimodal embedding models. Designed for tasks +# where each patient has multiple images taken at different +# times (e.g., serial chest X-rays during an ICU stay). + +from functools import partial +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torchvision.transforms as transforms +from PIL import Image + +from . import register_processor +from .base_processor import FeatureProcessor + + +def _convert_mode(img: Image.Image, mode: str) -> Image.Image: + """Convert a PIL image to the requested mode. + + Args: + img: Input PIL image. + mode: Target PIL image mode (e.g., "RGB", "L"). + + Returns: + Converted PIL image. + """ + return img.convert(mode) + + +@register_processor("time_image") +class TimeImageProcessor(FeatureProcessor): + """Feature processor that loads images and pairs them with timestamps. + + Takes a tuple of (image_paths, time_differences) and returns a tuple + of (stacked_image_tensor, timestamp_tensor, "image") suitable for + the unified multimodal embedding model. + + The processor sorts images chronologically by timestamp and + optionally caps the number of images per patient, keeping the most + recent observations. + + Input: + - image_paths: List[str | Path] + - time_diffs: List[float] (e.g., days from first admission) + + Processing: + 1. Sort (path, time) pairs chronologically. + 2. Truncate to max_images most recent if set. + 3. Load, resize, and transform each image. + 4. Stack into a single tensor. + + Output: + - Tuple of (images, timestamps, "image") where: + - images: torch.Tensor of shape (N, C, H, W) + - timestamps: torch.Tensor of shape (N,) + - "image": str literal for modality routing + + Args: + image_size: Resize images to (image_size, image_size). + Defaults to 224. + to_tensor: Whether to convert images to tensors. + Defaults to True. + normalize: Whether to normalize pixel values. + Defaults to False. + mean: Per-channel means for normalization. Required if + normalize is True. + std: Per-channel standard deviations for normalization. + Required if normalize is True. + mode: PIL image mode conversion (e.g., "RGB", "L"). If + None, keeps the original mode. Defaults to None. + max_images: Maximum number of images per patient. If a + patient has more images, the most recent (by timestamp) + are kept. If None, all images are kept. Defaults to + None. + + Raises: + ValueError: If normalize is True but mean or std is missing. + ValueError: If mean/std are provided but normalize is False. + + Example: + >>> proc = TimeImageProcessor( + ... image_size=224, + ... normalize=True, + ... mean=[0.485, 0.456, 0.406], + ... std=[0.229, 0.224, 0.225], + ... ) + >>> paths = ["/data/xray1.png", "/data/xray2.png"] + >>> times = [0.0, 2.5] + >>> images, timestamps, tag = proc.process((paths, times)) + >>> images.shape + torch.Size([2, 3, 224, 224]) + >>> timestamps + tensor([0.0000, 2.5000]) + >>> tag + 'image' + """ + + def __init__( + self, + image_size: int = 224, + to_tensor: bool = True, + normalize: bool = False, + mean: Optional[List[float]] = None, + std: Optional[List[float]] = None, + mode: Optional[str] = None, + max_images: Optional[int] = None, + ) -> None: + self.image_size = image_size + self.to_tensor = to_tensor + self.normalize = normalize + self.mean = mean + self.std = std + self.mode = mode + self.max_images = max_images + self.n_channels = None + + if self.normalize and ( + self.mean is None or self.std is None + ): + raise ValueError( + "Normalization requires both mean and std to be " + "provided." + ) + if not self.normalize and ( + self.mean is not None or self.std is not None + ): + raise ValueError( + "Mean and std are provided but normalize is set " + "to False. Either provide normalize=True, or " + "remove mean and std." + ) + + self.transform = self._build_transform() + + def _build_transform(self) -> transforms.Compose: + """Build the torchvision transform pipeline. + + Returns: + A composed transform that applies mode conversion, + resizing, tensor conversion, and normalization as + configured. + """ + transform_list = [] + if self.mode is not None: + transform_list.append( + transforms.Lambda( + partial(_convert_mode, mode=self.mode) + ) + ) + if self.image_size is not None: + transform_list.append( + transforms.Resize( + (self.image_size, self.image_size) + ) + ) + if self.to_tensor: + transform_list.append(transforms.ToTensor()) + if self.normalize: + transform_list.append( + transforms.Normalize( + mean=self.mean, std=self.std + ) + ) + return transforms.Compose(transform_list) + + def _load_single_image( + self, path: Union[str, Path] + ) -> torch.Tensor: + """Load and transform a single image from disk. + + Called internally by process() for each image path in + the input list. + + Args: + path: Path to the image file. + + Returns: + Transformed image tensor of shape (C, H, W). + + Raises: + FileNotFoundError: If the image file does not exist. + """ + image_path = Path(path) + if not image_path.exists(): + raise FileNotFoundError( + f"Image file not found: {image_path}" + ) + with Image.open(image_path) as img: + img.load() + return self.transform(img) + + def fit( + self, samples: Iterable[Dict[str, Any]], field: str + ) -> None: + """Fit the processor by inferring n_channels from data. + + Scans samples to find the first valid entry for the given + field and infers the number of image channels from mode. + + Args: + samples: Iterable of sample dictionaries. + field: The field name to extract from samples. + """ + if self.mode == "L": + self.n_channels = 1 + elif self.mode == "RGBA": + self.n_channels = 4 + elif self.mode is not None: + self.n_channels = 3 + else: + for sample in samples: + if field in sample and sample[field] is not None: + image_paths, _ = sample[field] + if len(image_paths) > 0: + path = Path(image_paths[0]) + if path.exists(): + with Image.open(path) as img: + if img.mode == "L": + self.n_channels = 1 + elif img.mode == "RGBA": + self.n_channels = 4 + else: + self.n_channels = 3 + break + if self.n_channels is None: + self.n_channels = 3 + + def process( + self, + value: Tuple[ + List[Union[str, Path]], List[float] + ], + ) -> Tuple[torch.Tensor, torch.Tensor, str]: + """Process paired image paths and timestamps. + + Takes a tuple of (image_paths, time_differences) where + each image path corresponds to the time difference at the + same index. Images are sorted chronologically. If + max_images is set, only the most recent images are kept. + + This method is called by SampleBuilder.transform during + dataset processing. + + Args: + value: A tuple of two lists: + - image_paths: List of file paths to images. + - time_diffs: List of float time differences + from the patient's first admission (e.g., + in days). + Both lists must have the same length. + + Returns: + A tuple of: + - images: Stacked image tensor of shape + (N, C, H, W) where N is the number of images. + - timestamps: Float tensor of shape (N,) + containing the time differences. + - tag: The literal string "image" for modality + routing in the multimodal embedding model. + + Raises: + ValueError: If image_paths and time_diffs have + different lengths. + ValueError: If image_paths is empty. + FileNotFoundError: If any image file does not exist. + """ + image_paths, time_diffs = value + + if len(image_paths) != len(time_diffs): + raise ValueError( + f"image_paths length ({len(image_paths)}) and " + f"time_diffs length ({len(time_diffs)}) must " + f"match." + ) + if len(image_paths) == 0: + raise ValueError("image_paths must be non-empty.") + + paired = sorted( + zip(time_diffs, image_paths), key=lambda x: x[0] + ) + + if ( + self.max_images is not None + and len(paired) > self.max_images + ): + paired = paired[-self.max_images:] + + timestamps = [] + image_tensors = [] + for t, p in paired: + image_tensors.append(self._load_single_image(p)) + timestamps.append(t) + + images = torch.stack(image_tensors, dim=0) + timestamps = torch.tensor( + timestamps, dtype=torch.float32 + ) + + if self.n_channels is None: + self.n_channels = images.shape[1] + + return images, timestamps, "image" + + def size(self) -> Optional[int]: + """Return number of image channels. + + Mirrors the TimeseriesProcessor.size() pattern. Returns + None if fit() or process() has not been called yet. + + Returns: + Number of channels, or None if unknown. + """ + return self.n_channels + + def __repr__(self) -> str: + return ( + f"TimeImageProcessor(" + f"image_size={self.image_size}, " + f"to_tensor={self.to_tensor}, " + f"normalize={self.normalize}, " + f"mean={self.mean}, " + f"std={self.std}, " + f"mode={self.mode}, " + f"max_images={self.max_images})" + ) \ No newline at end of file diff --git a/tests/core/test_time_image_processor.py b/tests/core/test_time_image_processor.py new file mode 100644 index 000000000..51467aa5c --- /dev/null +++ b/tests/core/test_time_image_processor.py @@ -0,0 +1,315 @@ +# Author: Joshua Steier +# Description: Unit tests for TimeImageProcessor. + +import os +import shutil +import tempfile +import unittest +from pathlib import Path + +import torch +from PIL import Image + +from pyhealth.processors.time_image_processor import ( + TimeImageProcessor, +) + + +class TestTimeImageProcessor(unittest.TestCase): + """Tests for TimeImageProcessor.""" + + def setUp(self): + """Create temp directory with synthetic test images.""" + self.temp_dir = tempfile.mkdtemp() + + self.rgb_paths = [] + for i in range(5): + path = os.path.join( + self.temp_dir, f"rgb_{i}.png" + ) + img = Image.new( + "RGB", + (100 + i * 10, 100 + i * 10), + color=(255, i * 50, 0), + ) + img.save(path) + self.rgb_paths.append(path) + + self.gray_path = os.path.join( + self.temp_dir, "gray.png" + ) + Image.new("L", (80, 80), color=128).save( + self.gray_path + ) + + self.rgba_path = os.path.join( + self.temp_dir, "rgba.png" + ) + Image.new( + "RGBA", (90, 90), color=(255, 0, 0, 128) + ).save(self.rgba_path) + + self.times = [0.0, 1.5, 3.0, 7.0, 14.0] + + def tearDown(self): + """Remove temporary test directory.""" + shutil.rmtree(self.temp_dir) + + # ---- Initialization ---- + + def test_init_default(self): + """Default init sets expected attributes.""" + proc = TimeImageProcessor() + self.assertEqual(proc.image_size, 224) + self.assertTrue(proc.to_tensor) + self.assertFalse(proc.normalize) + self.assertIsNone(proc.mean) + self.assertIsNone(proc.std) + self.assertIsNone(proc.mode) + self.assertIsNone(proc.max_images) + + def test_init_custom(self): + """Custom init stores all arguments.""" + proc = TimeImageProcessor( + image_size=128, + to_tensor=True, + normalize=True, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + mode="RGB", + max_images=3, + ) + self.assertEqual(proc.image_size, 128) + self.assertEqual(proc.max_images, 3) + self.assertTrue(proc.normalize) + self.assertEqual(proc.mode, "RGB") + + def test_init_normalize_without_mean_std_raises(self): + """ValueError when normalize=True but no mean/std.""" + with self.assertRaises(ValueError): + TimeImageProcessor(normalize=True) + + def test_init_mean_std_without_normalize_raises(self): + """ValueError when mean/std given but normalize=False.""" + with self.assertRaises(ValueError): + TimeImageProcessor(mean=[0.5], std=[0.5]) + + # ---- Core process() ---- + + def test_process_basic_rgb(self): + """Basic RGB returns correct shapes and tag.""" + proc = TimeImageProcessor(image_size=64) + paths = self.rgb_paths[:3] + times = self.times[:3] + + images, timestamps, tag = proc.process( + (paths, times) + ) + + self.assertEqual(images.shape, (3, 3, 64, 64)) + self.assertEqual(timestamps.shape, (3,)) + self.assertEqual(tag, "image") + self.assertIsInstance(images, torch.Tensor) + self.assertIsInstance(timestamps, torch.Tensor) + + def test_process_single_image(self): + """Single image works correctly.""" + proc = TimeImageProcessor(image_size=32) + images, timestamps, tag = proc.process( + ([self.rgb_paths[0]], [0.0]) + ) + + self.assertEqual(images.shape, (1, 3, 32, 32)) + self.assertEqual(timestamps.shape, (1,)) + self.assertEqual(tag, "image") + + def test_process_grayscale(self): + """Grayscale mode produces single-channel output.""" + proc = TimeImageProcessor(image_size=64, mode="L") + images, timestamps, tag = proc.process( + ([self.gray_path], [0.0]) + ) + + self.assertEqual(images.shape, (1, 1, 64, 64)) + + def test_process_rgba_to_rgb(self): + """RGBA converted to RGB via mode parameter.""" + proc = TimeImageProcessor(image_size=64, mode="RGB") + images, timestamps, tag = proc.process( + ([self.rgba_path], [1.0]) + ) + + self.assertEqual(images.shape, (1, 3, 64, 64)) + + def test_process_with_normalization(self): + """Normalization applies without errors.""" + proc = TimeImageProcessor( + image_size=64, + normalize=True, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ) + images, timestamps, tag = proc.process( + (self.rgb_paths[:2], self.times[:2]) + ) + + self.assertEqual(images.shape, (2, 3, 64, 64)) + self.assertIsInstance(images, torch.Tensor) + + # ---- Chronological sorting ---- + + def test_process_sorts_by_timestamp(self): + """Images reordered chronologically by timestamp.""" + proc = TimeImageProcessor(image_size=32) + + paths = list(reversed(self.rgb_paths[:3])) + times = [10.0, 5.0, 1.0] + + _, timestamps, _ = proc.process((paths, times)) + + expected = torch.tensor([1.0, 5.0, 10.0]) + self.assertTrue(torch.equal(timestamps, expected)) + + # ---- max_images truncation ---- + + def test_max_images_truncates_to_most_recent(self): + """max_images keeps the N most recent images.""" + proc = TimeImageProcessor( + image_size=32, max_images=2 + ) + images, timestamps, tag = proc.process( + (self.rgb_paths[:4], self.times[:4]) + ) + + self.assertEqual(images.shape[0], 2) + self.assertEqual(timestamps.shape, (2,)) + self.assertAlmostEqual(timestamps[0].item(), 3.0) + self.assertAlmostEqual(timestamps[1].item(), 7.0) + + def test_max_images_no_truncation_when_under(self): + """No truncation when count is under max_images.""" + proc = TimeImageProcessor( + image_size=32, max_images=10 + ) + images, timestamps, _ = proc.process( + (self.rgb_paths[:3], self.times[:3]) + ) + + self.assertEqual(images.shape[0], 3) + + # ---- Error handling ---- + + def test_process_mismatched_lengths_raises(self): + """ValueError for mismatched paths and times.""" + proc = TimeImageProcessor() + with self.assertRaises(ValueError): + proc.process( + (self.rgb_paths[:3], self.times[:2]) + ) + + def test_process_empty_paths_raises(self): + """ValueError for empty image list.""" + proc = TimeImageProcessor() + with self.assertRaises(ValueError): + proc.process(([], [])) + + def test_process_invalid_path_raises(self): + """FileNotFoundError for nonexistent image.""" + proc = TimeImageProcessor() + with self.assertRaises(FileNotFoundError): + proc.process( + (["/nonexistent/img.png"], [0.0]) + ) + + # ---- Path types ---- + + def test_process_accepts_path_objects(self): + """Path objects accepted alongside strings.""" + proc = TimeImageProcessor(image_size=32) + paths = [Path(self.rgb_paths[0])] + images, _, _ = proc.process((paths, [0.0])) + + self.assertEqual(images.shape, (1, 3, 32, 32)) + + # ---- fit() ---- + + def test_fit_infers_channels_from_mode(self): + """fit() infers n_channels from mode parameter.""" + proc = TimeImageProcessor(mode="L") + proc.fit([], "xray") + self.assertEqual(proc.size(), 1) + + proc2 = TimeImageProcessor(mode="RGB") + proc2.fit([], "xray") + self.assertEqual(proc2.size(), 3) + + def test_fit_infers_channels_from_sample(self): + """fit() infers n_channels from actual image data.""" + proc = TimeImageProcessor(image_size=32) + samples = [ + { + "xray": ( + [self.rgb_paths[0]], + [0.0], + ), + } + ] + proc.fit(samples, "xray") + self.assertEqual(proc.size(), 3) + + def test_fit_defaults_to_3_channels(self): + """fit() defaults to 3 channels if nothing found.""" + proc = TimeImageProcessor() + proc.fit([], "xray") + self.assertEqual(proc.size(), 3) + + # ---- size() ---- + + def test_size_none_before_fit(self): + """size() returns None before fit or process.""" + proc = TimeImageProcessor() + self.assertIsNone(proc.size()) + + def test_size_set_after_process(self): + """size() returns channels after process().""" + proc = TimeImageProcessor(image_size=32) + proc.process(([self.rgb_paths[0]], [0.0])) + self.assertEqual(proc.size(), 3) + + # ---- Repr ---- + + def test_repr(self): + """__repr__ contains key parameter values.""" + proc = TimeImageProcessor( + image_size=128, max_images=5, mode="L" + ) + r = repr(proc) + self.assertIn("TimeImageProcessor", r) + self.assertIn("image_size=128", r) + self.assertIn("max_images=5", r) + self.assertIn("mode=L", r) + + # ---- Tensor properties ---- + + def test_output_values_in_valid_range(self): + """Without normalization, pixels are in [0, 1].""" + proc = TimeImageProcessor(image_size=32) + images, _, _ = proc.process( + (self.rgb_paths[:2], self.times[:2]) + ) + + self.assertTrue(torch.all(images >= 0)) + self.assertTrue(torch.all(images <= 1)) + + def test_timestamp_dtype_is_float32(self): + """Timestamps are always float32 tensors.""" + proc = TimeImageProcessor(image_size=32) + _, timestamps, _ = proc.process( + (self.rgb_paths[:2], self.times[:2]) + ) + + self.assertEqual(timestamps.dtype, torch.float32) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file