From 4dc1ff73223cfff01eef80c4eba9719ed92af3ea Mon Sep 17 00:00:00 2001 From: Alvaro Exposito Mtz Date: Wed, 24 Jun 2026 16:22:13 +0100 Subject: [PATCH 1/2] E2I, S2I, P2P, NN2I and SUREpgImage solvers. ct_utils and ct_experiments changes --- LION/CTtools/ct_utils.py | 35 +++--- LION/experiments/ct_experiments.py | 26 ++--- LION/losses/SUREpgImage.py | 43 +++++--- LION/optimizers/Equivariance2InverseSolver.py | 101 ++++++++++++++++++ LION/optimizers/Sparse2InverseSolver.py | 75 +++++++------ 5 files changed, 203 insertions(+), 77 deletions(-) create mode 100644 LION/optimizers/Equivariance2InverseSolver.py diff --git a/LION/CTtools/ct_utils.py b/LION/CTtools/ct_utils.py index 3ef4022b..27dd6d42 100644 --- a/LION/CTtools/ct_utils.py +++ b/LION/CTtools/ct_utils.py @@ -12,6 +12,7 @@ import numpy as np import tomosipo as ts import torch +import torchvision.transforms.functional as TF # AItomotools imports from LION.CTtools.ct_geometry import Geometry @@ -54,7 +55,8 @@ def sinogram_add_noise( proj, I0=1000, sigma=5, - cross_talk=0.05, + sigma_blur=0.3015, + ks_value=3, flat_field=None, dark_field=None, enable_gradients=False, @@ -63,17 +65,25 @@ def sinogram_add_noise( Wraper for _sinogram_add_noise to support gradients """ if enable_gradients: - sino = _sinogram_add_noise(proj, I0, sigma, cross_talk, flat_field, dark_field) + sino = _sinogram_add_noise( + proj, I0, sigma, sigma_blur, ks_value, flat_field, dark_field + ) else: with torch.no_grad(): sino = _sinogram_add_noise( - proj.detach(), I0, sigma, cross_talk, flat_field, dark_field + proj.detach(), I0, sigma, sigma_blur, ks_value, flat_field, dark_field ) return sino def _sinogram_add_noise( - proj, I0=1000, sigma=5, cross_talk=0.05, flat_field=None, dark_field=None + proj, + I0=1000, + sigma=5, + sigma_blur=0.3015, + ks_value=3, + flat_field=None, + dark_field=None, ): """ Adds realistic noise to sinograms. @@ -101,7 +111,8 @@ def _sinogram_add_noise( max_val = torch.amax( proj ) # alternatively the highest power of 2 close to this value, but lets leave it as is. - + if max_val <= 0: + max_val = 1.0 Im = I0 * torch.exp(-proj / max_val) # Uncorrect the flat fields @@ -111,17 +122,11 @@ def _sinogram_add_noise( Im = torch.poisson(Im) # Detector cross talk + ks = int(sigma_blur * ks_value) * 2 + 1 + if ks < 3: + ks = 3 + Im = TF.gaussian_blur(Im, kernel_size=[ks, ks], sigma=[sigma_blur, sigma_blur]) - kernel = torch.tensor( - [[0.0, 0.0, 0.0], [cross_talk, 1, cross_talk], [0.0, 0.0, 0.0]] - ).view(1, 1, 3, 3).repeat(1, 1, 1, 1) / (1 + 2 * cross_talk) - - conv = torch.nn.Conv2d(1, 1, 3, bias=False, padding="same") - with torch.no_grad(): - conv.weight = torch.nn.Parameter(kernel) - conv = conv.to(dev) - - Im = conv(Im.unsqueeze(0))[0] # Electronic noise: Im = Im + sigma * torch.randn(Im.shape, device=dev) diff --git a/LION/experiments/ct_experiments.py b/LION/experiments/ct_experiments.py index d7f7cead..89923076 100644 --- a/LION/experiments/ct_experiments.py +++ b/LION/experiments/ct_experiments.py @@ -35,8 +35,8 @@ def __init__(self, experiment_params=None, dataset="LIDC-IDRI", datafolder=None) self.geometry = self.experiment_params.geometry self.dataset = dataset if hasattr(self.param, "noise_params"): - self.sino_fun = lambda sino, I0=self.param.noise_params.I0, sigma=self.param.noise_params.sigma, cross_talk=self.param.noise_params.cross_talk: ct.sinogram_add_noise( - sino, I0=I0, sigma=sigma, cross_talk=cross_talk + self.sino_fun = lambda sino, I0=self.param.noise_params.I0, sigma=self.param.noise_params.sigma, sigma_blur=self.param.noise_params.sigma_blur: ct.sinogram_add_noise( + sino, I0=I0, sigma=sigma, sigma_blur=sigma_blur ) @staticmethod @@ -62,8 +62,8 @@ def __get_dataset(self, mode): self.param.data_loader_params.noise_params.I0 = ( self.param.noise_params.I0 ) - self.param.data_loader_params.noise_params.cross_talk = ( - self.param.noise_params.cross_talk + self.param.data_loader_params.noise_params.sigma_blur = ( + self.param.noise_params.sigma_blur ) self.param.data_loader_params.add_noise = True dataloader = deteCT( @@ -112,7 +112,7 @@ def default_parameters(dataset="LIDC-IDRI"): param.noise_params = LIONParameter() param.noise_params.I0 = 1000 param.noise_params.sigma = 5 - param.noise_params.cross_talk = 0.05 + param.noise_params.sigma_blur = 0.3015 param.data_loader_params = Experiment.get_dataset_parameters( dataset, geometry=param.geometry ) @@ -135,7 +135,7 @@ def default_parameters(dataset="LIDC-IDRI"): param.noise_params = LIONParameter() param.noise_params.I0 = 3500 param.noise_params.sigma = 5 - param.noise_params.cross_talk = 0.05 + param.noise_params.sigma_blur = 0.3015 param.data_loader_params = Experiment.get_dataset_parameters( dataset, geometry=param.geometry ) @@ -158,7 +158,7 @@ def default_parameters(dataset="LIDC-IDRI"): param.noise_params = LIONParameter() param.noise_params.I0 = 10000 param.noise_params.sigma = 5 - param.noise_params.cross_talk = 0.05 + param.noise_params.sigma_blur = 0.3015 param.data_loader_params = Experiment.get_dataset_parameters( dataset, geometry=param.geometry ) @@ -180,7 +180,7 @@ def default_parameters(dataset="LIDC-IDRI"): param.noise_params = LIONParameter() param.noise_params.I0 = 3500 param.noise_params.sigma = 5 - param.noise_params.cross_talk = 0.05 + param.noise_params.sigma_blur = 0.3015 if dataset == "LIDC-IDRI": # Parameters for the LIDC-IDRI dataset @@ -207,7 +207,7 @@ def default_parameters(dataset="LIDC-IDRI"): param.noise_params = LIONParameter() param.noise_params.I0 = 1000 param.noise_params.sigma = 5 - param.noise_params.cross_talk = 0.05 + param.noise_params.sigma_blur = 0.3015 if dataset == "LIDC-IDRI": # Parameters for the LIDC-IDRI dataset @@ -235,7 +235,7 @@ def default_parameters(dataset="LIDC-IDRI"): param.noise_params = LIONParameter() param.noise_params.I0 = 10000 param.noise_params.sigma = 5 - param.noise_params.cross_talk = 0.05 + param.noise_params.sigma_blur = 0.3015 param.data_loader_params = Experiment.get_dataset_parameters( dataset, geometry=param.geometry @@ -258,7 +258,7 @@ def default_parameters(dataset="LIDC-IDRI"): param.noise_params = LIONParameter() param.noise_params.I0 = 3500 param.noise_params.sigma = 5 - param.noise_params.cross_talk = 0.05 + param.noise_params.sigma_blur = 0.3015 if dataset == "LIDC-IDRI": # Parameters for the LIDC-IDRI dataset @@ -285,7 +285,7 @@ def default_parameters(dataset="LIDC-IDRI"): param.noise_params = LIONParameter() param.noise_params.I0 = 1000 param.noise_params.sigma = 5 - param.noise_params.cross_talk = 0.05 + param.noise_params.sigma_blur = 0.3015 if dataset == "LIDC-IDRI": # Parameters for the LIDC-IDRI dataset @@ -313,7 +313,7 @@ def default_parameters(dataset="LIDC-IDRI"): param.noise_params = LIONParameter() param.noise_params.I0 = 10000 param.noise_params.sigma = 5 - param.noise_params.cross_talk = 0.05 + param.noise_params.sigma_blur = 0.3015 param.data_loader_params = Experiment.get_dataset_parameters( dataset, geometry=param.geometry diff --git a/LION/losses/SUREpgImage.py b/LION/losses/SUREpgImage.py index 995ba215..a814bfdf 100644 --- a/LION/losses/SUREpgImage.py +++ b/LION/losses/SUREpgImage.py @@ -1,42 +1,55 @@ import torch -class SUREpgLoss(): - def __init__(self, zeta: float, sigma2: float, eps1: float = 1e-3, eps2: float = 1e-3, kappa: float = 1.0): +class SUREpgLoss: + def __init__( + self, + zeta: float, + sigma2: float, + eps1: float = 1e-3, + eps2: float = 1e-3, + kappa: float = 1.0, + ): self.zeta = zeta self.sigma2 = sigma2 self.eps1 = eps1 self.eps2 = eps2 self.kappa = kappa - self.p = (1/2) * (1 + self.kappa / (self.kappa**2 + 4)**0.5) + self.p = (1 / 2) * (1 + self.kappa / (self.kappa**2 + 4) ** 0.5) self.q = 1 - self.p - self.a = (self.q / self.p)**0.5 - self.b = (self.p / self.q)**0.5 + self.a = (self.q / self.p) ** 0.5 + self.b = (self.p / self.q) ** 0.5 def __call__(self, model, y): - B=y.shape[0] + B = y.shape[0] N_per_img = y.shape[1] * y.shape[2] * y.shape[3] - fy=model(y) - loss = ((fy - y) ** 2).sum(dim=(1,2,3)) - self.zeta * y.sum(dim=(1,2,3)) - self.sigma2 * N_per_img + fy = model(y) + loss = ( + ((fy - y) ** 2).sum(dim=(1, 2, 3)) + - self.zeta * y.sum(dim=(1, 2, 3)) + - self.sigma2 * N_per_img + ) - #1st derivative MC + # 1st derivative MC delta1 = torch.randn_like(y) fy_perturbated = model(y + self.eps1 * delta1) u = self.zeta * y + self.sigma2 - mc1 = (delta1 * u * (fy_perturbated - fy)).sum(dim=(1,2,3)) + mc1 = (delta1 * u * (fy_perturbated - fy)).sum(dim=(1, 2, 3)) loss += 2.0 * mc1 / self.eps1 - #2nd derivative MC + # 2nd derivative MC u_rand = torch.rand_like(y) - delta2 = torch.where(u_rand < self.p,-self.a * torch.ones_like(y),+self.b * torch.ones_like(y)) + delta2 = torch.where( + u_rand < self.p, -self.a * torch.ones_like(y), +self.b * torch.ones_like(y) + ) - fy_plus = model(y + self.eps2 * delta2) + fy_plus = model(y + self.eps2 * delta2) fy_minus = model(y - self.eps2 * delta2) - mc2 = (delta2 * (fy_plus - 2*fy + fy_minus)).sum(dim=(1,2,3)) + mc2 = (delta2 * (fy_plus - 2 * fy + fy_minus)).sum(dim=(1, 2, 3)) loss -= (2 * self.sigma2 * self.zeta / (self.eps2**2 * self.kappa)) * mc2 - return loss.mean()/N_per_img + return loss.mean() / N_per_img diff --git a/LION/optimizers/Equivariance2InverseSolver.py b/LION/optimizers/Equivariance2InverseSolver.py new file mode 100644 index 00000000..456c8458 --- /dev/null +++ b/LION/optimizers/Equivariance2InverseSolver.py @@ -0,0 +1,101 @@ +from typing import Callable, Optional +import numpy as np +from LION.CTtools.ct_geometry import Geometry +from LION.classical_algorithms.fdk import fdk +from LION.models.LIONmodel import LIONmodel +import torch +import torch.nn.functional as F +from torch.optim.optimizer import Optimizer +from LION.optimizers.LIONsolver import LIONsolver +from LION.utils.parameter import LIONParameter +import tomosipo as ts +import LION.CTtools.ct_utils as ct_utils +from tomosipo.torch_support import to_autograd +import random +import torchvision.transforms as TT + + +class Equivariance2InverseSolver(LIONsolver): + def __init__( + self, + model: LIONmodel, + optimizer: Optimizer, + loss_fn, + solver_params: Optional[LIONParameter] = None, + geometry: Geometry = None, + verbose: bool = True, + device: torch.device = None, + ) -> None: + print(device) + super().__init__( + model, + optimizer, + loss_fn, + geometry, + verbose, + device, + solver_params=solver_params, + ) + + self.operator = ct_utils.make_operator(self.geometry) + self.model.geometry = self.geometry + self.model.operator = self.operator + self.projector = to_autograd(self.operator, num_extra_dims=1) + self.recon_fn = self.solver_params.recon_fn + + @staticmethod + def default_parameters() -> LIONParameter: + params = LIONParameter() + params.recon_fn = fdk + params.I0 = 500 + params.sigma = (50) ** (0.5) + params.sigma_blur = 0.8 + return params + + def mini_batch_step(self, sinos, targets): + # masking + NP = sinos.shape[2] + YJ_num = torch.randint(0, NP, (1,)).item() + YJ = sinos[:, :, YJ_num, :] + + YJc = sinos.clone() + YJc[:, :, YJ_num, :] = 0 + + weight = NP / (NP - 1) + RJc = self.recon_fn(YJc * weight, self.model.operator) + output_recon_1 = self.model(RJc) + + output_sino_1 = self.projector(output_recon_1) + AJ = output_sino_1[:, :, YJ_num, :] + batch_loss = ((AJ - YJ) ** 2).mean() + + angle = random.uniform(0, 360) + rotated_output_recon_1 = TT.functional.rotate( + output_recon_1, angle, interpolation=TT.InterpolationMode.BILINEAR + ) + rotated_output_recon_1 = torch.clamp(rotated_output_recon_1, min=0.0) + rotated_sinogram = self.projector(rotated_output_recon_1) + rotated_sinogram = torch.clamp(rotated_sinogram, min=0.0) + rotated_sinogram_noisy = ct_utils.sinogram_add_noise( + rotated_sinogram, + I0=self.solver_params.I0, + sigma=self.solver_params.sigma, + sigma_blur=self.solver_params.sigma_blur, + ks_value=3, + flat_field=None, + dark_field=None, + ) + + rotated_noisy_image = self.recon_fn(rotated_sinogram_noisy, self.model.operator) + output_recon_2 = self.model(rotated_noisy_image) + batch_loss += ((output_recon_2 - rotated_output_recon_1) ** 2).mean() + return batch_loss + + # No validation in E2I + def validate(self): + return 0 + + def reconstruct(self, sinos): + input_recon = self.recon_fn(sinos, self.operator) + output_recon = self.model(input_recon) + return output_recon diff --git a/LION/optimizers/Sparse2InverseSolver.py b/LION/optimizers/Sparse2InverseSolver.py index 6099f7f9..a3bf6684 100644 --- a/LION/optimizers/Sparse2InverseSolver.py +++ b/LION/optimizers/Sparse2InverseSolver.py @@ -14,6 +14,7 @@ import LION.CTtools.ct_utils as ct_utils from tomosipo.torch_support import to_autograd + class Sparse2InverseSolver(LIONsolver): def __init__( self, @@ -36,7 +37,7 @@ def __init__( solver_params=solver_params, ) - self.model.geometry = self.geometry + self.model.geometry = self.geometry self.model._make_operator() self.sino_split_count = self.solver_params.sino_split_count self.recon_fn = self.solver_params.recon_fn @@ -44,46 +45,46 @@ def __init__( self._make_sub_operators() @classmethod - def two_two_strategy(cls, sino_split_count) -> list[tuple[int,int]]: - #to return all 2 element combinations from 0 to sino_split_count-1 - combos = [] - for i in range(sino_split_count): - for j in range(i + 1, sino_split_count): - combos.append((i, j)) - return combos - + def two_two_strategy(cls, sino_split_count) -> list[tuple[int, int]]: + # to return all 2 element combinations from 0 to sino_split_count-1 + combos = [] + for i in range(sino_split_count): + for j in range(i + 1, sino_split_count): + combos.append((i, j)) + return combos + def _make_sub_operators(self) -> list[ts.Operator.Operator]: - self.sub_ops = [] + self.sub_ops = [] angles = self.geometry.angles.copy() n = len(angles) k = self.sino_split_count - angles_per_group = n // k - remainder = n % k + self.subgroup_indices = [] - start = 0 + for i in range(k): - end = start + angles_per_group + (1 if i < remainder else 0) - self.subgroup_indices.append(list(range(start, end))) - start = end + indices_grupo_i = list(range(i, n, k)) + self.subgroup_indices.append(indices_grupo_i) for idx_group in range(k): sub_geom = Geometry( - image_shape=tuple(self.geometry.image_shape), # tupla, ints - image_size=tuple(self.geometry.image_size), # tupla, floats - angles=[angles[i] for i in self.subgroup_indices[idx_group]], # list, floats - voxel_size=tuple(self.geometry.voxel_size), # tupla, floats + image_shape=tuple(self.geometry.image_shape), # tupla, ints + image_size=tuple(self.geometry.image_size), # tupla, floats + angles=[ + angles[i] for i in self.subgroup_indices[idx_group] + ], # list, floats + voxel_size=tuple(self.geometry.voxel_size), # tupla, floats mode=self.geometry.mode, dso=float(self.geometry.dso), dsd=float(self.geometry.dsd), - detector_shape=tuple(self.geometry.detector_shape), # tupla, ints - detector_size=tuple(self.geometry.detector_size), # tupla, floats - pixel_size=tuple(self.geometry.pixel_size), # tupla, floats - image_pos=tuple(self.geometry.image_pos) # tupla, floats + detector_shape=tuple(self.geometry.detector_shape), # tupla, ints + detector_size=tuple(self.geometry.detector_size), # tupla, floats + pixel_size=tuple(self.geometry.pixel_size), # tupla, floats + image_pos=tuple(self.geometry.image_pos), # tupla, floats ) sub_op = ct_utils.make_operator(sub_geom) self.sub_ops.append(sub_op) - self.combo_ops_autograd = {} + self.combo_ops_autograd = {} for combo in self.split_combinations: combo_angles = [] for split_idx in combo: @@ -100,7 +101,7 @@ def _make_sub_operators(self) -> list[ts.Operator.Operator]: detector_shape=tuple(self.geometry.detector_shape), detector_size=tuple(self.geometry.detector_size), pixel_size=tuple(self.geometry.pixel_size), - image_pos=tuple(self.geometry.image_pos) + image_pos=tuple(self.geometry.image_pos), ) combo_op = ct_utils.make_operator(combo_geom) self.combo_ops_autograd[combo] = to_autograd(combo_op, num_extra_dims=1) @@ -117,14 +118,13 @@ def _calculate_noisy_sub_recons(self, sinos): subgroup_recons[combo] = mean_subgroup_recon return subgroup_recons - @staticmethod def default_parameters() -> LIONParameter: params = LIONParameter() params.sino_split_count = 4 params.recon_fn = fdk return params - + def mini_batch_step(self, sinos, targets): batch_size = sinos.shape[0] subgroup_recons = self._calculate_noisy_sub_recons(sinos) @@ -132,16 +132,23 @@ def mini_batch_step(self, sinos, targets): total_pixels = 0 for combo, mean_recon in subgroup_recons.items(): output_recon = self.model(mean_recon) - remaining_splits = [i for i in range(self.sino_split_count) if i not in combo] + remaining_splits = [ + i for i in range(self.sino_split_count) if i not in combo + ] projector_combo = tuple(sorted(remaining_splits)) projector = self.combo_ops_autograd[projector_combo] + objective_idx = [] + for i in remaining_splits: + objective_idx.extend(self.subgroup_indices[i]) + objective_idx.sort() for b in range(batch_size): - projected_sino = projector(output_recon[b:b+1]) - target_sino = torch.cat([sinos[b:b+1, :, self.subgroup_indices[i], :] for i in remaining_splits],dim=2) - batch_loss += self.loss_fn(projected_sino, target_sino) + projected_sino = projector(output_recon[b : b + 1]) + target_sino = sinos[b : b + 1, :, objective_idx, :] + batch_loss += ((projected_sino - target_sino) ** 2).sum() + total_pixels += projected_sino.numel() + batch_loss /= total_pixels return batch_loss - # No validation in Sparse2Inverse as it is unsupervised learning def validate(self): return 0 @@ -152,6 +159,6 @@ def reconstruct(self, sinos): (sinos.shape[0], *self.geometry.image_shape), device=self.device ) for combo, mean_recon in subgroup_recons.items(): - outputs += self.model(mean_recon) + outputs += self.model(mean_recon) outputs /= len(subgroup_recons) return outputs From c4ba7cc4e10615d0905422638b19afa69697175e Mon Sep 17 00:00:00 2001 From: Alvaro Exposito Mtz Date: Thu, 25 Jun 2026 13:47:44 +0100 Subject: [PATCH 2/2] N2I split solution --- LION/optimizers/Noise2InverseSolver.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/LION/optimizers/Noise2InverseSolver.py b/LION/optimizers/Noise2InverseSolver.py index 493f3807..3264e581 100644 --- a/LION/optimizers/Noise2InverseSolver.py +++ b/LION/optimizers/Noise2InverseSolver.py @@ -52,9 +52,6 @@ def _make_sub_operators(self) -> list[ts.Operator.Operator]: ops = [] # maintain a copy of the original angles to restore later angles = self.geometry.angles.copy() - assert ( - len(angles) % self.sino_split_count == 0 - ), f"Cannot construct {self.sino_split_count} sinogram splits from {len(angles)} view angles. Ensure that sino_split_count divides #view angles" for k in range(self.sino_split_count): self.geometry.angles = angles[k :: self.sino_split_count] sub_op = ct.make_operator(self.geometry)