From d3a1ce7e3cd4b2412d52c8d38f883827b121e3aa Mon Sep 17 00:00:00 2001 From: Alvaro-Exposito-MTZ Date: Tue, 13 Jan 2026 15:22:19 +0000 Subject: [PATCH 1/9] example script Sparse2Inverse --- scripts/example_scripts/Sparse2Inverse.py | 75 +++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 scripts/example_scripts/Sparse2Inverse.py diff --git a/scripts/example_scripts/Sparse2Inverse.py b/scripts/example_scripts/Sparse2Inverse.py new file mode 100644 index 00000000..949587dd --- /dev/null +++ b/scripts/example_scripts/Sparse2Inverse.py @@ -0,0 +1,75 @@ +from LION.classical_algorithms.fdk import fdk +from Sparse2InverseSolver import Sparse2InverseSolver +from LION.models.CNNs.UNets.Unet import UNet +import LION.experiments.ct_experiments as ct_experiments +from torch.utils.data import DataLoader +from torch.optim.adam import Adam +import torch.nn as nn +import torch +import pathlib +import torch.utils.data as data_utils +import random + + +seed = 42 +random.seed(seed) +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +# Set Device +#%% +# % Chose device: +device = torch.device("cuda:1") +torch.cuda.set_device(device) + +# Define your data paths +savefolder = pathlib.Path("/store/LION/ea692/LION/LION/trained_models/Sparse2Inverse/Train/SparseAngleLowDoseCTRecon/30sin2000ep/16Angles") +# Creates the folders if they does not exist +savefolder.mkdir(parents=True, exist_ok=True) +final_result_fname = "S2I.pt" +checkpoint_fname = "S2I_check_*.pt" + +# Define experiment +experiment = ct_experiments.SparseAngleLowDoseCTRecon() +train_dataset = experiment.get_training_dataset() +indices = torch.arange(30) +train_dataset = data_utils.Subset(train_dataset, indices) + +# Data to train +batch_size = 1 +dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + +# Define model +model = UNet() + +# Create optimizer and loss function +optimizer = Adam(model.parameters(), lr=1e-4) +loss_fn = nn.MSELoss() + +#Sparse2InverseSolver. +s2i_params = Sparse2InverseSolver.default_parameters() +# Sparse to inverse requires certain user specifications. +s2i_params.sino_split_count = 4 +s2i_params.recon_fn = fdk + +# Initialize the solver as the other solvers in LION +solver = Sparse2InverseSolver( + model, + optimizer, + loss_fn, + solver_params=s2i_params, + geometry=experiment.geometry, + verbose=True, + device=device, +) + +solver.set_training(dataloader) +solver.set_checkpointing(checkpoint_fname, 100, save_folder=savefolder) + +epochs = 100 + +solver.train(epochs) +solver.save_final_results(final_result_fname, savefolder) +solver.clean_checkpoints() From df1ec3a44ec010223ec99012372663934fbc0c79 Mon Sep 17 00:00:00 2001 From: Alvaro-Exposito-MTZ Date: Wed, 14 Jan 2026 09:17:37 +0000 Subject: [PATCH 2/9] Update example script Sparse2Inverse.py --- scripts/example_scripts/Sparse2Inverse.py | 107 +++++++++++++++++++++- 1 file changed, 103 insertions(+), 4 deletions(-) diff --git a/scripts/example_scripts/Sparse2Inverse.py b/scripts/example_scripts/Sparse2Inverse.py index 949587dd..1742e674 100644 --- a/scripts/example_scripts/Sparse2Inverse.py +++ b/scripts/example_scripts/Sparse2Inverse.py @@ -9,7 +9,10 @@ import pathlib import torch.utils.data as data_utils import random - +import numpy as np +import matplotlib.pyplot as plt +from skimage.metrics import structural_similarity as ssim +from LION.metrics.haarpsi import HAARPsi seed = 42 random.seed(seed) @@ -25,7 +28,7 @@ torch.cuda.set_device(device) # Define your data paths -savefolder = pathlib.Path("/store/LION/ea692/LION/LION/trained_models/Sparse2Inverse/Train/SparseAngleLowDoseCTRecon/30sin2000ep/16Angles") +savefolder = pathlib.Path("/store/LION/ea692/LION/LION/trained_models/Sparse2Inverse/Train/SparseAngleLowDoseCTRecon") # Creates the folders if they does not exist savefolder.mkdir(parents=True, exist_ok=True) final_result_fname = "S2I.pt" @@ -34,14 +37,15 @@ # Define experiment experiment = ct_experiments.SparseAngleLowDoseCTRecon() train_dataset = experiment.get_training_dataset() +#30 sinograms for the experiment indices = torch.arange(30) train_dataset = data_utils.Subset(train_dataset, indices) # Data to train batch_size = 1 -dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) +dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False) -# Define model +# Define model. In the original paper used UNet model = UNet() # Create optimizer and loss function @@ -73,3 +77,98 @@ solver.train(epochs) solver.save_final_results(final_result_fname, savefolder) solver.clean_checkpoints() + +# Test using the training data +savefolder = pathlib.Path("/home/ea692/LION/LION/trained_models/Sparse2Inverse/Test/SparseAngleLowDoseCTRecon/SparseVSNoise/30sin2000ep/64Angles_Haarpsi_and_SSIM") +savefolder.mkdir(parents=True, exist_ok=True) + +#Load the trained model of Sparse2Inverse +model_Sparse, _, _ = UNet().load("/store/LION/ea692/LION/LION/trained_models/Sparse2Inverse/Train/SparseAngleLowDoseCTRecon/S2I.json") +model_Sparse.eval() +çsolver_params = Sparse2InverseSolver.default_parameters() +solver_params.sino_split_count = 4 +solver_params.recon_fn = fdk +optimizer = Adam(model_Sparse.parameters()) +#Not used directly, the solver defines its own loss. +loss_fn = nn.MSELoss() + +solver_sparse = Sparse2InverseSolver( + model_Sparse, + optimizer, + loss_fn, + solver_params=solver_params, + geometry=experiment.geometry, + verbose=False, + device=device, +) + +#Normalization in order to ensure a fair comparison of structural and perceptual image quality. +def normalize_01(x,y): + x = (x - y.min())/ (y.max() - y.min()) + x[x>1]=1 + x[x<0]=0 + return x + +#HAARPsi metric +haarpsi = HAARPsi(C=5.0, a=4.9) +haarpsi.eval() + +haarpsi_values_sparse = [] + +#SSIM metric +def my_ssim(x, y): + x = x.cpu().numpy().squeeze() + y = y.cpu().numpy().squeeze() + return ssim(x, y, data_range=x.max() - x.min()) + +ssim_values_sparse = [] + +# Fixed visualization window for all images to ensure fair visual comparison. +vmin, vmax = 0, 5 + +for idx, (sino, target) in enumerate(dataloader): + sino = sino.to(device) + with torch.no_grad(): + model_reco_sparse = solver_sparse.reconstruct(sino).detach().cpu() + target_cpu = target.cpu() + + target_n = normalize_01(target_cpu,target_cpu) + sparse_n = normalize_01(model_reco_sparse,target_cpu) + + haarspi_sparse,_,_ = haarpsi(target_n, sparse_n) + ssim_sparse=my_ssim(target_n,sparse_n) + + ssim_values_sparse.append(ssim_sparse) + + haarpsi_values_sparse.append(haarspi_sparse.item()) + + #Figure the comparison between target and reconstruction. + #Raw reconstructions are shown without normalization. + if idx == 0: + plt.figure(figsize=(12,4)) + + plt.subplot(1,2,1) + plt.title("Target (clean)") + im0 = plt.imshow(target[0,0].cpu(), cmap="gray") + plt.axis("off") + im0.set_clim(vmin, vmax) + + plt.subplot(1,2,2) + plt.title(f"Model raw reconstruction Sparse\nhaarpsi={haarspi_sparse.item():.3f}\nssim={ssim_sparse:.3f}") + im2 = plt.imshow(model_reco_sparse[0,0], cmap="gray") + plt.axis("off") + im2.set_clim(vmin, vmax) + + plt.tight_layout() + plt.savefig(savefolder / "Reconstruction_Sparse2Inverse_SparseAngleLowDoseCTRecon_Haarspi_SSIM.png", dpi=150) + plt.close() + +haarpsi_mean_sparse = np.mean(haarpsi_values_sparse) +haarpsi_std_sparse = np.std(haarpsi_values_sparse) + +ssim_mean_sparse = np.mean(ssim_values_sparse) +ssim_std_sparse = np.std(ssim_values_sparse) + +#Plot the metrics obtained throw all the sinograms tested +print(f"haarpsi mean Sparse SparseAngleLowDoseCTRecon 30sin2000ep 64 angles: {haarpsi_mean_sparse:.4f}, haarpsi std Sparse: {haarpsi_std_sparse:.4f}") +print(f"ssim mean Sparse SparseAngleLowDoseCTRecon 30sin2000ep 64 angles: {ssim_mean_sparse:.4f}, ssim std Sparse: {ssim_std_sparse:.4f}") From 49918ec620e84ac5ce21e2d2cee57238dae20a3e Mon Sep 17 00:00:00 2001 From: Alvaro-Exposito-MTZ Date: Thu, 15 Jan 2026 14:18:50 +0000 Subject: [PATCH 3/9] Update in loss Sparse2InverseSolver.py --- LION/optimizers/Sparse2InverseSolver.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/LION/optimizers/Sparse2InverseSolver.py b/LION/optimizers/Sparse2InverseSolver.py index 5a7931b8..6e7128db 100644 --- a/LION/optimizers/Sparse2InverseSolver.py +++ b/LION/optimizers/Sparse2InverseSolver.py @@ -139,9 +139,7 @@ def mini_batch_step(self, sinos, targets): for b in range(batch_size): projected_sino = projector(output_recon[b:b+1]) target_sino = torch.cat([sinos[b:b+1, :, self.subgroup_indices[i], :] for i in remaining_splits],dim=2) - batch_loss += ((projected_sino - target_sino) ** 2).sum() - total_pixels += projected_sino.numel() - batch_loss /= total_pixels + batch_loss += self.loss_fn(projected_sino, target_sino) return batch_loss From cd212dde107f05f44670b7e7d8dfd10cc085a9f6 Mon Sep 17 00:00:00 2001 From: Alvaro-Exposito-MTZ Date: Thu, 15 Jan 2026 14:39:07 +0000 Subject: [PATCH 4/9] Update example script Sparse2Inverse.py --- scripts/example_scripts/Sparse2Inverse.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/scripts/example_scripts/Sparse2Inverse.py b/scripts/example_scripts/Sparse2Inverse.py index 1742e674..098c3fe2 100644 --- a/scripts/example_scripts/Sparse2Inverse.py +++ b/scripts/example_scripts/Sparse2Inverse.py @@ -82,18 +82,16 @@ savefolder = pathlib.Path("/home/ea692/LION/LION/trained_models/Sparse2Inverse/Test/SparseAngleLowDoseCTRecon/SparseVSNoise/30sin2000ep/64Angles_Haarpsi_and_SSIM") savefolder.mkdir(parents=True, exist_ok=True) -#Load the trained model of Sparse2Inverse -model_Sparse, _, _ = UNet().load("/store/LION/ea692/LION/LION/trained_models/Sparse2Inverse/Train/SparseAngleLowDoseCTRecon/S2I.json") -model_Sparse.eval() -çsolver_params = Sparse2InverseSolver.default_parameters() +model.eval() +solver_params = Sparse2InverseSolver.default_parameters() solver_params.sino_split_count = 4 solver_params.recon_fn = fdk -optimizer = Adam(model_Sparse.parameters()) +optimizer = Adam(model.parameters()) #Not used directly, the solver defines its own loss. loss_fn = nn.MSELoss() solver_sparse = Sparse2InverseSolver( - model_Sparse, + model, optimizer, loss_fn, solver_params=solver_params, From f5034da8d06819e0c5f3125180afb07740a269dd Mon Sep 17 00:00:00 2001 From: Alvaro-Exposito-MTZ Date: Mon, 19 Jan 2026 08:40:50 +0000 Subject: [PATCH 5/9] Update in Test Sparse2Inverse.py I've changed my_ssim function and put solver.set_testing() --- scripts/example_scripts/Sparse2Inverse.py | 69 +++-------------------- 1 file changed, 9 insertions(+), 60 deletions(-) diff --git a/scripts/example_scripts/Sparse2Inverse.py b/scripts/example_scripts/Sparse2Inverse.py index 098c3fe2..b72249bd 100644 --- a/scripts/example_scripts/Sparse2Inverse.py +++ b/scripts/example_scripts/Sparse2Inverse.py @@ -107,66 +107,15 @@ def normalize_01(x,y): x[x<0]=0 return x -#HAARPsi metric -haarpsi = HAARPsi(C=5.0, a=4.9) -haarpsi.eval() - -haarpsi_values_sparse = [] - #SSIM metric def my_ssim(x, y): - x = x.cpu().numpy().squeeze() - y = y.cpu().numpy().squeeze() - return ssim(x, y, data_range=x.max() - x.min()) - -ssim_values_sparse = [] - -# Fixed visualization window for all images to ensure fair visual comparison. -vmin, vmax = 0, 5 - -for idx, (sino, target) in enumerate(dataloader): - sino = sino.to(device) - with torch.no_grad(): - model_reco_sparse = solver_sparse.reconstruct(sino).detach().cpu() - target_cpu = target.cpu() - - target_n = normalize_01(target_cpu,target_cpu) - sparse_n = normalize_01(model_reco_sparse,target_cpu) - - haarspi_sparse,_,_ = haarpsi(target_n, sparse_n) - ssim_sparse=my_ssim(target_n,sparse_n) + x = x.detach().squeeze().cpu() + y = y.detach().squeeze().cpu() - ssim_values_sparse.append(ssim_sparse) - - haarpsi_values_sparse.append(haarspi_sparse.item()) - - #Figure the comparison between target and reconstruction. - #Raw reconstructions are shown without normalization. - if idx == 0: - plt.figure(figsize=(12,4)) - - plt.subplot(1,2,1) - plt.title("Target (clean)") - im0 = plt.imshow(target[0,0].cpu(), cmap="gray") - plt.axis("off") - im0.set_clim(vmin, vmax) - - plt.subplot(1,2,2) - plt.title(f"Model raw reconstruction Sparse\nhaarpsi={haarspi_sparse.item():.3f}\nssim={ssim_sparse:.3f}") - im2 = plt.imshow(model_reco_sparse[0,0], cmap="gray") - plt.axis("off") - im2.set_clim(vmin, vmax) - - plt.tight_layout() - plt.savefig(savefolder / "Reconstruction_Sparse2Inverse_SparseAngleLowDoseCTRecon_Haarspi_SSIM.png", dpi=150) - plt.close() - -haarpsi_mean_sparse = np.mean(haarpsi_values_sparse) -haarpsi_std_sparse = np.std(haarpsi_values_sparse) - -ssim_mean_sparse = np.mean(ssim_values_sparse) -ssim_std_sparse = np.std(ssim_values_sparse) - -#Plot the metrics obtained throw all the sinograms tested -print(f"haarpsi mean Sparse SparseAngleLowDoseCTRecon 30sin2000ep 64 angles: {haarpsi_mean_sparse:.4f}, haarpsi std Sparse: {haarpsi_std_sparse:.4f}") -print(f"ssim mean Sparse SparseAngleLowDoseCTRecon 30sin2000ep 64 angles: {ssim_mean_sparse:.4f}, ssim std Sparse: {ssim_std_sparse:.4f}") + target_n = normalize_01(y,y) + sparse_n = normalize_01(x,y) + return ssim(target_n, sparse_n, data_range=1) + +model.eval() +solver.set_testing(dataloader, my_ssim) +solver.test() From 5b25ad7f1366bcbaeaaf9bc3441844fa6dd16d7f Mon Sep 17 00:00:00 2001 From: Alvaro-Exposito-MTZ Date: Wed, 21 Jan 2026 15:33:03 +0000 Subject: [PATCH 6/9] Update Sparse2InverseSolver.py --- LION/optimizers/Sparse2InverseSolver.py | 1 - 1 file changed, 1 deletion(-) diff --git a/LION/optimizers/Sparse2InverseSolver.py b/LION/optimizers/Sparse2InverseSolver.py index 6e7128db..6099f7f9 100644 --- a/LION/optimizers/Sparse2InverseSolver.py +++ b/LION/optimizers/Sparse2InverseSolver.py @@ -38,7 +38,6 @@ def __init__( self.model.geometry = self.geometry self.model._make_operator() - self.A_full = self.model.A self.sino_split_count = self.solver_params.sino_split_count self.recon_fn = self.solver_params.recon_fn self.split_combinations = self.two_two_strategy(self.sino_split_count) From d98a3b7b2e4d6e9603af183a49bc62f973e80c95 Mon Sep 17 00:00:00 2001 From: Alvaro-Exposito-MTZ Date: Wed, 25 Feb 2026 13:44:33 +0000 Subject: [PATCH 7/9] Create Proj2ProjSolver.py --- LION/optimizers/Proj2ProjSolver.py | 98 ++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 LION/optimizers/Proj2ProjSolver.py diff --git a/LION/optimizers/Proj2ProjSolver.py b/LION/optimizers/Proj2ProjSolver.py new file mode 100644 index 00000000..02720344 --- /dev/null +++ b/LION/optimizers/Proj2ProjSolver.py @@ -0,0 +1,98 @@ +from typing import Callable, Optional +import warnings +import numpy as np +from tqdm import tqdm +from LION.CTtools.ct_geometry import Geometry +from LION.classical_algorithms.fdk import fdk +from LION.models.LIONmodel import LIONmodel +import torch +import torch.nn.functional as F +from torch.optim.optimizer import Optimizer +from LION.optimizers.LIONsolver import LIONsolver +from LION.utils.parameter import LIONParameter +import tomosipo as ts +import LION.CTtools.ct_utils as ct_utils +from tomosipo.torch_support import to_autograd + +class Proj2ProjSolver(LIONsolver): + def __init__( + self, + model: LIONmodel, + optimizer: Optimizer, + loss_fn, + solver_params: Optional[LIONParameter] = None, + geometry: Geometry = None, + verbose: bool = True, + device: torch.device = None, + ) -> None: + print(device) + super().__init__( + model, + optimizer, + loss_fn, + geometry, + verbose, + device, + solver_params=solver_params, + ) + + self.operator = ct_utils.make_operator(self.geometry) + self.model.geometry = self.geometry + self.model.operator = self.operator + self.projector = to_autograd(self.operator, num_extra_dims=1) + self.recon_fn = self.solver_params.recon_fn + self.global_step=0 + + def get_mask(self, shape, step): + # shape: (B, C, H, W) + mask = torch.ones(shape, device=self.device) + grid = self.solver_params.grid_size + + for b in range(shape[0]): + idx = (step+b) % (grid * grid) + r = idx // grid + c = idx % grid + + mask[b, :, r::grid, c::grid] = 0 + return mask + + def fill_mean(self, sinos, mask): + kernel = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.float32, device=self.device) / 4.0 + + kernel = kernel.view(1, 1, 3, 3) + local_mean_sino = F.conv2d(sinos, kernel, padding=1) + + filled_sinos = (sinos * mask) + (local_mean_sino * (1 - mask)) + return filled_sinos + + @staticmethod + def default_parameters() -> LIONParameter: + params = LIONParameter() + params.grid_size = 4 + params.recon_fn = fdk + return params + + def mini_batch_step(self, sinos, targets): + mask=self.get_mask(sinos.shape, self.global_step) + self.global_step += sinos.shape[0] + input_sino=self.fill_mean(sinos,mask) + + input_recon=self.recon_fn(input_sino,self.model.operator) + output_recon = self.model(input_recon) + output_sino=self.projector(output_recon) + + output_sino_mask=output_sino*(1-mask) + target_sino=sinos*(1-mask) + + batch_loss = ((output_sino_mask - target_sino) ** 2).mean() + return batch_loss + + + # No validation in Proj2Proj + def validate(self): + return 0 + + def reconstruct(self, sinos): + input_recon = self.recon_fn(sinos, self.operator) + output_recon = self.model(input_recon) + return output_recon From 5dd7b9f7e72a7b50ded702390a0295845a4d6cac Mon Sep 17 00:00:00 2001 From: Alvaro-Exposito-MTZ Date: Wed, 25 Feb 2026 13:45:33 +0000 Subject: [PATCH 8/9] Create Noisier2Inverse.py --- LION/optimizers/Noisier2Inverse.py | 87 ++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 LION/optimizers/Noisier2Inverse.py diff --git a/LION/optimizers/Noisier2Inverse.py b/LION/optimizers/Noisier2Inverse.py new file mode 100644 index 00000000..ebaac2b1 --- /dev/null +++ b/LION/optimizers/Noisier2Inverse.py @@ -0,0 +1,87 @@ +from typing import Callable, Optional +import warnings +import numpy as np +from tqdm import tqdm +from LION.CTtools.ct_geometry import Geometry +from LION.classical_algorithms.fdk import fdk +from LION.models.LIONmodel import LIONmodel +import torch +import torch.nn.functional as F +from torch.optim.optimizer import Optimizer +from LION.optimizers.LIONsolver import LIONsolver +from LION.utils.parameter import LIONParameter +import tomosipo as ts +import LION.CTtools.ct_utils as ct_utils +from tomosipo.torch_support import to_autograd + +import torchvision.transforms.functional as TF + +class Noisier2Inverse(LIONsolver): + def __init__( + self, + model: LIONmodel, + optimizer: Optimizer, + loss_fn, + solver_params: Optional[LIONParameter] = None, + geometry: Geometry = None, + verbose: bool = True, + device: torch.device = None, + ) -> None: + print(device) + super().__init__( + model, + optimizer, + loss_fn, + geometry, + verbose, + device, + solver_params=solver_params, + ) + + self.operator = ct_utils.make_operator(self.geometry) + self.model.geometry = self.geometry + self.model.operator = self.operator + self.projector = to_autograd(self.operator, num_extra_dims=1) + self.recon_fn = self.solver_params.recon_fn + + + @staticmethod + def default_parameters() -> LIONParameter: + params = LIONParameter() + params.sigma=3 + params.delta=1 + params.recon_fn = fdk + return params + + def mini_batch_step(self, sinos, targets): + sigma = self.solver_params.sigma + delta = self.solver_params.delta + ks = int(sigma * 3) * 2 + 1 + + N = torch.randn_like(sinos) * delta + N = TF.gaussian_blur(N, kernel_size=[ks, ks], sigma=[sigma, sigma]) + z = sinos + N + + input_recon = self.recon_fn(z,self.model.operator) + output_recon = self.model(input_recon) + output_sino = self.projector(output_recon) + target_sino = sinos - N + + #Sobolev Loss + #res = output_sino - target_sino + #grad_x = res[:, :, :, 1:] - res[:, :, :, :-1] + #grad_y = res[:, :, 1:, :] - res[:, :, :-1, :] + + #batch_loss = ((output_sino - target_sino)**2).mean() + (grad_x**2).mean() + (grad_y**2).mean() + batch_loss= ((output_sino - target_sino)**2).mean() + return batch_loss + + + # No validation in Noisier2Inverse + def validate(self): + return 0 + + def reconstruct(self, sinos): + input_recon = self.recon_fn(sinos, self.operator) + output_recon = self.model(input_recon) + return output_recon From 521039fb264bc5596510cdd1a8f4d31544feaac0 Mon Sep 17 00:00:00 2001 From: Alvaro-Exposito-MTZ Date: Wed, 25 Feb 2026 13:46:34 +0000 Subject: [PATCH 9/9] Create SUREpgImage.py --- LION/losses/SUREpgImage.py | 42 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 LION/losses/SUREpgImage.py diff --git a/LION/losses/SUREpgImage.py b/LION/losses/SUREpgImage.py new file mode 100644 index 00000000..995ba215 --- /dev/null +++ b/LION/losses/SUREpgImage.py @@ -0,0 +1,42 @@ +import torch + +class SUREpgLoss(): + + def __init__(self, zeta: float, sigma2: float, eps1: float = 1e-3, eps2: float = 1e-3, kappa: float = 1.0): + self.zeta = zeta + self.sigma2 = sigma2 + self.eps1 = eps1 + self.eps2 = eps2 + self.kappa = kappa + + self.p = (1/2) * (1 + self.kappa / (self.kappa**2 + 4)**0.5) + self.q = 1 - self.p + self.a = (self.q / self.p)**0.5 + self.b = (self.p / self.q)**0.5 + + def __call__(self, model, y): + B=y.shape[0] + N_per_img = y.shape[1] * y.shape[2] * y.shape[3] + + fy=model(y) + loss = ((fy - y) ** 2).sum(dim=(1,2,3)) - self.zeta * y.sum(dim=(1,2,3)) - self.sigma2 * N_per_img + + #1st derivative MC + delta1 = torch.randn_like(y) + fy_perturbated = model(y + self.eps1 * delta1) + + u = self.zeta * y + self.sigma2 + mc1 = (delta1 * u * (fy_perturbated - fy)).sum(dim=(1,2,3)) + loss += 2.0 * mc1 / self.eps1 + + #2nd derivative MC + u_rand = torch.rand_like(y) + delta2 = torch.where(u_rand < self.p,-self.a * torch.ones_like(y),+self.b * torch.ones_like(y)) + + fy_plus = model(y + self.eps2 * delta2) + fy_minus = model(y - self.eps2 * delta2) + + mc2 = (delta2 * (fy_plus - 2*fy + fy_minus)).sum(dim=(1,2,3)) + loss -= (2 * self.sigma2 * self.zeta / (self.eps2**2 * self.kappa)) * mc2 + + return loss.mean()/N_per_img