Skip to content
Open
24 changes: 22 additions & 2 deletions configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ callbacks:
monitor: validation/loss
mode: min

# opt-in callbacks — add to trainer.callbacks when needed
tiles_export:
_target_: stain_normalization.callbacks.TilesExport
output_dir: ${output_dir}
n_first: 10
sample_rate: 0.0005 # for test dataset cca 100 tiles

analysis_export:
_target_: stain_normalization.callbacks.AnalysisExport
output_dir: ${output_dir}

wsi_assembler:
_target_: stain_normalization.callbacks.WSIAssembler
output_dir: ${output_dir}

early_stopping:
_target_: lightning.pytorch.callbacks.EarlyStopping
monitor: validation/loss
Expand All @@ -32,6 +47,8 @@ model:
lambda_l1: 0.2
lambda_lum: 0.2
lambda_gdl: 0.1
normalize_mean: ${data.test.normalize.mean}
normalize_std: ${data.test.normalize.std}

trainer:
enable_checkpointing: True
Expand All @@ -46,11 +63,14 @@ trainer:
data:
batch_size: 64
num_workers: 8

metadata:
user: ???
experiment_name: Stain-Normalization
run_name: ???
run_name: ???
description: ???
hyperparams: ${model}


output_dir: ??? # ./data/outputs

5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"scikit-image>=0.25.2",
"rationai-staining @ git+https://gitlab.ics.muni.cz/rationai/digital-pathology/libraries/staining.git",
"openslide-python>=1.4.3",
"kornia>=0.8.2",
]

[dependency-groups]
Expand All @@ -26,7 +27,3 @@ dev = ["mypy", "ruff"]
[tool.mypy]
ignore_missing_imports = true

[tool.uv]
environments = ["sys_platform == 'linux'"]

override-dependencies = ["mlflow>=2.15.1"]
7 changes: 7 additions & 0 deletions stain_normalization/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from stain_normalization.callbacks._base import ImageCallback
from stain_normalization.callbacks.analysis_export import AnalysisExport
from stain_normalization.callbacks.tiles_export import TilesExport
from stain_normalization.callbacks.wsi_assembler import WSIAssembler


__all__ = ["AnalysisExport", "ImageCallback", "TilesExport", "WSIAssembler"]
17 changes: 17 additions & 0 deletions stain_normalization/callbacks/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Any

import numpy as np
import torch
from lightning import Callback


class ImageCallback(Callback):
"""Base callback providing tensor-to-image conversion.

Expects denormalized [0,1] tensors (denormalization is done in the model).
"""

@staticmethod
def tensor_to_image(tensor: torch.Tensor) -> np.ndarray[Any, Any]:
"""Convert [0,1] CHW tensor to uint8 HWC numpy array."""
return tensor.mul(255).byte().permute(1, 2, 0).cpu().numpy()
84 changes: 84 additions & 0 deletions stain_normalization/callbacks/tiles_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from pathlib import Path
from typing import Any

import torch
from lightning import LightningModule, Trainer
from PIL import Image

from stain_normalization.callbacks._base import ImageCallback
from stain_normalization.type_aliases import Outputs


class TilesExport(ImageCallback):
def __init__(
self,
output_dir: str | Path,
n_first: int = 10,
sample_rate: float = 0.0005,
) -> None:
super().__init__()
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.n_first = n_first
self.sample_rate = sample_rate
self._global_count: int = 0

def tensor_to_image(self, tensor: torch.Tensor) -> Image.Image: # type: ignore[override] # intentional: PIL Image is not subtype of ndarray
return Image.fromarray(super().tensor_to_image(tensor))

def _should_save(self) -> bool:
count = self._global_count
self._global_count += 1
if count < self.n_first:
return True
return torch.rand(1).item() < self.sample_rate

def on_test_batch_end( # type: ignore[override] # narrowed Lightning STEP_OUTPUT
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: Outputs,
batch: tuple[torch.Tensor, list[dict[str, Any]]],
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
_, data = batch
for b in range(len(outputs)):
slide_name = data[b]["slide_name"]
if not self._should_save():
continue

xy = data[b]["xy"]
slide_dir = self.output_dir / slide_name
slide_dir.mkdir(parents=True, exist_ok=True)

self.tensor_to_image(outputs[b]).save(slide_dir / f"{xy}_predicted.png")

original_image = Image.fromarray(data[b]["original_image"].astype("uint8"))
original_image.save(slide_dir / f"{xy}_original.png")

modified_image = Image.fromarray(
(data[b]["modified_image"] * 255).astype("uint8")
)
modified_image.save(slide_dir / f"{xy}_modified.png")

def on_predict_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: Outputs,
batch: tuple[torch.Tensor, list[dict[str, Any]]],
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
_, data = batch
for b in range(len(outputs)):
slide_name = data[b]["slide_name"]
if not self._should_save():
continue

xy = data[b]["xy"]
slide_dir = self.output_dir / slide_name
slide_dir.mkdir(parents=True, exist_ok=True)

self.tensor_to_image(outputs[b]).save(slide_dir / f"{xy}.png")
220 changes: 220 additions & 0 deletions stain_normalization/callbacks/wsi_assembler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import tempfile
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import numpy as np
import torch
from lightning import LightningModule, Trainer

from rationai.mlkit.lightning.callbacks import MultiloaderLifecycle

from stain_normalization.callbacks._base import ImageCallback
from stain_normalization.type_aliases import Outputs


@dataclass
class _SlideMeta:
path: str
level: int
extent_x: int
extent_y: int
tile_extent_x: int
tile_extent_y: int
mpp_x: float
mpp_y: float


@dataclass
class _SlideBuffers:
meta: _SlideMeta
temp_dir: tempfile.TemporaryDirectory[str]
result_buffer: np.memmap[Any, Any]
count_buffer: np.memmap[Any, Any]


class WSIAssembler(ImageCallback, MultiloaderLifecycle):
"""Assembles predicted tiles back into whole-slide pyramid TIFFs.

Uses one dataloader per slide (via MultiloaderLifecycle) — buffers are
opened on dataloader start and saved/freed on dataloader end.
"""

def __init__(
self,
output_dir: str | Path,
temp_dir: str | Path | None = None,
) -> None:
ImageCallback.__init__(self)
MultiloaderLifecycle.__init__(self)
self.output_dir = Path(output_dir)
self.temp_dir = str(temp_dir) if temp_dir else None
self._active: _SlideBuffers | None = None
self._active_name: str | None = None
self._failed_slides: list[str] = []

def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
self.output_dir.mkdir(parents=True, exist_ok=True)

def on_predict_dataloader_start(
self, trainer: Trainer, pl_module: LightningModule, dataloader_idx: int
) -> None:
slide = trainer.datamodule.predict.slides.iloc[dataloader_idx] # type: ignore[attr-defined]
meta = _SlideMeta(
path=slide.path,
level=int(slide.level),
extent_x=int(slide.extent_x),
extent_y=int(slide.extent_y),
tile_extent_x=int(slide.tile_extent_x),
tile_extent_y=int(slide.tile_extent_y),
mpp_x=float(slide.mpp_x),
mpp_y=float(slide.mpp_y),
)
slide_name = Path(slide.path).stem
self._open_slide(slide_name, meta)

def on_predict_dataloader_end(
self, trainer: Trainer, pl_module: LightningModule, dataloader_idx: int
) -> None:
self._close_slide()

def _open_slide(self, slide_name: str, meta: _SlideMeta) -> None:
"""Allocate memmap buffers for one slide."""
h, w = meta.extent_y, meta.extent_x

tmp = tempfile.TemporaryDirectory(
prefix=f"wsi_{slide_name}_", dir=self.temp_dir
)
result_buf = np.memmap(
Path(tmp.name) / "result.raw",
dtype=np.uint8,
mode="w+",
shape=(h, w, 3),
)
count_buf = np.memmap(
Path(tmp.name) / "count.raw",
dtype=np.uint8,
mode="w+",
shape=(h, w),
)

self._active = _SlideBuffers(
meta=meta,
temp_dir=tmp,
result_buffer=result_buf,
count_buffer=count_buf,
)
self._active_name = slide_name

def _close_slide(self) -> None:
"""Save and free the currently active slide."""
if self._active is None:
return
assert self._active_name is not None
slide_name = self._active_name
try:
self._save_slide(slide_name, self._active)
except Exception:
print(f"ERROR: Failed to save slide '{slide_name}'")
traceback.print_exc()
self._failed_slides.append(slide_name)
finally:
del self._active.result_buffer
del self._active.count_buffer
self._active.temp_dir.cleanup()
self._active = None
self._active_name = None

def on_predict_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: Outputs,
batch: tuple[torch.Tensor, list[dict[str, Any]]],
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
for b in range(len(outputs)):
tile = self.tensor_to_image(outputs[b])
metadata = batch[1][b]
x, y = (int(v) for v in metadata["xy"].split("_"))
self._place_tile(tile, x, y)

def _place_tile(self, tile: np.ndarray[Any, Any], x: int, y: int) -> None:
"""Place a predicted tile into the active slide buffer with overlap averaging."""
assert self._active is not None
sb = self._active
ex, ey = sb.meta.extent_x, sb.meta.extent_y

h = max(0, min(tile.shape[0], ey - y))
w = max(0, min(tile.shape[1], ex - x))
if h == 0 or w == 0:
return
tile = tile[:h, :w]

region = sb.result_buffer[y : y + h, x : x + w]
count = sb.count_buffer[y : y + h, x : x + w]

# Running average: avg = (old * n + new) / (n + 1)
overlap = count > 0
if overlap.any():
n = count[:, :, np.newaxis].astype(np.float32)
blended = np.where(
overlap[:, :, np.newaxis],
(region.astype(np.float32) * n + tile) / (n + 1),
tile,
)
sb.result_buffer[y : y + h, x : x + w] = np.clip(blended, 0, 255).astype(
np.uint8
)
else:
sb.result_buffer[y : y + h, x : x + w] = tile

sb.count_buffer[y : y + h, x : x + w] = count + 1

def on_predict_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
if self._failed_slides:
print(
f"WARNING: Failed to save {len(self._failed_slides)} slide(s): "
f"{self._failed_slides}"
)

def _save_slide(self, slide_name: str, sb: _SlideBuffers) -> None:
# Imported here — module-level import causes OpenSlide segfault (libtiff conflict).
import pyvips

meta = sb.meta
sb.result_buffer.flush()
sb.count_buffer.flush()

result_path = Path(sb.temp_dir.name) / "result.raw"
count_path = Path(sb.temp_dir.name) / "count.raw"

result_img = pyvips.Image.rawload(
str(result_path), meta.extent_x, meta.extent_y, 3
)
result_img = result_img.copy(interpretation=pyvips.Interpretation.SRGB)

count_img = pyvips.Image.rawload(
str(count_path), meta.extent_x, meta.extent_y, 1
)
mask = count_img > 0
# add white background for untouched areas (count=0)
white = (pyvips.Image.black(meta.extent_x, meta.extent_y, bands=3) + 255).cast(
pyvips.BandFormat.UCHAR
)
final_img = mask.ifthenelse(result_img, white)

output_path = self.output_dir / f"{slide_name}_norm.tiff"
final_img.tiffsave(
str(output_path),
bigtiff=True,
compression=pyvips.enums.ForeignTiffCompression.DEFLATE,
tile=True,
tile_width=512,
tile_height=512,
pyramid=True,
xres=1000.0 / meta.mpp_x,
yres=1000.0 / meta.mpp_y,
)
Loading