From ab9bfc5cd0bc3eb0c8558eadce2823096b69d23d Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Mon, 12 Aug 2024 15:59:26 -0400 Subject: [PATCH] move all core methods to hestcore --- bin/extract_patch_embeddings.py | 30 +- core/preprocessing/conch_patch_embedder.py | 90 +-- core/preprocessing/hest_modules/SegDataset.py | 73 --- core/preprocessing/hest_modules/__init__.py | 0 .../hest_modules/segmentation.py | 386 ------------ core/preprocessing/hest_modules/wsi.py | 584 ------------------ core/utils/utils.py | 35 ++ requirements.txt | 3 +- 8 files changed, 86 insertions(+), 1115 deletions(-) delete mode 100644 core/preprocessing/hest_modules/SegDataset.py delete mode 100644 core/preprocessing/hest_modules/__init__.py delete mode 100644 core/preprocessing/hest_modules/segmentation.py delete mode 100644 core/preprocessing/hest_modules/wsi.py diff --git a/bin/extract_patch_embeddings.py b/bin/extract_patch_embeddings.py index 6156a21..cb7188d 100644 --- a/bin/extract_patch_embeddings.py +++ b/bin/extract_patch_embeddings.py @@ -1,14 +1,15 @@ import sys; sys.path.append('../') import argparse -import os import logging +import os import openslide from tqdm import tqdm +from core.utils.utils import get_pixel_size -from core.preprocessing.conch_patch_embedder import TileEmbedder -from core.preprocessing.hest_modules.segmentation import TissueSegmenter -from core.preprocessing.hest_modules.wsi import get_pixel_size, OpenSlideWSI +from core.preprocessing.conch_patch_embedder import ConchTileEmbedder +from hestcore.wsi import OpenSlideWSI +from hestcore.segmentation import segment_tissue_deep # Configure logger logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @@ -38,8 +39,7 @@ def process(slide_dir, out_dir, patch_mag, patch_size): os.makedirs(patch_emb_path, exist_ok=True) # create tissue segmenter and tile embedder - segmenter = TissueSegmenter(save_path=seg_path, batch_size=64) - embedder = TileEmbedder(target_patch_size=patch_size, target_mag=patch_mag, save_path=out_dir) + embedder = ConchTileEmbedder(target_patch_size=patch_size, target_mag=patch_mag, save_path=out_dir) for fn in tqdm(fnames): @@ -49,13 +49,21 @@ def process(slide_dir, out_dir, patch_mag, patch_size): fn_no_extension = os.path.splitext(fn)[0] # 2. segment tissue - gdf_contours = segmenter.segment_tissue( - wsi=wsi, - pixel_size=pixel_size, - save_bn=fn_no_extension, + gdf_contours = segment_tissue_deep( + wsi, + pixel_size, + batch_size=64 ) - # 3. extract patches and embeddings + # 3. save segmentation + visualization + os.makedirs(os.path.join(out_dir, 'geojson'), exist_ok=True) + os.makedirs(os.path.join(out_dir, 'jpeg'), exist_ok=True) + seg_name = fn_no_extension + '_tissue_vis.jpeg' + wsi.get_tissue_vis(gdf_contours).save(os.path.join(out_dir, 'jpeg', seg_name)) + seg_name = fn_no_extension + '_tissue_mask.geojson' + gdf_contours.to_file(os.path.join(out_dir, 'geojson', seg_name), driver="GeoJSON") + + # 4. extract patches and embeddings embedder.embed_tiles( wsi=wsi, gdf_contours=gdf_contours, diff --git a/core/preprocessing/conch_patch_embedder.py b/core/preprocessing/conch_patch_embedder.py index c99d945..d5fcc19 100644 --- a/core/preprocessing/conch_patch_embedder.py +++ b/core/preprocessing/conch_patch_embedder.py @@ -1,16 +1,18 @@ -from tqdm import tqdm -import numpy as np -import h5py import os -from PIL import Image - -import torch -from torch.utils.data import Dataset +import h5py +import numpy as np +import torch +import torchvision.transforms as transforms from conch.open_clip_custom import create_model_from_pretrained - +from hestcore.datasets import WSIPatcherDataset # from core.preprocessing.hest_modules.wsi import WSIPatcher -from core.preprocessing.hest_modules.wsi import OpenSlideWSIPatcher, get_pixel_size +from hestcore.wsi import OpenSlideWSIPatcher +from PIL import Image +from torch.utils.data import Dataset +from tqdm import tqdm + +from core.utils.utils import get_pixel_size, mag_to_px_size def save_hdf5(output_fpath, @@ -72,7 +74,7 @@ def collate_features(batch): return features, coords -class TileEmbedder: +class ConchTileEmbedder: def __init__(self, model_name='conch_ViT-B-16', model_repo='hf_hub:MahmoodLab/conch', @@ -100,13 +102,23 @@ def embed_tiles(self, wsi, gdf_contours, fn) -> str: patching_save_path = os.path.join(self.save_path, 'patches', f'{fn}_patches.png') embedding_save_path = os.path.join(self.save_path, 'patch_embeddings', f'{fn}.h5') - dataset = TileDataset( - wsi=wsi, - gdf_contours=gdf_contours, - target_patch_size=self.target_patch_size, - target_mag=self.target_mag, - eval_transform=self.img_transforms, - save_path=patching_save_path) + dst_pixel_size = mag_to_px_size(self.target_mag) + src_pixel_size = get_pixel_size(wsi.img) + + patcher = wsi.create_patcher( + self.target_patch_size, + src_pixel_size, + dst_pixel_size, + mask=gdf_contours, + pil=True + ) + + conch_transforms = transforms.Compose([ + self.img_transforms, + transforms.Lambda(lambda x: torch.unsqueeze(x, 0)) + ]) + + dataset = WSIPatcherDataset(patcher, transform=conch_transforms) dataloader = torch.utils.data.DataLoader( dataset, @@ -130,46 +142,4 @@ def embed_tiles(self, wsi, gdf_contours, fn) -> str: } save_hdf5(embedding_save_path, mode=mode, asset_dict=asset_dict) - return embedding_save_path - - -class TileDataset(Dataset): - def __init__(self, wsi, gdf_contours, target_patch_size, target_mag, eval_transform, save_path=None): - self.wsi = wsi - self.gdf_contours = gdf_contours - self.eval_transform = eval_transform - - self.patcher = OpenSlideWSIPatcher( - wsi=wsi, - patch_size=target_patch_size, - src_pixel_size=get_pixel_size(wsi.img), - dst_pixel_size=self.mag_to_px_size(target_mag), - mask=gdf_contours, - coords_only=False, - ) - self.patcher.save_visualization(path=save_path) - - @staticmethod - def mag_to_px_size(mag): - if mag == 5: return 2.0 - if mag == 10: return 1.0 - if mag == 20: return 0.5 - if mag == 40: return 0.25 - else: raise ValueError('Magnification should be in [5, 10, 20, 40].') - - # def _load_coords(self): - # with h5py.File(self.coords_h5_fpath, "r") as f: - # self.attr_dict = {k: dict(f[k].attrs) for k in f.keys() if len(f[k].attrs) > 0} - # self.coords = f['coords'][:] - # self.patch_size = f['coords'].attrs['patch_size'] - # self.custom_downsample = f['coords'].attrs['custom_downsample'] - # self.target_patch_size = int(self.patch_size) // int(self.custom_downsample) if self.custom_downsample > 1 else self.patch_size - - def __len__(self): - return len(self.patcher) - - def __getitem__(self, idx): - img, x, y = self.patcher[idx] - img = Image.fromarray(img, 'RGB') - img = self.eval_transform(img).unsqueeze(dim=0) - return img, (x, y) + return embedding_save_path \ No newline at end of file diff --git a/core/preprocessing/hest_modules/SegDataset.py b/core/preprocessing/hest_modules/SegDataset.py deleted file mode 100644 index 2a2abe6..0000000 --- a/core/preprocessing/hest_modules/SegDataset.py +++ /dev/null @@ -1,73 +0,0 @@ -import os - -import numpy as np -from PIL import Image -from torch.utils.data import Dataset -from tqdm import tqdm - -from .wsi import WSIPatcher - - -class SegFileDataset(Dataset): - masks = [] - patches = [] - coords = [] - - def __init__(self, root_path, transform): - self._load_paths(root_path) - - self.transform = transform - - def _load_paths(self, root_path): - self.mask_paths = [] - self.patch_paths = [] - self.coords = [] - for mask_filename in tqdm(os.listdir(root_path)): - name = mask_filename.split('.')[0] - pxl_x, pxl_y = int(name.split('_')[0]), int(name.split('_')[1]) - self.patch_paths.append(os.path.join(root_path, mask_filename)) - self.coords.append([pxl_x, pxl_y]) - - - def __len__(self): - return len(self.patch_paths) - - def __getitem__(self, index): - - with Image.open(self.patch_paths[index]) as patch: - patch = np.array(patch) - coord = self.coords[index] - - sample = patch - - if self.transform: - sample = self.transform(sample) - - return sample, coord - - -class SegWSIDataset(Dataset): - masks = [] - patches = [] - coords = [] - - def __init__(self, patcher: WSIPatcher, transform): - self.patcher = patcher - - self.cols, self.rows = self.patcher.get_cols_rows() - - self.transform = transform - - - def __len__(self): - return len(self.patcher) - - def __getitem__(self, index): - col = index % self.cols - row = index // self.cols - tile, x, y = self.patcher.get_tile(col, row) - - if self.transform: - tile = self.transform(tile) - - return tile, (x, y) \ No newline at end of file diff --git a/core/preprocessing/hest_modules/__init__.py b/core/preprocessing/hest_modules/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/core/preprocessing/hest_modules/segmentation.py b/core/preprocessing/hest_modules/segmentation.py deleted file mode 100644 index 242e629..0000000 --- a/core/preprocessing/hest_modules/segmentation.py +++ /dev/null @@ -1,386 +0,0 @@ -## Taken from HEST - -from __future__ import annotations - -import os -import pickle -from functools import partial -from typing import Union - -import cv2 -import numpy as np -import pandas as pd -from geopandas import gpd -from huggingface_hub import snapshot_download -from PIL import Image -from shapely import Polygon -import openslide -from tqdm import tqdm -from pathlib import Path - -import torch -from torch import nn -from torch.utils.data import DataLoader -from torchvision import transforms - -from core.preprocessing.hest_modules.wsi import WSI, wsi_factory -from .SegDataset import SegWSIDataset - - -def get_path_relative(file, path) -> str: - curr_dir = os.path.dirname(os.path.abspath(file)) - return os.path.join(curr_dir, path) - -def make_valid(polygon): - for i in [0, 0.1, -0.1, 0.2]: - new_polygon = polygon.buffer(i) - if isinstance(new_polygon, Polygon) and new_polygon.is_valid: - return new_polygon - raise Exception("Failed to make a valid polygon") - - -class TissueSegmenter: - def __init__(self, - model_name='deeplabv3_seg_v4.ckpt', - batch_size=8, - auto_download=True, - num_workers=8, - save_path=None): - self.model_name = model_name - self.batch_size = batch_size - self.auto_download = auto_download - self.num_workers = num_workers - self.save_path = save_path - self.model = self._load_model() - - def _load_model(self): - model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50') - model.classifier[4] = nn.Conv2d( - in_channels=256, - out_channels=2, - kernel_size=1, - stride=1 - ) - - if self.auto_download: - model_dir = Path(__file__).resolve().parents[3] / 'models' - snapshot_download(repo_id="MahmoodLab/hest-tissue-seg", repo_type='model', local_dir=model_dir, allow_patterns=self.model_name) - - weights_path = model_dir / self.model_name - - if torch.cuda.is_available(): - checkpoint = torch.load(weights_path, weights_only=False) - else: - checkpoint = torch.load(weights_path, map_location=torch.device('cpu'), weights_only=False) - - new_state_dict = {} - for key in checkpoint['state_dict']: - if 'aux' in key: - continue - new_key = key.replace('model.', '') - new_state_dict[new_key] = checkpoint['state_dict'][key] - model.load_state_dict(new_state_dict) - - if torch.cuda.is_available(): - model.cuda() - - model.eval() - return model - - def segment_tissue(self, - wsi: Union[np.ndarray, openslide.OpenSlide, WSI], - pixel_size: float, - save_bn: str=None, - fast_mode=False, - dst_pixel_size=1, - patch_size_um=512) -> gpd.GeoDataFrame: - src_pixel_size = pixel_size - - if fast_mode and dst_pixel_size == 1: - dst_pixel_size = 2 - - patch_size_deeplab = 512 - scale = src_pixel_size / dst_pixel_size - patch_size_src = round(patch_size_um / scale) - wsi = wsi_factory(wsi) - - patcher = wsi.create_patcher(patch_size_deeplab, src_pixel_size, dst_pixel_size) - - eval_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) - dataset = SegWSIDataset(patcher, eval_transforms) - dataloader = DataLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers) - - cols, rows = patcher.get_cols_rows() - width, height = patch_size_deeplab * cols, patch_size_deeplab * rows - stitched_img = np.zeros((height, width), dtype=np.uint8) - src_to_deeplab_scale = patch_size_deeplab / patch_size_src - - with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16): - for batch in tqdm(dataloader, total=len(dataloader)): - imgs, coords = batch - if torch.cuda.is_available(): - imgs = imgs.cuda() - masks = self.model(imgs)['out'] - preds = masks.argmax(1).to(torch.uint8).detach() - torch.cuda.synchronize() - preds = preds.cpu().numpy() - coords = np.column_stack((coords[0], coords[1])) - - for i in range(preds.shape[0]): - pred = preds[i] - coord = coords[i] - x, y = round(coord[0] * src_to_deeplab_scale), round(coord[1] * src_to_deeplab_scale) - y_end = min(y + patch_size_deeplab, height) - x_end = min(x + patch_size_deeplab, width) - stitched_img[y:y_end, x:x_end] += pred[:y_end-y, :x_end-x] - - mask = (stitched_img > 0).astype(np.uint8) - gdf_contours = mask_to_gdf(mask, max_nb_holes=5, pixel_size=src_pixel_size, contour_scale=1 / src_to_deeplab_scale) - - if self.save_path is not None and save_bn is not None: - os.makedirs(os.path.join(self.save_path, 'pkl'), exist_ok=True) - os.makedirs(os.path.join(self.save_path, 'geojson'), exist_ok=True) - os.makedirs(os.path.join(self.save_path, 'jpeg'), exist_ok=True) - seg_name = save_bn + '_tissue_mask.jpeg' - get_tissue_vis(wsi, gdf_contours).save(os.path.join(self.save_path, 'jpeg', seg_name)) - seg_name = save_bn + '_tissue_mask.geojson' - gdf_contours.to_file(os.path.join(self.save_path, 'geojson', seg_name), driver="GeoJSON") - seg_name = save_bn + '_tissue_mask.pkl' - with open(os.path.join(self.save_path, 'pkl', seg_name), "wb") as f: - pickle.dump(gdf_contours, f) - - return gdf_contours - - -def save_pkl(filename, save_object): - writer = open(filename,'wb') - pickle.dump(save_object, writer) - writer.close() - - -def mask_rgb(rgb: np.ndarray, mask: np.ndarray) -> np.ndarray: - """Mask an RGB image - - Args: - rgb (np.ndarray): RGB image to mask with shape (height, width, 3) - mask (np.ndarray): Binary mask with shape (height, width) - - Returns: - np.ndarray: Masked image - """ - assert ( - rgb.shape[:-1] == mask.shape - ), "Mask and RGB shape are different. Cannot mask when source and mask have different dimension." - mask_positive = np.dstack([mask, mask, mask]) - mask_negative = np.dstack([~mask, ~mask, ~mask]) - positive = rgb * mask_positive - negative = rgb * mask_negative - negative = 255 * (negative > 0.0001).astype(int) - - masked_image = positive + negative - - return np.clip(masked_image, a_min=0, a_max=255) - - -def contours_to_img( - contours: gpd.GeoDataFrame, - img: np.ndarray, - draw_contours=False, - thickness=1, - downsample=1., - line_color=(0, 255, 0) -) -> np.ndarray: - draw_cont = partial(cv2.drawContours, contourIdx=-1, thickness=thickness, lineType=cv2.LINE_8) - draw_cont_fill = partial(cv2.drawContours, contourIdx=-1, thickness=cv2.FILLED) - - groups = contours.groupby('tissue_id') - for _, group in groups: - - for _, row in group.iterrows(): - cont = np.array([[round(x * downsample), round(y * downsample)] for x, y in row.geometry.exterior.coords]) - holes = [np.array([[round(x * downsample), round(y * downsample)] for x, y in hole.coords]) for hole in row.geometry.interiors] - - draw_cont_fill(image=img, contours=[cont], color=line_color) - - for hole in holes: - draw_cont_fill(image=img, contours=[hole], color=(0, 0, 0)) - - if draw_contours: - draw_cont(image=img, contours=[cont], color=line_color) - return img - - -def get_tissue_vis( - img: Union[np.ndarray, openslide.OpenSlide, WSI], - tissue_contours: gpd.GeoDataFrame, - line_color=(0, 255, 0), - line_thickness=5, - target_width=1000, - seg_display=True, - ) -> Image: - tissue_contours = tissue_contours.copy() - - wsi = wsi_factory(img) - - width, height = wsi.get_dimensions() - downsample = target_width / width - - top_left = (0,0) - - img = wsi.get_thumbnail(round(width * downsample), round(height * downsample)) - - if tissue_contours is None: - return Image.fromarray(img) - - downscaled_mask = np.zeros(img.shape[:2], dtype=np.uint8) - downscaled_mask = np.expand_dims(downscaled_mask, axis=-1) - downscaled_mask = downscaled_mask * np.array([0, 0, 0]).astype(np.uint8) - - if tissue_contours is not None and seg_display: - downscaled_mask = contours_to_img( - tissue_contours, - downscaled_mask, - draw_contours=True, - thickness=line_thickness, - downsample=downsample, - line_color=line_color - ) - - alpha = 0.4 - img = cv2.addWeighted(img, 1 - alpha, downscaled_mask, alpha, 0) - img = img.astype(np.uint8) - - return Image.fromarray(img) - - -def filter_contours(contours, hierarchy, filter_params, scale, pixel_size): - """ - Filter contours by: area - """ - filtered = [] - - # find indices of foreground contours (parent == -1) - if len(hierarchy) == 0: - hierarchy_1 = [] - else: - hierarchy_1 = np.flatnonzero(hierarchy[:,1] == -1) - all_holes = [] - - # loop through foreground contour indices - for cont_idx in hierarchy_1: - # actual contour - cont = contours[cont_idx] - # indices of holes contained in this contour (children of parent contour) - holes = np.flatnonzero(hierarchy[:, 1] == cont_idx) - # take contour area (includes holes) - a = cv2.contourArea(cont) - # calculate the contour area of each hole - hole_areas = [cv2.contourArea(contours[hole_idx]) for hole_idx in holes] - # actual area of foreground contour region - a = a - np.array(hole_areas).sum() - a *= pixel_size ** 2 - - if a == 0: continue - - - - if tuple((filter_params['a_t'],)) < tuple((a,)): - - if (filter_params['filter_color_mode'] == 'none') or (filter_params['filter_color_mode'] is None): - filtered.append(cont_idx) - holes = [hole_idx for hole_idx in holes if cv2.contourArea(contours[hole_idx]) * pixel_size ** 2 > filter_params['min_hole_area']] - all_holes.append(holes) - else: - raise Exception() - - - # for parent in filtered: - # all_holes.append(np.flatnonzero(hierarchy[:, 1] == parent)) - - ##### TODO: re-implement this in a single for-loop that - ##### loops through both parent contours and holes - - foreground_contours = [contours[cont_idx] for cont_idx in filtered] - - hole_contours = [] - - for hole_ids in all_holes: - unfiltered_holes = [contours[idx] for idx in hole_ids ] - unfilered_holes = sorted(unfiltered_holes, key=cv2.contourArea, reverse=True) - # take max_n_holes largest holes by area - filtered_holes = unfilered_holes[:filter_params['max_n_holes']] - #filtered_holes = [] - - # filter these holes - #for hole in unfilered_holes: - # if cv2.contourArea(hole) > filter_params['a_h']: - # filtered_holes.append(hole) - - hole_contours.append(filtered_holes) - - return foreground_contours, hole_contours - - -def mask_to_gdf(mask: np.ndarray, keep_ids = [], exclude_ids=[], max_nb_holes=0, min_contour_area=1000, pixel_size=1, contour_scale=1.): - TARGET_EDGE_SIZE = 2000 - scale = TARGET_EDGE_SIZE / mask.shape[0] - - downscaled_mask = cv2.resize(mask, (round(mask.shape[1] * scale), round(mask.shape[0] * scale))) - - # Find and filter contours - if max_nb_holes == 0: - contours, hierarchy = cv2.findContours(downscaled_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) - else: - contours, hierarchy = cv2.findContours(downscaled_mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) # Find contours - #print('Num Contours Before Filtering:', len(contours)) - if hierarchy is None: - hierarchy = [] - else: - hierarchy = np.squeeze(hierarchy, axis=(0,))[:, 2:] - - filter_params = { - 'filter_color_mode': 'none', - 'max_n_holes': max_nb_holes, - 'a_t': min_contour_area * pixel_size ** 2, - 'min_hole_area': 4000 * pixel_size ** 2 - } - - if filter_params: - foreground_contours, hole_contours = filter_contours(contours, hierarchy, filter_params, scale, pixel_size) # Necessary for filtering out artifacts - - - if len(foreground_contours) == 0: - raise Exception('no contour detected') - else: - contours_tissue = scale_contour_dim(foreground_contours, contour_scale / scale) - contours_holes = scale_holes_dim(hole_contours, contour_scale / scale) - - if len(keep_ids) > 0: - contour_ids = set(keep_ids) - set(exclude_ids) - else: - contour_ids = set(np.arange(len(contours_tissue))) - set(exclude_ids) - - tissue_ids = [i for i in contour_ids] - polygons = [] - for i in contour_ids: - holes = [contours_holes[i][j].squeeze(1) for j in range(len(contours_holes[i]))] if len(contours_holes[i]) > 0 else None - polygon = Polygon(contours_tissue[i].squeeze(1), holes=holes) - if not polygon.is_valid: - polygon = make_valid(polygon) - polygons.append(polygon) - - gdf_contours = gpd.GeoDataFrame(pd.DataFrame(tissue_ids, columns=['tissue_id']), geometry=polygons) - - return gdf_contours - - -def scale_holes_dim(contours, scale): - r""" - """ - return [[np.array(hole * scale, dtype = 'int32') for hole in holes] for holes in contours] - - -def scale_contour_dim(contours, scale): - r""" - """ - return [np.array(cont * scale, dtype='int32') for cont in contours] \ No newline at end of file diff --git a/core/preprocessing/hest_modules/wsi.py b/core/preprocessing/hest_modules/wsi.py deleted file mode 100644 index 18d8b0f..0000000 --- a/core/preprocessing/hest_modules/wsi.py +++ /dev/null @@ -1,584 +0,0 @@ -## Code taken from https://github.com/mahmoodlab/HEST/tree/main wsi.py -## TODO replace by an independent package - - -from __future__ import annotations - -import warnings -from abc import abstractmethod -from functools import partial -from typing import Tuple, Union - -import cv2 -import geopandas as gpd -import numpy as np -import openslide -from PIL import Image - - -class CucimWarningSingleton: - _warned_cucim = False - - @classmethod - def warn(cls): - if cls._warned_cucim is False: - # warnings.warn("CuImage is not available. Ensure you have a GPU and cucim installed to use GPU acceleration.") - cls._warned_cucim = True - return cls._warned_cucim - - -def is_cuimage(img): - try: - from cucim import CuImage - except ImportError: - CuImage = None - CucimWarningSingleton.warn() - return CuImage is not None and isinstance(img, CuImage) # type: ignore - - -class WSI: - - def __init__(self, img): - self.img = img - - if not (isinstance(img, openslide.OpenSlide) or isinstance(img, np.ndarray) or is_cuimage(img)) : - raise ValueError(f"Invalid type for img {type(img)}") - - self.width, self.height = self.get_dimensions() - - @abstractmethod - def numpy(self) -> np.ndarray: - pass - - @abstractmethod - def get_dimensions(self): - pass - - @abstractmethod - def read_region(self, location, level, size) -> np.ndarray: - pass - - @abstractmethod - def get_thumbnail(self, width, height): - pass - - def __repr__(self) -> str: - width, height = self.get_dimensions() - - return f"" - - @abstractmethod - def create_patcher( - self, - patch_size: int, - src_pixel_size: float, - dst_pixel_size: float = None, - overlap: int = 0, - mask: gpd.GeoDataFrame = None, - coords_only = False, - custom_coords = None - ) -> WSIPatcher: - pass - - -def wsi_factory(img) -> WSI: - try: - from cucim import CuImage - except ImportError: - CuImage = None - CucimWarningSingleton.warn() - - if isinstance(img, WSI): - return img - elif isinstance(img, openslide.OpenSlide): - return OpenSlideWSI(img) - elif isinstance(img, np.ndarray): - return NumpyWSI(img) - elif is_cuimage(img): - return CuImageWSI(img) - elif isinstance(img, str): - if CuImage is not None: - return CuImageWSI(CuImage(img)) - else: - warnings.warn("Cucim isn't available, opening the image with OpenSlide (will be slower)") - return OpenSlideWSI(openslide.OpenSlide(img)) - else: - raise ValueError(f'type {type(img)} is not supported') - -class NumpyWSI(WSI): - def __init__(self, img: np.ndarray): - super().__init__(img) - - def numpy(self) -> np.ndarray: - return self.img - - def get_dimensions(self): - return self.img.shape[1], self.img.shape[0] - - def read_region(self, location, level, size) -> np.ndarray: - img = self.img - x_start, y_start = location[0], location[1] - x_size, y_size = size[0], size[1] - return img[y_start:y_start + y_size, x_start:x_start + x_size] - - def get_thumbnail(self, width, height) -> np.ndarray: - return cv2.resize(self.img, (width, height)) - - def create_patcher( - self, - patch_size: int, - src_pixel_size: float, - dst_pixel_size: float = None, - overlap: int = 0, - mask: gpd.GeoDataFrame = None, - coords_only = False, - custom_coords = None - ) -> WSIPatcher: - return NumpyWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only, custom_coords) - - -class OpenSlideWSI(WSI): - def __init__(self, img: openslide.OpenSlide): - super().__init__(img) - - def numpy(self) -> np.ndarray: - return self.get_thumbnail(self.width, self.height) - - def get_dimensions(self): - return self.img.dimensions - - def read_region(self, location, level, size) -> np.ndarray: - return np.array(self.img.read_region(location, level, size)) - - def get_thumbnail(self, width, height): - return np.array(self.img.get_thumbnail((width, height))) - - def get_best_level_for_downsample(self, downsample): - return self.img.get_best_level_for_downsample(downsample) - - def level_dimensions(self): - return self.img.level_dimensions - - def level_downsamples(self): - return self.img.level_downsamples - - def create_patcher( - self, - patch_size: int, - src_pixel_size: float, - dst_pixel_size: float = None, - overlap: int = 0, - mask: gpd.GeoDataFrame = None, - coords_only = False, - custom_coords = None - ) -> WSIPatcher: - return OpenSlideWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only, custom_coords) - -class CuImageWSI(WSI): - def __init__(self, img: 'CuImage'): - super().__init__(img) - - def numpy(self) -> np.ndarray: - return self.get_thumbnail(self.width, self.height) - - def get_dimensions(self): - return self.img.resolutions['level_dimensions'][0] - - def read_region(self, location, level, size) -> np.ndarray: - return np.array(self.img.read_region(location=location, level=level, size=size)) - - def get_thumbnail(self, width, height): - downsample = self.width / width - downsamples = self.img.resolutions['level_downsamples'] - closest = 0 - for i in range(len(downsamples)): - if downsamples[i] > downsample: - break - closest = i - - curr_width, curr_height = self.img.resolutions['level_dimensions'][closest] - thumbnail = np.array(self.img.read_region(location=(0, 0), level=closest, size=(curr_width, curr_height))) - thumbnail = cv2.resize(thumbnail, (width, height)) - - return thumbnail - - def get_best_level_for_downsample(self, downsample): - downsamples = self.img.resolutions['level_downsamples'] - last = 0 - for i in range(len(downsamples)): - down = downsamples[i] - if downsample < down: - return last - last = i - return last - - def level_dimensions(self): - return self.img.resolutions['level_dimensions'] - - def level_downsamples(self): - return self.img.resolutions['level_downsamples'] - - def create_patcher( - self, - patch_size: int, - src_pixel_size: float, - dst_pixel_size: float = None, - overlap: int = 0, - mask: gpd.GeoDataFrame = None, - coords_only = False, - custom_coords = None - ) -> WSIPatcher: - return CuImageWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only, custom_coords) - - -class WSIPatcher: - """ Iterator class to handle patching, patch scaling and tissue mask intersection """ - - def __init__( - self, - wsi: WSI, - patch_size: int, - src_pixel_size: float, - dst_pixel_size: float = None, - overlap: int = 0, - mask: gpd.GeoDataFrame = None, - coords_only = False, - custom_coords = None - ): - """ Initialize patcher, compute number of (masked) rows, columns. - - Args: - wsi (WSI): wsi to patch - patch_size (int): patch width/height in pixel on the slide after rescaling - src_pixel_size (float, optional): pixel size in um/px of the slide before rescaling. Defaults to None. - dst_pixel_size (float, optional): pixel size in um/px of the slide after rescaling. Defaults to None. - overlap (int, optional): overlap size in pixel before rescaling. Defaults to 0. - mask (gpd.GeoDataFrame, optional): geopandas dataframe of Polygons. Defaults to None. - coords_only (bool, optional): whenever to extract only the coordinates insteaf of coordinates + tile. Default to False. - """ - self.wsi = wsi - self.overlap = overlap - self.width, self.height = self.wsi.get_dimensions() - self.patch_size_target = patch_size - self.mask = mask - self.i = 0 - self.coords_only = coords_only - self.custom_coords = custom_coords - - if dst_pixel_size is None: - self.downsample = 1. - else: - self.downsample = dst_pixel_size / src_pixel_size - - self.patch_size_src = round(patch_size * self.downsample) - - self.level, self.patch_size_level, self.overlap_level = self._prepare() - - if custom_coords is None: - self.cols, self.rows = self._compute_cols_rows() - - col_rows = np.array([ - [col, row] - for col in range(self.cols) - for row in range(self.rows) - ]) - coords = np.array([self._colrow_to_xy(xy[0], xy[1]) for xy in col_rows]) - else: - coords = custom_coords - - if self.mask is not None: - self.valid_patches_nb, self.valid_coords = self._compute_masked(coords) - else: - self.valid_patches_nb, self.valid_coords = len(coords), coords - - def _colrow_to_xy(self, col, row): - """ Convert col row of a tile to its top-left coordinates before rescaling (x, y) """ - x = col * (self.patch_size_src) - self.overlap * np.clip(col - 1, 0, None) - y = row * (self.patch_size_src) - self.overlap * np.clip(row - 1, 0, None) - return (x, y) - - # def _compute_masked(self, coords) -> None: - # """ Compute tiles which center falls under the tissue """ - - # xy_centers = coords + self.patch_size_src // 2 - - # union_mask = self.mask.union_all() - - # points = gpd.points_from_xy(xy_centers[:, 0], xy_centers[:, 1]) - # valid_mask = gpd.GeoSeries(points).within(union_mask).values - # valid_patches_nb = valid_mask.sum() - # valid_coords = coords[valid_mask] - # return valid_patches_nb, valid_coords - - def _compute_masked(self, coords) -> None: - """ Compute tiles which any corner falls under the tissue """ - - # Filter coordinates by bounding boxes of mask polygons - bounding_boxes = self.mask.geometry.bounds - valid_coords = [] - - for _, bbox in bounding_boxes.iterrows(): - bbox_coords = coords[ - (coords[:, 0] >= bbox['minx'] - self.patch_size_src) & (coords[:, 0] <= bbox['maxx'] + self.patch_size_src) & - (coords[:, 1] >= bbox['miny'] - self.patch_size_src) & (coords[:, 1] <= bbox['maxy'] + self.patch_size_src) - ] - valid_coords.append(bbox_coords) - - if len(valid_coords) > 0: - coords = np.vstack(valid_coords) - coords = np.unique(coords, axis=0) - else: - coords = np.array([]) - - # Calculate corner coordinates - top_left = coords - top_right = coords + np.array([self.patch_size_src, 0]) - bottom_left = coords + np.array([0, self.patch_size_src]) - bottom_right = coords + np.array([self.patch_size_src, self.patch_size_src]) - - # Combine all corner coordinates - corners = np.stack([top_left, top_right, bottom_left, bottom_right], axis=1).reshape(-1, 2) - - union_mask = self.mask.union_all() - - # Check if any of the corners fall within the mask - points = gpd.points_from_xy(corners[:, 0], corners[:, 1]) - valid_mask = gpd.GeoSeries(points).within(union_mask).values - valid_mask = valid_mask.reshape(-1, 4).any(axis=1) # Check any corner within mask - valid_patches_nb = valid_mask.sum() - valid_coords = coords[valid_mask] - - return valid_patches_nb, valid_coords - - def __len__(self): - return self.valid_patches_nb - - def __iter__(self): - self.i = 0 - return self - - def __next__(self): - if self.i >= self.valid_patches_nb: - raise StopIteration - x = self.__getitem__(self.i) - self.i += 1 - return x - - def __getitem__(self, index): - if 0 <= index < len(self): - xy = self.valid_coords[index] - x, y = xy[0], xy[1] - if self.coords_only: - return x, y - tile, x, y = self.get_tile_xy(x, y) - return tile, x, y - else: - raise IndexError("Index out of range") - - - @abstractmethod - def _prepare(self) -> None: - pass - - def get_cols_rows(self) -> Tuple[int, int]: - """ Get the number of columns and rows in the associated WSI - - Returns: - Tuple[int, int]: (nb_columns, nb_rows) - """ - return self.cols, self.rows - - def get_tile_xy(self, x: int, y: int) -> Tuple[np.ndarray, int, int]: - raw_tile = self.wsi.read_region(location=(x, y), level=self.level, size=(self.patch_size_level, self.patch_size_level)) - tile = np.array(raw_tile) - if self.patch_size_target is not None: - tile = cv2.resize(tile, (self.patch_size_target, self.patch_size_target)) - assert x < self.width and y < self.height - return tile[:, :, :3], x, y - - def get_tile(self, col: int, row: int) -> Tuple[np.ndarray, int, int]: - """ get tile at position (column, row) - - Args: - col (int): column - row (int): row - - Returns: - Tuple[np.ndarray, int, int]: (tile, pixel x of top-left corner (before rescaling), pixel_y of top-left corner (before rescaling)) - """ - if self.custom_coords is not None: - raise ValueError("Can't use get_tile as 'custom_coords' was passed to the constructor") - - x, y = self._colrow_to_xy(col, row) - return self.get_tile_xy(x, y) - - def _compute_cols_rows(self) -> Tuple[int, int]: - col = 0 - row = 0 - x, y = self._colrow_to_xy(col, row) - while x < self.width: - col += 1 - x, _ = self._colrow_to_xy(col, row) - cols = col - while y < self.height: - row += 1 - _, y = self._colrow_to_xy(col, row) - rows = row - return cols, rows - - def save_visualization(self, path, vis_width=1000, dpi=150): - mask_plot = get_tissue_vis( - self.wsi, - self.mask, - line_color=(0, 255, 0), - line_thickness=5, - target_width=vis_width, - seg_display=True, - ) - import matplotlib.pyplot as plt - from matplotlib.collections import PatchCollection - from matplotlib.patches import Rectangle - - downscale_vis = vis_width / self.width - - _, ax = plt.subplots() - ax.imshow(mask_plot) - - patch_rectangles = [] - for xy in self.valid_coords: - x, y = xy[0], xy[1] - x, y = x * downscale_vis, y * downscale_vis - - patch_rectangles.append(Rectangle((x, y), self.patch_size_src * downscale_vis, self.patch_size_src * downscale_vis)) - - ax.add_collection(PatchCollection(patch_rectangles, facecolor='none', edgecolor='black', linewidth=0.3)) - ax.set_axis_off() - plt.tight_layout() - plt.savefig(path, dpi=dpi, bbox_inches = 'tight') - - -class OpenSlideWSIPatcher(WSIPatcher): - wsi: OpenSlideWSI - - def _prepare(self) -> None: - level = self.wsi.get_best_level_for_downsample(self.downsample) - level_downsample = self.wsi.level_downsamples()[level] - patch_size_level = round(self.patch_size_src / level_downsample) - overlap_level = round(self.overlap / level_downsample) - return level, patch_size_level, overlap_level - -class CuImageWSIPatcher(WSIPatcher): - wsi: CuImageWSI - - def _prepare(self) -> None: - level = self.wsi.get_best_level_for_downsample(self.downsample) - level_downsample = self.wsi.level_downsamples()[level] - patch_size_level = round(self.patch_size_src / level_downsample) - overlap_level = round(self.overlap / level_downsample) - return level, patch_size_level, overlap_level - -class NumpyWSIPatcher(WSIPatcher): - WSI: NumpyWSI - - def _prepare(self) -> None: - patch_size_level = self.patch_size_src - overlap_level = self.overlap - level = -1 - return level, patch_size_level, overlap_level - - - -def contours_to_img( - contours: gpd.GeoDataFrame, - img: np.ndarray, - draw_contours=False, - thickness=1, - downsample=1., - line_color=(0, 255, 0) -) -> np.ndarray: - draw_cont = partial(cv2.drawContours, contourIdx=-1, thickness=thickness, lineType=cv2.LINE_8) - draw_cont_fill = partial(cv2.drawContours, contourIdx=-1, thickness=cv2.FILLED) - - groups = contours.groupby('tissue_id') - for _, group in groups: - - for _, row in group.iterrows(): - cont = np.array([[round(x * downsample), round(y * downsample)] for x, y in row.geometry.exterior.coords]) - holes = [np.array([[round(x * downsample), round(y * downsample)] for x, y in hole.coords]) for hole in row.geometry.interiors] - - draw_cont_fill(image=img, contours=[cont], color=line_color) - - for hole in holes: - draw_cont_fill(image=img, contours=[hole], color=(0, 0, 0)) - - if draw_contours: - draw_cont(image=img, contours=[cont], color=line_color) - return img - - -def get_tissue_vis( - img: Union[np.ndarray, openslide.OpenSlide, CuImage, WSI], - tissue_contours: gpd.GeoDataFrame, - line_color=(0, 255, 0), - line_thickness=5, - target_width=1000, - seg_display=True, - ) -> Image: - - wsi = wsi_factory(img) - - width, height = wsi.get_dimensions() - downsample = target_width / width - - top_left = (0,0) - - img = wsi.get_thumbnail(round(width * downsample), round(height * downsample)) - - if tissue_contours is None: - return Image.fromarray(img) - - tissue_contours = tissue_contours.copy() - - downscaled_mask = np.zeros(img.shape[:2], dtype=np.uint8) - downscaled_mask = np.expand_dims(downscaled_mask, axis=-1) - downscaled_mask = downscaled_mask * np.array([0, 0, 0]).astype(np.uint8) - - if tissue_contours is not None and seg_display: - downscaled_mask = contours_to_img( - tissue_contours, - downscaled_mask, - draw_contours=True, - thickness=line_thickness, - downsample=downsample, - line_color=line_color - ) - - alpha = 0.4 - img = cv2.addWeighted(img, 1 - alpha, downscaled_mask, alpha, 0) - img = img.astype(np.uint8) - - return Image.fromarray(img) - -def get_pixel_size(slide): - """ - Extracts the pixel size from a whole slide image (WSI). - - Parameters: - slide (OpenSlide): - - Returns: - float: pixel size in microns. - """ - - # Get the pixel size - try: - # OpenSlide provides pixel size in microns in the properties - pixel_size = float(slide.properties.get(openslide.PROPERTY_NAME_MPP_X, 0)) - except Exception as e: - raise ValueError("Could not retrieve pixel size: " + str(e)) - - # Check if pixel size was successfully retrieved - if pixel_size == 0: - raise ValueError("Pixel size information is not available in the slide metadata.") - - return pixel_size diff --git a/core/utils/utils.py b/core/utils/utils.py index 467897f..0a63fa2 100644 --- a/core/utils/utils.py +++ b/core/utils/utils.py @@ -199,3 +199,38 @@ def smooth_rank_measure(embedding_matrix, eps=1e-7): smooth_rank = round(smooth_rank.item(), 2) return smooth_rank + + + +def mag_to_px_size(mag): + if mag == 5: return 2.0 + if mag == 10: return 1.0 + if mag == 20: return 0.5 + if mag == 40: return 0.25 + else: raise ValueError('Magnification should be in [5, 10, 20, 40].') + + +def get_pixel_size(slide): + """ + Extracts the pixel size from a whole slide image (WSI). + + Parameters: + slide (OpenSlide): + + Returns: + float: pixel size in microns. + """ + from openslide import PROPERTY_NAME_MPP_X + + # Get the pixel size + try: + # OpenSlide provides pixel size in microns in the properties + pixel_size = float(slide.properties.get(PROPERTY_NAME_MPP_X, 0)) + except Exception as e: + raise ValueError("Could not retrieve pixel size: " + str(e)) + + # Check if pixel size was successfully retrieved + if pixel_size == 0: + raise ValueError("Pixel size information is not available in the slide metadata.") + + return pixel_size diff --git a/requirements.txt b/requirements.txt index 996ca01..e7e4630 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ torch>=2.3.1 clamlite>=0.0.4 geopandas>=1.0 huggingface_hub -shapely \ No newline at end of file +shapely +hestcore==1.0.0 \ No newline at end of file