Skip to content
42 changes: 42 additions & 0 deletions LION/losses/SUREpgImage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch

class SUREpgLoss():

def __init__(self, zeta: float, sigma2: float, eps1: float = 1e-3, eps2: float = 1e-3, kappa: float = 1.0):
self.zeta = zeta
self.sigma2 = sigma2
self.eps1 = eps1
self.eps2 = eps2
self.kappa = kappa

self.p = (1/2) * (1 + self.kappa / (self.kappa**2 + 4)**0.5)
self.q = 1 - self.p
self.a = (self.q / self.p)**0.5
self.b = (self.p / self.q)**0.5

def __call__(self, model, y):
B=y.shape[0]
N_per_img = y.shape[1] * y.shape[2] * y.shape[3]

fy=model(y)
loss = ((fy - y) ** 2).sum(dim=(1,2,3)) - self.zeta * y.sum(dim=(1,2,3)) - self.sigma2 * N_per_img

#1st derivative MC
delta1 = torch.randn_like(y)
fy_perturbated = model(y + self.eps1 * delta1)

u = self.zeta * y + self.sigma2
mc1 = (delta1 * u * (fy_perturbated - fy)).sum(dim=(1,2,3))
loss += 2.0 * mc1 / self.eps1

#2nd derivative MC
u_rand = torch.rand_like(y)
delta2 = torch.where(u_rand < self.p,-self.a * torch.ones_like(y),+self.b * torch.ones_like(y))

fy_plus = model(y + self.eps2 * delta2)
fy_minus = model(y - self.eps2 * delta2)

mc2 = (delta2 * (fy_plus - 2*fy + fy_minus)).sum(dim=(1,2,3))
loss -= (2 * self.sigma2 * self.zeta / (self.eps2**2 * self.kappa)) * mc2

return loss.mean()/N_per_img
87 changes: 87 additions & 0 deletions LION/optimizers/Noisier2Inverse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Callable, Optional
import warnings
import numpy as np
from tqdm import tqdm
from LION.CTtools.ct_geometry import Geometry
from LION.classical_algorithms.fdk import fdk
from LION.models.LIONmodel import LIONmodel
import torch
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer
from LION.optimizers.LIONsolver import LIONsolver
from LION.utils.parameter import LIONParameter
import tomosipo as ts
import LION.CTtools.ct_utils as ct_utils
from tomosipo.torch_support import to_autograd

import torchvision.transforms.functional as TF

class Noisier2Inverse(LIONsolver):
def __init__(
self,
model: LIONmodel,
optimizer: Optimizer,
loss_fn,
solver_params: Optional[LIONParameter] = None,
geometry: Geometry = None,
verbose: bool = True,
device: torch.device = None,
) -> None:
print(device)
super().__init__(
model,
optimizer,
loss_fn,
geometry,
verbose,
device,
solver_params=solver_params,
)

self.operator = ct_utils.make_operator(self.geometry)
self.model.geometry = self.geometry
self.model.operator = self.operator
self.projector = to_autograd(self.operator, num_extra_dims=1)
self.recon_fn = self.solver_params.recon_fn


@staticmethod
def default_parameters() -> LIONParameter:
params = LIONParameter()
params.sigma=3
params.delta=1
params.recon_fn = fdk
return params

def mini_batch_step(self, sinos, targets):
sigma = self.solver_params.sigma
delta = self.solver_params.delta
ks = int(sigma * 3) * 2 + 1

N = torch.randn_like(sinos) * delta
N = TF.gaussian_blur(N, kernel_size=[ks, ks], sigma=[sigma, sigma])
z = sinos + N

input_recon = self.recon_fn(z,self.model.operator)
output_recon = self.model(input_recon)
output_sino = self.projector(output_recon)
target_sino = sinos - N

#Sobolev Loss
#res = output_sino - target_sino
#grad_x = res[:, :, :, 1:] - res[:, :, :, :-1]
#grad_y = res[:, :, 1:, :] - res[:, :, :-1, :]

#batch_loss = ((output_sino - target_sino)**2).mean() + (grad_x**2).mean() + (grad_y**2).mean()
batch_loss= ((output_sino - target_sino)**2).mean()
return batch_loss


# No validation in Noisier2Inverse
def validate(self):
return 0

def reconstruct(self, sinos):
input_recon = self.recon_fn(sinos, self.operator)
output_recon = self.model(input_recon)
return output_recon
98 changes: 98 additions & 0 deletions LION/optimizers/Proj2ProjSolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import Callable, Optional
import warnings
import numpy as np
from tqdm import tqdm
from LION.CTtools.ct_geometry import Geometry
from LION.classical_algorithms.fdk import fdk
from LION.models.LIONmodel import LIONmodel
import torch
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer
from LION.optimizers.LIONsolver import LIONsolver
from LION.utils.parameter import LIONParameter
import tomosipo as ts
import LION.CTtools.ct_utils as ct_utils
from tomosipo.torch_support import to_autograd

class Proj2ProjSolver(LIONsolver):
def __init__(
self,
model: LIONmodel,
optimizer: Optimizer,
loss_fn,
solver_params: Optional[LIONParameter] = None,
geometry: Geometry = None,
verbose: bool = True,
device: torch.device = None,
) -> None:
print(device)
super().__init__(
model,
optimizer,
loss_fn,
geometry,
verbose,
device,
solver_params=solver_params,
)

self.operator = ct_utils.make_operator(self.geometry)
self.model.geometry = self.geometry
self.model.operator = self.operator
self.projector = to_autograd(self.operator, num_extra_dims=1)
self.recon_fn = self.solver_params.recon_fn
self.global_step=0

def get_mask(self, shape, step):
# shape: (B, C, H, W)
mask = torch.ones(shape, device=self.device)
grid = self.solver_params.grid_size

for b in range(shape[0]):
idx = (step+b) % (grid * grid)
r = idx // grid
c = idx % grid

mask[b, :, r::grid, c::grid] = 0
return mask

def fill_mean(self, sinos, mask):
kernel = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.float32, device=self.device) / 4.0

kernel = kernel.view(1, 1, 3, 3)
local_mean_sino = F.conv2d(sinos, kernel, padding=1)

filled_sinos = (sinos * mask) + (local_mean_sino * (1 - mask))
return filled_sinos

@staticmethod
def default_parameters() -> LIONParameter:
params = LIONParameter()
params.grid_size = 4
params.recon_fn = fdk
return params

def mini_batch_step(self, sinos, targets):
mask=self.get_mask(sinos.shape, self.global_step)
self.global_step += sinos.shape[0]
input_sino=self.fill_mean(sinos,mask)

input_recon=self.recon_fn(input_sino,self.model.operator)
output_recon = self.model(input_recon)
output_sino=self.projector(output_recon)

output_sino_mask=output_sino*(1-mask)
target_sino=sinos*(1-mask)

batch_loss = ((output_sino_mask - target_sino) ** 2).mean()
return batch_loss


# No validation in Proj2Proj
def validate(self):
return 0

def reconstruct(self, sinos):
input_recon = self.recon_fn(sinos, self.operator)
output_recon = self.model(input_recon)
return output_recon
5 changes: 1 addition & 4 deletions LION/optimizers/Sparse2InverseSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def __init__(

self.model.geometry = self.geometry
self.model._make_operator()
self.A_full = self.model.A
self.sino_split_count = self.solver_params.sino_split_count
self.recon_fn = self.solver_params.recon_fn
self.split_combinations = self.two_two_strategy(self.sino_split_count)
Expand Down Expand Up @@ -139,9 +138,7 @@ def mini_batch_step(self, sinos, targets):
for b in range(batch_size):
projected_sino = projector(output_recon[b:b+1])
target_sino = torch.cat([sinos[b:b+1, :, self.subgroup_indices[i], :] for i in remaining_splits],dim=2)
batch_loss += ((projected_sino - target_sino) ** 2).sum()
total_pixels += projected_sino.numel()
batch_loss /= total_pixels
batch_loss += self.loss_fn(projected_sino, target_sino)
return batch_loss


Expand Down
121 changes: 121 additions & 0 deletions scripts/example_scripts/Sparse2Inverse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from LION.classical_algorithms.fdk import fdk
from Sparse2InverseSolver import Sparse2InverseSolver
from LION.models.CNNs.UNets.Unet import UNet
import LION.experiments.ct_experiments as ct_experiments
from torch.utils.data import DataLoader
from torch.optim.adam import Adam
import torch.nn as nn
import torch
import pathlib
import torch.utils.data as data_utils
import random
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from LION.metrics.haarpsi import HAARPsi

seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Set Device
#%%
# % Chose device:
device = torch.device("cuda:1")
torch.cuda.set_device(device)

# Define your data paths
savefolder = pathlib.Path("/store/LION/ea692/LION/LION/trained_models/Sparse2Inverse/Train/SparseAngleLowDoseCTRecon")
# Creates the folders if they does not exist
savefolder.mkdir(parents=True, exist_ok=True)
final_result_fname = "S2I.pt"
checkpoint_fname = "S2I_check_*.pt"

# Define experiment
experiment = ct_experiments.SparseAngleLowDoseCTRecon()
train_dataset = experiment.get_training_dataset()
#30 sinograms for the experiment
indices = torch.arange(30)
train_dataset = data_utils.Subset(train_dataset, indices)

# Data to train
batch_size = 1
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)

# Define model. In the original paper used UNet
model = UNet()

# Create optimizer and loss function
optimizer = Adam(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

#Sparse2InverseSolver.
s2i_params = Sparse2InverseSolver.default_parameters()
# Sparse to inverse requires certain user specifications.
s2i_params.sino_split_count = 4
s2i_params.recon_fn = fdk

# Initialize the solver as the other solvers in LION
solver = Sparse2InverseSolver(
model,
optimizer,
loss_fn,
solver_params=s2i_params,
geometry=experiment.geometry,
verbose=True,
device=device,
)

solver.set_training(dataloader)
solver.set_checkpointing(checkpoint_fname, 100, save_folder=savefolder)

epochs = 100

solver.train(epochs)
solver.save_final_results(final_result_fname, savefolder)
solver.clean_checkpoints()

# Test using the training data
savefolder = pathlib.Path("/home/ea692/LION/LION/trained_models/Sparse2Inverse/Test/SparseAngleLowDoseCTRecon/SparseVSNoise/30sin2000ep/64Angles_Haarpsi_and_SSIM")
savefolder.mkdir(parents=True, exist_ok=True)

model.eval()
solver_params = Sparse2InverseSolver.default_parameters()
solver_params.sino_split_count = 4
solver_params.recon_fn = fdk
optimizer = Adam(model.parameters())
#Not used directly, the solver defines its own loss.
loss_fn = nn.MSELoss()

solver_sparse = Sparse2InverseSolver(
model,
optimizer,
loss_fn,
solver_params=solver_params,
geometry=experiment.geometry,
verbose=False,
device=device,
)

#Normalization in order to ensure a fair comparison of structural and perceptual image quality.
def normalize_01(x,y):
x = (x - y.min())/ (y.max() - y.min())
x[x>1]=1
x[x<0]=0
return x

#SSIM metric
def my_ssim(x, y):
x = x.detach().squeeze().cpu()
y = y.detach().squeeze().cpu()

target_n = normalize_01(y,y)
sparse_n = normalize_01(x,y)
return ssim(target_n, sparse_n, data_range=1)

model.eval()
solver.set_testing(dataloader, my_ssim)
solver.test()
Loading