diff --git a/LION/losses/SUREpgImage.py b/LION/losses/SUREpgImage.py new file mode 100644 index 00000000..995ba215 --- /dev/null +++ b/LION/losses/SUREpgImage.py @@ -0,0 +1,42 @@ +import torch + +class SUREpgLoss(): + + def __init__(self, zeta: float, sigma2: float, eps1: float = 1e-3, eps2: float = 1e-3, kappa: float = 1.0): + self.zeta = zeta + self.sigma2 = sigma2 + self.eps1 = eps1 + self.eps2 = eps2 + self.kappa = kappa + + self.p = (1/2) * (1 + self.kappa / (self.kappa**2 + 4)**0.5) + self.q = 1 - self.p + self.a = (self.q / self.p)**0.5 + self.b = (self.p / self.q)**0.5 + + def __call__(self, model, y): + B=y.shape[0] + N_per_img = y.shape[1] * y.shape[2] * y.shape[3] + + fy=model(y) + loss = ((fy - y) ** 2).sum(dim=(1,2,3)) - self.zeta * y.sum(dim=(1,2,3)) - self.sigma2 * N_per_img + + #1st derivative MC + delta1 = torch.randn_like(y) + fy_perturbated = model(y + self.eps1 * delta1) + + u = self.zeta * y + self.sigma2 + mc1 = (delta1 * u * (fy_perturbated - fy)).sum(dim=(1,2,3)) + loss += 2.0 * mc1 / self.eps1 + + #2nd derivative MC + u_rand = torch.rand_like(y) + delta2 = torch.where(u_rand < self.p,-self.a * torch.ones_like(y),+self.b * torch.ones_like(y)) + + fy_plus = model(y + self.eps2 * delta2) + fy_minus = model(y - self.eps2 * delta2) + + mc2 = (delta2 * (fy_plus - 2*fy + fy_minus)).sum(dim=(1,2,3)) + loss -= (2 * self.sigma2 * self.zeta / (self.eps2**2 * self.kappa)) * mc2 + + return loss.mean()/N_per_img diff --git a/LION/optimizers/Noisier2Inverse.py b/LION/optimizers/Noisier2Inverse.py new file mode 100644 index 00000000..ebaac2b1 --- /dev/null +++ b/LION/optimizers/Noisier2Inverse.py @@ -0,0 +1,87 @@ +from typing import Callable, Optional +import warnings +import numpy as np +from tqdm import tqdm +from LION.CTtools.ct_geometry import Geometry +from LION.classical_algorithms.fdk import fdk +from LION.models.LIONmodel import LIONmodel +import torch +import torch.nn.functional as F +from torch.optim.optimizer import Optimizer +from LION.optimizers.LIONsolver import LIONsolver +from LION.utils.parameter import LIONParameter +import tomosipo as ts +import LION.CTtools.ct_utils as ct_utils +from tomosipo.torch_support import to_autograd + +import torchvision.transforms.functional as TF + +class Noisier2Inverse(LIONsolver): + def __init__( + self, + model: LIONmodel, + optimizer: Optimizer, + loss_fn, + solver_params: Optional[LIONParameter] = None, + geometry: Geometry = None, + verbose: bool = True, + device: torch.device = None, + ) -> None: + print(device) + super().__init__( + model, + optimizer, + loss_fn, + geometry, + verbose, + device, + solver_params=solver_params, + ) + + self.operator = ct_utils.make_operator(self.geometry) + self.model.geometry = self.geometry + self.model.operator = self.operator + self.projector = to_autograd(self.operator, num_extra_dims=1) + self.recon_fn = self.solver_params.recon_fn + + + @staticmethod + def default_parameters() -> LIONParameter: + params = LIONParameter() + params.sigma=3 + params.delta=1 + params.recon_fn = fdk + return params + + def mini_batch_step(self, sinos, targets): + sigma = self.solver_params.sigma + delta = self.solver_params.delta + ks = int(sigma * 3) * 2 + 1 + + N = torch.randn_like(sinos) * delta + N = TF.gaussian_blur(N, kernel_size=[ks, ks], sigma=[sigma, sigma]) + z = sinos + N + + input_recon = self.recon_fn(z,self.model.operator) + output_recon = self.model(input_recon) + output_sino = self.projector(output_recon) + target_sino = sinos - N + + #Sobolev Loss + #res = output_sino - target_sino + #grad_x = res[:, :, :, 1:] - res[:, :, :, :-1] + #grad_y = res[:, :, 1:, :] - res[:, :, :-1, :] + + #batch_loss = ((output_sino - target_sino)**2).mean() + (grad_x**2).mean() + (grad_y**2).mean() + batch_loss= ((output_sino - target_sino)**2).mean() + return batch_loss + + + # No validation in Noisier2Inverse + def validate(self): + return 0 + + def reconstruct(self, sinos): + input_recon = self.recon_fn(sinos, self.operator) + output_recon = self.model(input_recon) + return output_recon diff --git a/LION/optimizers/Proj2ProjSolver.py b/LION/optimizers/Proj2ProjSolver.py new file mode 100644 index 00000000..02720344 --- /dev/null +++ b/LION/optimizers/Proj2ProjSolver.py @@ -0,0 +1,98 @@ +from typing import Callable, Optional +import warnings +import numpy as np +from tqdm import tqdm +from LION.CTtools.ct_geometry import Geometry +from LION.classical_algorithms.fdk import fdk +from LION.models.LIONmodel import LIONmodel +import torch +import torch.nn.functional as F +from torch.optim.optimizer import Optimizer +from LION.optimizers.LIONsolver import LIONsolver +from LION.utils.parameter import LIONParameter +import tomosipo as ts +import LION.CTtools.ct_utils as ct_utils +from tomosipo.torch_support import to_autograd + +class Proj2ProjSolver(LIONsolver): + def __init__( + self, + model: LIONmodel, + optimizer: Optimizer, + loss_fn, + solver_params: Optional[LIONParameter] = None, + geometry: Geometry = None, + verbose: bool = True, + device: torch.device = None, + ) -> None: + print(device) + super().__init__( + model, + optimizer, + loss_fn, + geometry, + verbose, + device, + solver_params=solver_params, + ) + + self.operator = ct_utils.make_operator(self.geometry) + self.model.geometry = self.geometry + self.model.operator = self.operator + self.projector = to_autograd(self.operator, num_extra_dims=1) + self.recon_fn = self.solver_params.recon_fn + self.global_step=0 + + def get_mask(self, shape, step): + # shape: (B, C, H, W) + mask = torch.ones(shape, device=self.device) + grid = self.solver_params.grid_size + + for b in range(shape[0]): + idx = (step+b) % (grid * grid) + r = idx // grid + c = idx % grid + + mask[b, :, r::grid, c::grid] = 0 + return mask + + def fill_mean(self, sinos, mask): + kernel = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.float32, device=self.device) / 4.0 + + kernel = kernel.view(1, 1, 3, 3) + local_mean_sino = F.conv2d(sinos, kernel, padding=1) + + filled_sinos = (sinos * mask) + (local_mean_sino * (1 - mask)) + return filled_sinos + + @staticmethod + def default_parameters() -> LIONParameter: + params = LIONParameter() + params.grid_size = 4 + params.recon_fn = fdk + return params + + def mini_batch_step(self, sinos, targets): + mask=self.get_mask(sinos.shape, self.global_step) + self.global_step += sinos.shape[0] + input_sino=self.fill_mean(sinos,mask) + + input_recon=self.recon_fn(input_sino,self.model.operator) + output_recon = self.model(input_recon) + output_sino=self.projector(output_recon) + + output_sino_mask=output_sino*(1-mask) + target_sino=sinos*(1-mask) + + batch_loss = ((output_sino_mask - target_sino) ** 2).mean() + return batch_loss + + + # No validation in Proj2Proj + def validate(self): + return 0 + + def reconstruct(self, sinos): + input_recon = self.recon_fn(sinos, self.operator) + output_recon = self.model(input_recon) + return output_recon diff --git a/LION/optimizers/Sparse2InverseSolver.py b/LION/optimizers/Sparse2InverseSolver.py index 5a7931b8..6099f7f9 100644 --- a/LION/optimizers/Sparse2InverseSolver.py +++ b/LION/optimizers/Sparse2InverseSolver.py @@ -38,7 +38,6 @@ def __init__( self.model.geometry = self.geometry self.model._make_operator() - self.A_full = self.model.A self.sino_split_count = self.solver_params.sino_split_count self.recon_fn = self.solver_params.recon_fn self.split_combinations = self.two_two_strategy(self.sino_split_count) @@ -139,9 +138,7 @@ def mini_batch_step(self, sinos, targets): for b in range(batch_size): projected_sino = projector(output_recon[b:b+1]) target_sino = torch.cat([sinos[b:b+1, :, self.subgroup_indices[i], :] for i in remaining_splits],dim=2) - batch_loss += ((projected_sino - target_sino) ** 2).sum() - total_pixels += projected_sino.numel() - batch_loss /= total_pixels + batch_loss += self.loss_fn(projected_sino, target_sino) return batch_loss diff --git a/scripts/example_scripts/Sparse2Inverse.py b/scripts/example_scripts/Sparse2Inverse.py new file mode 100644 index 00000000..b72249bd --- /dev/null +++ b/scripts/example_scripts/Sparse2Inverse.py @@ -0,0 +1,121 @@ +from LION.classical_algorithms.fdk import fdk +from Sparse2InverseSolver import Sparse2InverseSolver +from LION.models.CNNs.UNets.Unet import UNet +import LION.experiments.ct_experiments as ct_experiments +from torch.utils.data import DataLoader +from torch.optim.adam import Adam +import torch.nn as nn +import torch +import pathlib +import torch.utils.data as data_utils +import random +import numpy as np +import matplotlib.pyplot as plt +from skimage.metrics import structural_similarity as ssim +from LION.metrics.haarpsi import HAARPsi + +seed = 42 +random.seed(seed) +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +# Set Device +#%% +# % Chose device: +device = torch.device("cuda:1") +torch.cuda.set_device(device) + +# Define your data paths +savefolder = pathlib.Path("/store/LION/ea692/LION/LION/trained_models/Sparse2Inverse/Train/SparseAngleLowDoseCTRecon") +# Creates the folders if they does not exist +savefolder.mkdir(parents=True, exist_ok=True) +final_result_fname = "S2I.pt" +checkpoint_fname = "S2I_check_*.pt" + +# Define experiment +experiment = ct_experiments.SparseAngleLowDoseCTRecon() +train_dataset = experiment.get_training_dataset() +#30 sinograms for the experiment +indices = torch.arange(30) +train_dataset = data_utils.Subset(train_dataset, indices) + +# Data to train +batch_size = 1 +dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False) + +# Define model. In the original paper used UNet +model = UNet() + +# Create optimizer and loss function +optimizer = Adam(model.parameters(), lr=1e-4) +loss_fn = nn.MSELoss() + +#Sparse2InverseSolver. +s2i_params = Sparse2InverseSolver.default_parameters() +# Sparse to inverse requires certain user specifications. +s2i_params.sino_split_count = 4 +s2i_params.recon_fn = fdk + +# Initialize the solver as the other solvers in LION +solver = Sparse2InverseSolver( + model, + optimizer, + loss_fn, + solver_params=s2i_params, + geometry=experiment.geometry, + verbose=True, + device=device, +) + +solver.set_training(dataloader) +solver.set_checkpointing(checkpoint_fname, 100, save_folder=savefolder) + +epochs = 100 + +solver.train(epochs) +solver.save_final_results(final_result_fname, savefolder) +solver.clean_checkpoints() + +# Test using the training data +savefolder = pathlib.Path("/home/ea692/LION/LION/trained_models/Sparse2Inverse/Test/SparseAngleLowDoseCTRecon/SparseVSNoise/30sin2000ep/64Angles_Haarpsi_and_SSIM") +savefolder.mkdir(parents=True, exist_ok=True) + +model.eval() +solver_params = Sparse2InverseSolver.default_parameters() +solver_params.sino_split_count = 4 +solver_params.recon_fn = fdk +optimizer = Adam(model.parameters()) +#Not used directly, the solver defines its own loss. +loss_fn = nn.MSELoss() + +solver_sparse = Sparse2InverseSolver( + model, + optimizer, + loss_fn, + solver_params=solver_params, + geometry=experiment.geometry, + verbose=False, + device=device, +) + +#Normalization in order to ensure a fair comparison of structural and perceptual image quality. +def normalize_01(x,y): + x = (x - y.min())/ (y.max() - y.min()) + x[x>1]=1 + x[x<0]=0 + return x + +#SSIM metric +def my_ssim(x, y): + x = x.detach().squeeze().cpu() + y = y.detach().squeeze().cpu() + + target_n = normalize_01(y,y) + sparse_n = normalize_01(x,y) + return ssim(target_n, sparse_n, data_range=1) + +model.eval() +solver.set_testing(dataloader, my_ssim) +solver.test()