diff --git a/.github/workflows/publish-to-pypi.yml b/.github/workflows/publish-to-pypi.yml index 9979e81..95a12cb 100644 --- a/.github/workflows/publish-to-pypi.yml +++ b/.github/workflows/publish-to-pypi.yml @@ -77,3 +77,4 @@ jobs: uses: pypa/gh-action-pypi-publish@release/v1 with: repository-url: https://test.pypi.org/legacy/ + verbose: true diff --git a/docs/index.md b/docs/index.md index a4dc16c..836f753 100644 --- a/docs/index.md +++ b/docs/index.md @@ -21,6 +21,7 @@ Welcome to the ITKIT documentation! ITKIT is a user-friendly toolkit built on `S - **[itk_extract](itk_extract.md)** - Label extraction - **[itk_combine](itk_combine.md)** - Label merging and intersection - **[itk_convert](itk_convert.md)** - Format conversion +- **[itk_infer](itk_infer.md)** - Batch inference with MMEngine/ONNX backends ### Advanced Topics diff --git a/docs/itk_infer.md b/docs/itk_infer.md new file mode 100644 index 0000000..d04ea03 --- /dev/null +++ b/docs/itk_infer.md @@ -0,0 +1,179 @@ +# itk_infer + +Perform batch inference on 3D medical images using trained segmentation models with support for MMEngine and ONNX backends. + +## Usage + +```bash +itk_infer -i -o --backend [options] +``` + +## Backends + +- **mmengine**: Use MMEngine models with config and checkpoint files +- **onnx**: Use ONNX runtime for optimized inference + +## Required Parameters + +- `-i, --input-folder PATH`: Input folder containing image files (supports `*.mha`, `*.nii`, `*.nii.gz`) +- `-o, --output PATH`: Output folder for segmentation results + +### Backend-Specific Requirements + +**For MMEngine backend:** + +- `-cfg, --cfg-path PATH`: Model configuration file path +- `-ckpt, --ckpt-path PATH`: Model checkpoint file path + +**For ONNX backend:** + +- `--onnx PATH`: ONNX model file path + +## Optional Parameters + +### Windowing Parameters + +- `--wl INT`: Window level for CT preprocessing (optional; defaults to config value for MMEngine) +- `--ww INT`: Window width for CT preprocessing (optional; defaults to config value for MMEngine) + +> **Note**: For ONNX backend, if `--wl/--ww` are not provided, the tool attempts to read them from the ONNX model's metadata (`window_level`/`window_width`). + +### Inference Configuration + +- `--patch-size Z Y X`: Override patch size for sliding window inference (three integers) +- `--patch-stride Z Y X`: Override patch stride for sliding window inference (three integers) + +### Performance Options + +- `--num-proc N`: Number of parallel processes (default: 1) +- `--gpus N`: Number of GPUs to use (default: 1) +- `--fp16`: Enable FP16 mixed precision for faster inference +- `--save-logits`: Save raw segmentation logits as `.zarr` files (compressed with LZ4) +- `--save-conf`: Calculate and save prediction confidence scores to `confidences.xlsx` + +## Output Files + +The tool generates the following outputs in the specified output folder: + +1. **Segmentation Maps**: One file per input image with the same filename + - Format: Same as input (`.mha`, `.nii`, or `.nii.gz`) + - Orientation: Automatically reoriented to LPI + - Metadata: Copied from input image (spacing, origin, direction) + +2. **Logits (Optional)**: When `--save-logits` is enabled + - Format: `.zarr` files with Blosc+LZ4 compression + - Shape: `(C, Z, Y, X)` where C is the number of classes + - Data type: float16 + +3. **Confidence Scores (Optional)**: When `--save-conf` is enabled + - File: `confidences.xlsx` + - Content: Per-image prediction confidence based on inverse entropy + +## Examples + +### MMEngine Backend + +```bash +# Basic inference with MMEngine +itk_infer -i /data/images -o /data/results \ + --backend mmengine \ + -cfg /models/config.py \ + -ckpt /models/checkpoint.pth + +# Multi-GPU inference with custom windowing +itk_infer -i /data/images -o /data/results \ + --backend mmengine \ + -cfg /models/config.py \ + -ckpt /models/checkpoint.pth \ + --wl 50 --ww 400 \ + --num-proc 4 --gpus 2 + +# FP16 inference with custom patch configuration +itk_infer -i /data/images -o /data/results \ + --backend mmengine \ + -cfg /models/config.py \ + -ckpt /models/checkpoint.pth \ + --patch-size 96 96 96 \ + --patch-stride 48 48 48 \ + --fp16 +``` + +### ONNX Backend + +```bash +# Basic ONNX inference +itk_infer -i /data/images -o /data/results \ + --backend onnx \ + --onnx /models/model.onnx \ + --wl 50 --ww 400 + +# Multi-process ONNX inference with logits and confidence +itk_infer -i /data/images -o /data/results \ + --backend onnx \ + --onnx /models/model.onnx \ + --num-proc 4 --gpus 2 \ + --save-logits --save-conf +``` + +## Features + +### Automatic Skipping + +The tool automatically skips files that have already been processed, checking for existing output files before inference. This enables resumable batch processing. + +### Multi-Processing + +Supports parallel processing across multiple GPUs: + +- Files are evenly distributed across processes +- Each process is assigned to a GPU in round-robin fashion +- Progress bars show per-process status + +### Prediction Confidence + +When `--save-conf` is enabled, the tool calculates prediction confidence using inverse normalized entropy: + +- **High confidence** (close to 1.0): Model is certain about predictions +- **Low confidence** (close to 0.0): Model is uncertain, predictions may be less reliable +- Useful for quality control and identifying cases requiring manual review + +### Sliding Window Inference + +Processes large 3D volumes by dividing them into overlapping patches: + +- Configurable patch size and stride +- Automatic overlap blending +- Memory-efficient processing of arbitrarily large volumes + +## Integration with 3D Slicer + +For interactive inference within 3D Slicer, see the **[3D Slicer Integration](slicer_integration.md)** guide, which provides a GUI-based extension using the same inference backend. + +## Performance Tips + +1. **GPU Memory**: Use `--fp16` to reduce memory usage and increase speed +2. **Batch Processing**: Increase `--num-proc` to parallelize across multiple GPUs +3. **Patch Configuration**: Larger patches may improve accuracy but require more memory +4. **Windowing**: Proper `--wl/--ww` values are critical for CT image preprocessing + +## Troubleshooting + +**Error: "No input files found"** + +- Ensure input folder contains files with supported extensions (`.mha`, `.nii`, `.nii.gz`) + +**Error: "requires --wl/--ww"** + +- For ONNX backend, specify windowing parameters or embed them in ONNX metadata + +**Out of Memory** + +- Reduce patch size using `--patch-size` +- Enable `--fp16` mode +- Reduce `--num-proc` if multiple processes compete for GPU memory + +**Slow Performance** + +- Enable `--fp16` for faster inference +- Increase `--num-proc` and `--gpus` for parallel processing +- Increase `--patch-stride` (less overlap means faster processing but potentially lower quality) diff --git a/itkit/process/itk_infer.py b/itkit/process/itk_infer.py new file mode 100644 index 0000000..29046aa --- /dev/null +++ b/itkit/process/itk_infer.py @@ -0,0 +1,277 @@ +import glob +import multiprocessing as mp +import os +import traceback +from multiprocessing import Manager +from pathlib import Path + +import numpy as np +import pandas as pd +import SimpleITK as sitk +import torch +import zarr +from torch import Tensor +from tqdm import tqdm + +VALID_INPUT_EXTS = ['*.mha', '*.nii', '*.nii.gz'] + + +def gen_tasks(args): + os.makedirs(args.output, exist_ok=True) + + # Collect all available files + input_files: list[str] = [] + for ext in VALID_INPUT_EXTS: + input_files.extend(glob.glob(os.path.join(args.input_folder, ext))) + if not input_files: + raise ValueError(f"No input files found in {args.input_folder}") + + # Pre-filter out existing output files + pending = [] + for f in input_files: + output_path = Path(args.output) / Path(f).name + if output_path.exists(): + print(f"Skipping existing {output_path}") + else: + pending.append(f) + if not pending: + print("No new files to process") + return [[] for _ in range(args.num_proc)] + + # Distribute tasks evenly to each process + if args.num_proc == 1 or len(pending) == 1: + return [pending] + avg = len(pending) // args.num_proc + rem = len(pending) % args.num_proc + tasks = [] + idx = 0 + for i in range(args.num_proc): + cnt = avg + (1 if i < rem else 0) + tasks.append(pending[idx:idx+cnt]) + idx += cnt + return tasks + + +def set_window(image_array:np.ndarray, wl:int, ww:int) -> np.ndarray: + left = wl - ww/2 + right = wl + ww/2 + image_array = np.clip(image_array.astype(np.int16), left, right) + image_array = (image_array - left) / ww + return image_array + + +def calc_classwise_pred_confidence(seg_logits: Tensor) -> Tensor: + """ + Calculate prediction confidence (inverse entropy) across all spatial dimensions. + + Args: + seg_logits (Tensor): Segmentation logits tensor with shape (N, C, Z, Y, X). + + Returns: + confidence (Tensor): Tensor with shape (N,) representing the mean confidence across all classes. + """ + assert (C:=seg_logits.size(1)) >= 2, f"Number of classes must be at least 2, got {C}." + + # Compute softmax probabilities + probs = torch.softmax(seg_logits, dim=1) # (N, C, Z, Y, X) + + # Calculate entropy: -sum(p * log(p)) + entropy = - (probs * torch.log(probs + 1e-8)).sum(dim=1) # (N, Z, Y, X) + + # Normalize entropy to [0,1] by dividing by maximum entropy = log(C) + max_entropy = np.log(C) + + normalized_entropy = entropy / max_entropy + confidence = 1.0 - normalized_entropy # (N, Z, Y, X) + + return confidence.mean(dim=(1, 2, 3)) # (N,) + + +def process_gpu_task(process_id, file_list, args, pred_conf_shared_dict=None): + # NOTE Local environment setup for each GPU process. + gpu_id = process_id % args.gpus + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + from itkit.mm.inference import Inferencer_Seg3D, MMEngineInferBackend, ONNXInferBackend + from itkit.mm.sliding_window import InferenceConfig + tqdm.write(f"Process {process_id} using GPU {gpu_id}, processing {len(file_list)} files") + + infer_cfg_override = None + if args.patch_size is not None or args.patch_stride is not None: + infer_cfg_override = InferenceConfig( + patch_size=tuple(args.patch_size) if args.patch_size is not None else None, + patch_stride=tuple(args.patch_stride) if args.patch_stride is not None else None + ) + + if args.backend == "mmengine": + backend = MMEngineInferBackend( + cfg_path=args.cfg_path, + ckpt_path=args.ckpt_path, + inference_config=infer_cfg_override, + allow_tqdm=False + ) + wl = args.wl if args.wl is not None else backend.cfg.get('wl') + ww = args.ww if args.ww is not None else backend.cfg.get('ww') + + elif args.backend == "onnx": # onnx + backend = ONNXInferBackend( + onnx_path=args.onnx, + inference_config=infer_cfg_override, + allow_tqdm=False + ) + wl, ww = args.wl, args.ww + if wl is None or ww is None: + meta = backend.session.get_modelmeta().custom_metadata_map or {} + if wl is None and 'window_level' in meta: + wl = int(meta['window_level']) + if ww is None and 'window_width' in meta: + ww = int(meta['window_width']) + if wl is None or ww is None: + raise ValueError("--backend onnx requires --wl/--ww, or the ONNX must contain metadata: window_level/window_width") + + else: + raise NotImplementedError(f"Backend {args.backend} not supported.") + + inferencer = Inferencer_Seg3D( + backend=backend, + fp16=args.fp16, + allow_tqdm=False + ) + + pred_confidences = {} + output_folder = Path(args.output) + output_folder.mkdir(parents=True, exist_ok=True) + for file_path in tqdm(file_list, + dynamic_ncols=True, + leave=False, + mininterval=1, + position=process_id, + desc=f"Proc{process_id}-GPU{gpu_id}"): + try: + # Prepare + file = Path(file_path) + itk_image = sitk.ReadImage(file_path) + image_array = sitk.GetArrayFromImage(itk_image) + image_array = set_window(image_array, wl, ww) + image_array = image_array.astype(np.float16 if args.fp16 else np.float32) + + # Inference + pred_seg_logits, pred_sem_seg = inferencer.Inference_FromNDArray(image_array) + assert pred_seg_logits.size(0) == 1, "Batch size > 1 not supported in this script." + assert pred_sem_seg.size(0) == 1, "Batch size > 1 not supported in this script." + + # Save semantic segmentation map + itk_pred = sitk.GetImageFromArray(pred_sem_seg[0].cpu().numpy()) + itk_pred.CopyInformation(itk_image) + itk_pred = sitk.DICOMOrient(itk_pred, 'LPI') + sitk.WriteImage(itk_pred, output_folder/file.name, True) + + # Register Confidence Map + if pred_conf_shared_dict is not None: + confidence = calc_classwise_pred_confidence(pred_seg_logits) # (N,) + pred_confidences[file.stem] = confidence[0].cpu().item() + + # Save segmentation logits as .npz + logits_np = pred_seg_logits[0].cpu().numpy().astype(np.float16) + if args.save_logits: + zarr.save_array( + output_folder / (file.stem+'.zarr'), + logits_np, # pyright: ignore[reportArgumentType] + codecs=[{"name": "bytes"}, + {"name": "blosc", + "configuration": {"cname": "lz4", + "clevel": 6}}] + ) + + except Exception as e: + traceback.print_exc() + tqdm.write(f"Error processing {file_path}: {e}") + + # Update shared dict with pred_confidences + if pred_conf_shared_dict is not None: + pred_conf_shared_dict.update(pred_confidences) + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='Inferencer') + parser.add_argument('-i', '--input-folder', type=str, required=True, help='Input folder path') + parser.add_argument('-o', '--output', type=str, required=True, help='Output folder path') + + # Backend options + parser.add_argument('--backend', type=str, choices=['mmengine', 'onnx'], default='mmengine', help='Inference backend') + # MMEngine related + parser.add_argument('-cfg', '--cfg-path', type=str, help='Config file path (required for mmengine)') + parser.add_argument('-ckpt', '--ckpt-path', type=str, help='Checkpoint file path (required for mmengine)') + # ONNX related + parser.add_argument('--onnx', type=str, help='ONNX model path (required for onnx)') + + # Windowing parameters (optional, defaults to config if mmengine) + parser.add_argument('--wl', type=int, help='Window level') + parser.add_argument('--ww', type=int, help='Window width') + # Inference config overrides (optional) + parser.add_argument('--patch-size', type=int, nargs=3, metavar=('Z', 'Y', 'X'), + help='Override inference patch size (Z Y X)') + parser.add_argument('--patch-stride', type=int, nargs=3, metavar=('Z', 'Y', 'X'), + help='Override inference patch stride (Z Y X)') + + # Other options + parser.add_argument('--num-proc', type=int, default=1, help='Number of processes to use') + parser.add_argument('--gpus', type=int, default=1, help='Number of GPUs to use') + parser.add_argument('--fp16', action='store_true', default=False, help='Use FP16 precision') + parser.add_argument('--save-logits', action='store_true', default=False) + parser.add_argument('--save-conf', action='store_true', default=False) + + args = parser.parse_args() + + if args.backend == 'mmengine': + if not args.cfg_path or not args.ckpt_path: + parser.error("--backend mmengine requires --cfg-path and --ckpt-path") + elif args.backend == 'onnx': + if not args.onnx: + parser.error("--backend onnx requires --onnx") + + return args + + +def main(): + args = parse_args() + + # Allocate task + task_per_process = gen_tasks(args) + total_tasks = sum(len(t) for t in task_per_process) + print(f"Found {total_tasks} files to process") + + # Create shared dict for collecting pred_confidences + if args.save_conf: + manager = Manager() + pred_conf_shared_dict = manager.dict() + else: + pred_conf_shared_dict = None + + processes = [] + for process_id, file_list in enumerate(task_per_process): + if not file_list: + print(f"Process {process_id} has no tasks to process") + continue + p = mp.get_context('spawn').Process( + target = process_gpu_task, + args = (process_id, file_list, args, pred_conf_shared_dict), + daemon = True + ) + p.start() + processes.append(p) + + for p in processes: + p.join() + + # Collect all pred_confidences from pred_conf_shared_dict + if pred_conf_shared_dict is not None: + pd.DataFrame.from_dict( + dict(pred_conf_shared_dict), + orient='index', + columns=['Confidence'] + ).to_excel(os.path.join(args.output, 'confidences.xlsx'), sheet_name='Confidences') + + +if __name__ == '__main__': + main() diff --git a/mkdocs.yml b/mkdocs.yml index 0d8260a..d960959 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -24,6 +24,7 @@ nav: - itk_aug: itk_aug.md - itk_extract: itk_extract.md - itk_convert: itk_convert.md + - itk_infer: itk_infer.md - Advanced Topics: - Framework Integration: framework_integration.md - 3D Slicer Integration: slicer_integration.md diff --git a/pyproject.toml b/pyproject.toml index b492524..6e01edc 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "itkit" -version = "4.0.0rc3" +version = "4.0.0rc4" requires-python = ">= 3.10" description = "ITKIT: Feasible Medical Image Operation based on SimpleITK API" readme = "README.md" @@ -109,6 +109,7 @@ itk_evaluate = "itkit.process.itk_evaluate:main" itk_combine = "itkit.process.itk_combine:main" itkit-app = "itkit.gui.app:main" itk_slicer = "SlicerITKIT.server.itkit_server:main" +itk_infer = "itkit.inference.itk_infer:main" [tool.setuptools.packages.find] where = ["."]